From 92b5d1471931f6a34af3cd361c5e867e0d3e745b Mon Sep 17 00:00:00 2001 From: Joe Chen Date: Sat, 11 Jun 2022 17:58:12 +0800 Subject: [PATCH] route: add `Headers` method for matching request headers (#133) Co-authored-by: E99p1ant --- codecov.yml | 4 + internal/route/header_matcher.go | 37 ++++ internal/route/header_matcher_test.go | 81 ++++++++ internal/route/leaf.go | 48 ++++- internal/route/leaf_test.go | 19 +- internal/route/tree.go | 36 ++-- internal/route/tree_test.go | 278 ++++++++++++++++++++++---- router.go | 53 ++++- router_test.go | 45 ++++- 9 files changed, 510 insertions(+), 91 deletions(-) create mode 100644 internal/route/header_matcher.go create mode 100644 internal/route/header_matcher_test.go diff --git a/codecov.yml b/codecov.yml index 198b73a..4b99d11 100644 --- a/codecov.yml +++ b/codecov.yml @@ -5,6 +5,10 @@ coverage: default: threshold: 1% informational: true + patch: + default: + only_pulls: true + informational: true comment: layout: 'diff' diff --git a/internal/route/header_matcher.go b/internal/route/header_matcher.go new file mode 100644 index 0000000..e61af50 --- /dev/null +++ b/internal/route/header_matcher.go @@ -0,0 +1,37 @@ +// Copyright 2022 Flamego. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package route + +import ( + "net/http" + "regexp" +) + +// HeaderMatcher stores matchers for request headers. +type HeaderMatcher struct { + matches map[string]*regexp.Regexp // Key is the header name +} + +// NewHeaderMatcher creates a new HeaderMatcher using given matches, where keys +// are header names. +func NewHeaderMatcher(matches map[string]*regexp.Regexp) *HeaderMatcher { + return &HeaderMatcher{ + matches: matches, + } +} + +// Match returns true if all matches are successfully in the given header. +func (m *HeaderMatcher) Match(header http.Header) bool { + for name, re := range m.matches { + v := header.Get(name) + if v == "" { + return false + } + if !re.MatchString(v) { + return false + } + } + return true +} diff --git a/internal/route/header_matcher_test.go b/internal/route/header_matcher_test.go new file mode 100644 index 0000000..c1e5cab --- /dev/null +++ b/internal/route/header_matcher_test.go @@ -0,0 +1,81 @@ +// Copyright 2022 Flamego. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package route + +import ( + "net/http" + "regexp" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestHeaderMatcher(t *testing.T) { + header := make(http.Header) + header.Set("Server", "Caddy") + header.Set("Status", "200 OK") + + tests := []struct { + name string + matches map[string]*regexp.Regexp + want bool + }{ + { + name: "loose matches", + matches: map[string]*regexp.Regexp{ + "Server": regexp.MustCompile("Caddy"), + "Status": regexp.MustCompile("200"), + }, + want: true, + }, + { + name: "loose matches", + matches: map[string]*regexp.Regexp{ + "Server": regexp.MustCompile("Caddy"), + "Status": regexp.MustCompile("404"), + }, + want: false, + }, + + { + name: "exact matches", + matches: map[string]*regexp.Regexp{ + "Server": regexp.MustCompile("^Caddy$"), + "Status": regexp.MustCompile("^200 OK$"), + }, + want: true, + }, + { + name: "exact matches", + matches: map[string]*regexp.Regexp{ + "Server": regexp.MustCompile("^Caddy$"), + "Status": regexp.MustCompile("^200$"), + }, + want: false, + }, + + { + name: "presence match", + matches: map[string]*regexp.Regexp{ + "Server": regexp.MustCompile(""), + }, + want: true, + }, + { + name: "presence match", + matches: map[string]*regexp.Regexp{ + "Server": regexp.MustCompile(""), + "Cache-Control": regexp.MustCompile(""), + }, + want: false, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got := NewHeaderMatcher(test.matches).Match(header) + assert.Equal(t, test.want, got) + }) + } +} diff --git a/internal/route/leaf.go b/internal/route/leaf.go index 62cef3a..501ab4d 100644 --- a/internal/route/leaf.go +++ b/internal/route/leaf.go @@ -6,6 +6,7 @@ package route import ( "bytes" + "net/http" "regexp" "strconv" "strings" @@ -27,6 +28,9 @@ const ( // Leaf is a leaf derived from a segment. type Leaf interface { + // SetHeaderMatcher sets the HeaderMatcher for the leaf. + SetHeaderMatcher(m *HeaderMatcher) + // URLPath fills in bind parameters with given values to build the "path" // portion of the URL. If `withOptional` is true, the path will include the // current leaf when it is optional; otherwise, the current leaf is excluded. @@ -46,15 +50,16 @@ type Leaf interface { getMatchStyle() MatchStyle // match returns true if the leaf matches the segment, values of bind parameters // are stored in the `Params`. - match(segment string, params Params) bool + match(segment string, params Params, header http.Header) bool } // baseLeaf contains common fields for any leaf. type baseLeaf struct { - parent Tree // The parent tree this leaf belongs to. - route *Route // The route that the segment belongs to. - segment *Segment // The segment that the leaf is derived from. - handler Handler // The handler bound to the leaf. + parent Tree // The parent tree this leaf belongs to. + route *Route // The route that the segment belongs to. + segment *Segment // The segment that the leaf is derived from. + handler Handler // The handler bound to the leaf. + headerMatcher *HeaderMatcher // The matcher for header values. } func (l *baseLeaf) getParent() Tree { @@ -65,6 +70,14 @@ func (l *baseLeaf) getSegment() *Segment { return l.segment } +func (l *baseLeaf) SetHeaderMatcher(m *HeaderMatcher) { + l.headerMatcher = m +} + +func (l *baseLeaf) matchHeader(header http.Header) bool { + return l.headerMatcher == nil || l.headerMatcher.Match(header) +} + func (l *baseLeaf) URLPath(vals map[string]string, withOptional bool) string { var buf bytes.Buffer for _, s := range l.route.Segments { @@ -123,8 +136,8 @@ func (*staticLeaf) getMatchStyle() MatchStyle { return matchStyleStatic } -func (l *staticLeaf) match(segment string, _ Params) bool { - return l.literals == segment +func (l *staticLeaf) match(segment string, _ Params, header http.Header) bool { + return l.literals == segment && l.matchHeader(header) } func (l *staticLeaf) Static() bool { @@ -149,12 +162,16 @@ func (*regexLeaf) getMatchStyle() MatchStyle { return matchStyleRegex } -func (l *regexLeaf) match(segment string, params Params) bool { +func (l *regexLeaf) match(segment string, params Params, header http.Header) bool { submatches := l.regexp.FindStringSubmatch(segment) if len(submatches) < len(l.binds)+1 { return false } + if !l.matchHeader(header) { + return false + } + for i, bind := range l.binds { params[bind] = submatches[i+1] } @@ -171,7 +188,10 @@ func (*placeholderLeaf) getMatchStyle() MatchStyle { return matchStylePlaceholder } -func (l *placeholderLeaf) match(segment string, params Params) bool { +func (l *placeholderLeaf) match(segment string, params Params, header http.Header) bool { + if !l.matchHeader(header) { + return false + } params[l.bind] = segment return true } @@ -187,7 +207,10 @@ func (*matchAllLeaf) getMatchStyle() MatchStyle { return matchStyleAll } -func (l *matchAllLeaf) match(segment string, params Params) bool { +func (l *matchAllLeaf) match(segment string, params Params, header http.Header) bool { + if !l.matchHeader(header) { + return false + } params[l.bind] = segment return true } @@ -196,13 +219,16 @@ func (l *matchAllLeaf) match(segment string, params Params) bool { // defined). The `path` should be original request path, `segment` should NOT be // unescaped by the caller. It returns true if segments are captured within the // limit, and the capture result is stored in `params`. -func (l *matchAllLeaf) matchAll(path, segment string, next int, params Params) bool { +func (l *matchAllLeaf) matchAll(path, segment string, next int, params Params, header http.Header) bool { // Do `next-1` because "next" starts at the next character of preceding "/"; do // `strings.Count()+1` because the segment itself also counts. E.g. "webapi" + // "users/events" => 3 if l.capture > 0 && l.capture < strings.Count(path[next-1:], "/")+1 { return false } + if !l.matchHeader(header) { + return false + } params[l.bind] = segment + "/" + path[next:] return true diff --git a/internal/route/leaf_test.go b/internal/route/leaf_test.go index 0852032..285807e 100644 --- a/internal/route/leaf_test.go +++ b/internal/route/leaf_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewLeaf(t *testing.T) { @@ -21,7 +22,7 @@ func TestNewLeaf(t *testing.T) { }) parser, err := NewParser() - assert.Nil(t, err) + require.NoError(t, err) tests := []struct { route string @@ -60,12 +61,12 @@ func TestNewLeaf(t *testing.T) { for _, test := range tests { t.Run(test.route, func(t *testing.T) { route, err := parser.Parse(test.route) - assert.Nil(t, err) + require.NoError(t, err) assert.Len(t, route.Segments, 1) segment := route.Segments[0] got, err := newLeaf(nil, route, segment, nil) - assert.Nil(t, err) + require.NoError(t, err) switch test.style { case matchStyleStatic: @@ -86,7 +87,7 @@ func TestNewLeaf(t *testing.T) { func TestNewLeaf_Regex(t *testing.T) { parser, err := NewParser() - assert.Nil(t, err) + require.NoError(t, err) tests := []struct { route string @@ -122,12 +123,12 @@ func TestNewLeaf_Regex(t *testing.T) { for _, test := range tests { t.Run(test.route, func(t *testing.T) { route, err := parser.Parse(test.route) - assert.Nil(t, err) + require.NoError(t, err) assert.Len(t, route.Segments, 1) segment := route.Segments[0] got, err := newLeaf(nil, route, segment, nil) - assert.Nil(t, err) + require.NoError(t, err) leaf := got.(*regexLeaf) assert.Equal(t, test.wantRegexp, leaf.regexp.String()) @@ -138,7 +139,7 @@ func TestNewLeaf_Regex(t *testing.T) { func TestLeaf_URLPath(t *testing.T) { parser, err := NewParser() - assert.Nil(t, err) + require.NoError(t, err) tests := []struct { route string @@ -245,11 +246,11 @@ func TestLeaf_URLPath(t *testing.T) { for _, test := range tests { t.Run(test.route, func(t *testing.T) { route, err := parser.Parse(test.route) - assert.Nil(t, err) + require.NoError(t, err) segment := route.Segments[len(route.Segments)-1] leaf, err := newLeaf(nil, route, segment, nil) - assert.Nil(t, err) + require.NoError(t, err) got := leaf.URLPath(test.vals, test.withOptional) assert.Equal(t, test.want, got) diff --git a/internal/route/tree.go b/internal/route/tree.go index 585132b..b784361 100644 --- a/internal/route/tree.go +++ b/internal/route/tree.go @@ -15,10 +15,10 @@ import ( // Tree is a tree derived from a segment. type Tree interface { - // Match matches a leaf for the given request path, values of bind parameters - // are stored in the `Params`. The `Params` may contain extra values that do not - // belong to the final leaf due to backtrace. - Match(path string) (Leaf, Params, bool) + // Match matches a leaf for the given request path and provided headers, values + // of bind parameters are stored in the `Params`. The `Params` may contain extra + // values that do not belong to the final leaf due to backtrace. + Match(path string, header http.Header) (Leaf, Params, bool) // getParent returns the parent tree. The root tree does not have parent. getParent() Tree @@ -46,7 +46,7 @@ type Tree interface { match(segment string, params Params) bool // matchNextSegment advances the `next` cursor for matching next segment in the // request path. - matchNextSegment(path string, next int, params Params) (Leaf, bool) + matchNextSegment(path string, next int, params Params, header http.Header) (Leaf, bool) } // baseTree contains common fields and methods for any tree. @@ -299,10 +299,10 @@ func (t *matchAllTree) getBinds() []string { // defined). The `path` should be original request path, `segment` should NOT be // unescaped by the caller. It returns the matched leaf and true if segments are // captured within the limit, and the capture result is stored in `params`. -func (t *matchAllTree) matchAll(path, segment string, next int, params Params) (Leaf, bool) { +func (t *matchAllTree) matchAll(path, segment string, next int, params Params, header http.Header) (Leaf, bool) { captured := 1 // Starts with 1 because the segment itself also count. for t.capture <= 0 || t.capture >= captured { - leaf, ok := t.matchNextSegment(path, next, params) + leaf, ok := t.matchNextSegment(path, next, params, header) if ok { params[t.bind] = segment return leaf, true @@ -413,9 +413,9 @@ func AddRoute(t Tree, r *Route, h Handler) (Leaf, error) { // matchLeaf returns the matched leaf and true if any leaf of the tree matches // the given segment. -func (t *baseTree) matchLeaf(segment string, params Params) (Leaf, bool) { +func (t *baseTree) matchLeaf(segment string, params Params, header http.Header) (Leaf, bool) { for _, l := range t.leaves { - ok := l.match(segment, params) + ok := l.match(segment, params, header) if ok { return l, true } @@ -425,10 +425,10 @@ func (t *baseTree) matchLeaf(segment string, params Params) (Leaf, bool) { // matchSubtree returns the matched leaf and true if any subtree or leaf of the // tree matches the given segment. -func (t *baseTree) matchSubtree(path, segment string, next int, params Params) (Leaf, bool) { +func (t *baseTree) matchSubtree(path, segment string, next int, params Params, header http.Header) (Leaf, bool) { for _, st := range t.subtrees { if st.getMatchStyle() == matchStyleAll { - leaf, ok := st.(*matchAllTree).matchAll(path, segment, next, params) + leaf, ok := st.(*matchAllTree).matchAll(path, segment, next, params, header) if ok { return leaf, true } @@ -444,7 +444,7 @@ func (t *baseTree) matchSubtree(path, segment string, next int, params Params) ( continue } - leaf, ok := st.matchNextSegment(path, next, params) + leaf, ok := st.matchNextSegment(path, next, params, header) if !ok { continue } @@ -458,7 +458,7 @@ func (t *baseTree) matchSubtree(path, segment string, next int, params Params) ( return nil, false } - ok := leaf.(*matchAllLeaf).matchAll(path, segment, next, params) + ok := leaf.(*matchAllLeaf).matchAll(path, segment, next, params, header) if ok { return leaf, ok } @@ -467,18 +467,18 @@ func (t *baseTree) matchSubtree(path, segment string, next int, params Params) ( return nil, false } -func (t *baseTree) matchNextSegment(path string, next int, params Params) (Leaf, bool) { +func (t *baseTree) matchNextSegment(path string, next int, params Params, header http.Header) (Leaf, bool) { i := strings.Index(path[next:], "/") if i == -1 { - return t.matchLeaf(path[next:], params) + return t.matchLeaf(path[next:], params, header) } - return t.matchSubtree(path, path[next:next+i], next+i+1, params) + return t.matchSubtree(path, path[next:next+i], next+i+1, params, header) } -func (t *baseTree) Match(path string) (Leaf, Params, bool) { +func (t *baseTree) Match(path string, header http.Header) (Leaf, Params, bool) { path = strings.TrimLeft(path, "/") params := make(Params) - leaf, ok := t.matchNextSegment(path, 0, params) + leaf, ok := t.matchNextSegment(path, 0, params, header) if !ok { return nil, nil, false } diff --git a/internal/route/tree_test.go b/internal/route/tree_test.go index 0c806eb..566fdeb 100644 --- a/internal/route/tree_test.go +++ b/internal/route/tree_test.go @@ -6,11 +6,13 @@ package route import ( "fmt" + "net/http" "regexp" "strings" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewTree(t *testing.T) { @@ -29,7 +31,7 @@ func TestNewTree(t *testing.T) { }) parser, err := NewParser() - assert.Nil(t, err) + require.NoError(t, err) tests := []struct { route string @@ -59,12 +61,12 @@ func TestNewTree(t *testing.T) { for _, test := range tests { t.Run(test.route, func(t *testing.T) { route, err := parser.Parse(test.route) - assert.Nil(t, err) - assert.Len(t, route.Segments, 2) + require.NoError(t, err) + require.Len(t, route.Segments, 2) segment := route.Segments[0] got, err := newTree(nil, segment) - assert.Nil(t, err) + require.NoError(t, err) switch test.style { case matchStyleStatic: @@ -82,7 +84,7 @@ func TestNewTree(t *testing.T) { func TestNewTree_Regex(t *testing.T) { parser, err := NewParser() - assert.Nil(t, err) + require.NoError(t, err) tests := []struct { route string @@ -118,12 +120,12 @@ func TestNewTree_Regex(t *testing.T) { for _, test := range tests { t.Run(test.route, func(t *testing.T) { route, err := parser.Parse(test.route) - assert.Nil(t, err) - assert.Len(t, route.Segments, 2) + require.NoError(t, err) + require.Len(t, route.Segments, 2) segment := route.Segments[0] got, err := newTree(nil, segment) - assert.Nil(t, err) + require.NoError(t, err) tree := got.(*regexTree) assert.Equal(t, test.wantRegexp, tree.regexp.String()) @@ -134,18 +136,18 @@ func TestNewTree_Regex(t *testing.T) { func TestAddRoute(t *testing.T) { parser, err := NewParser() - assert.Nil(t, err) + require.NoError(t, err) t.Run("duplicated routes", func(t *testing.T) { tree := NewTree() r1, err := parser.Parse(`/webapi/users`) - assert.Nil(t, err) + require.NoError(t, err) _, err = AddRoute(tree, r1, nil) - assert.Nil(t, err) + require.NoError(t, err) r2, err := parser.Parse(`/webapi/users/?events`) - assert.Nil(t, err) + require.NoError(t, err) _, err = AddRoute(tree, r2, nil) got := fmt.Sprintf("%v", err) want := `add optional leaf to grandparent: duplicated route "/webapi/users/?events"` @@ -154,7 +156,7 @@ func TestAddRoute(t *testing.T) { t.Run("duplicated match all styles", func(t *testing.T) { route, err := parser.Parse(`/webapi/tree/{paths: **}/{names: **}/upload`) - assert.Nil(t, err) + require.NoError(t, err) _, err = AddRoute(NewTree(), route, nil) got := fmt.Sprintf("%v", err) @@ -223,12 +225,12 @@ func TestAddRoute(t *testing.T) { }, }, { - route: `/webapi/article_{id: /[0-9]+/}_{page: /[\\w]+/}.{ext: /diff|patch/}`, + route: `/webapi/article_{id: /\d+/}_{page: /[\\w]+/}.{ext: /diff|patch/}`, style: matchStyleRegex, wantDepth: 3, wantLeaf: ®exLeaf{ baseLeaf: baseLeaf{}, - regexp: regexp.MustCompile(`^article_([0-9]+)_([\\w]+)\.(diff|patch)$`), + regexp: regexp.MustCompile(`^article_(\d+)_([\\w]+)\.(diff|patch)$`), binds: []string{"id", "page", "ext"}, }, }, @@ -236,10 +238,10 @@ func TestAddRoute(t *testing.T) { for _, test := range tests { t.Run(test.route, func(t *testing.T) { route, err := parser.Parse(test.route) - assert.Nil(t, err) + require.NoError(t, err) got, err := AddRoute(NewTree(), route, nil) - assert.Nil(t, err) + require.NoError(t, err) segment := route.Segments[len(route.Segments)-1] switch test.style { @@ -276,7 +278,7 @@ func TestAddRoute(t *testing.T) { func TestAddRoute_DuplicatedBinds(t *testing.T) { parser, err := NewParser() - assert.Nil(t, err) + require.NoError(t, err) tree := NewTree() @@ -294,7 +296,7 @@ func TestAddRoute_DuplicatedBinds(t *testing.T) { for _, route := range routes { t.Run(route, func(t *testing.T) { r, err := parser.Parse(route) - assert.Nil(t, err) + require.NoError(t, err) _, err = AddRoute(tree, r, nil) got := fmt.Sprintf("%v", err) @@ -305,7 +307,7 @@ func TestAddRoute_DuplicatedBinds(t *testing.T) { func TestAddRoute_DuplicatedMatchAll(t *testing.T) { parser, err := NewParser() - assert.Nil(t, err) + require.NoError(t, err) tree := NewTree() @@ -321,12 +323,11 @@ func TestAddRoute_DuplicatedMatchAll(t *testing.T) { for i, route := range routes { t.Run(route, func(t *testing.T) { r, err := parser.Parse(route) - assert.Nil(t, err) + require.NoError(t, err) _, err = AddRoute(tree, r, nil) - if i%2 == 0 { - assert.Nil(t, err) + assert.NoError(t, err) } else { got := fmt.Sprintf("%v", err) assert.Contains(t, got, "duplicated match all bind parameter") @@ -337,7 +338,7 @@ func TestAddRoute_DuplicatedMatchAll(t *testing.T) { func TestTree_Match(t *testing.T) { parser, err := NewParser() - assert.Nil(t, err) + require.NoError(t, err) tree := NewTree() @@ -361,10 +362,10 @@ func TestTree_Match(t *testing.T) { } for _, route := range routes { r, err := parser.Parse(route) - assert.Nil(t, err) + require.NoError(t, err) _, err = AddRoute(tree, r, nil) - assert.Nil(t, err) + require.NoError(t, err) } tests := []struct { @@ -534,8 +535,8 @@ func TestTree_Match(t *testing.T) { } for _, test := range tests { t.Run(test.path, func(t *testing.T) { - leaf, params, ok := tree.Match(test.path) - assert.Equal(t, test.wantOK, ok) + leaf, params, ok := tree.Match(test.path, nil) + require.Equal(t, test.wantOK, ok) if !ok { return @@ -548,26 +549,22 @@ func TestTree_Match(t *testing.T) { func TestTree_MatchEscape(t *testing.T) { parser, err := NewParser() - assert.Nil(t, err) + require.NoError(t, err) tree := NewTree() - - // NOTE: The order of routes and tests matters, matching for the same priority - // is first in first match. routes := []string{ "/webapi/special/vars/{var}", } for _, route := range routes { r, err := parser.Parse(route) - assert.Nil(t, err) + require.NoError(t, err) _, err = AddRoute(tree, r, nil) - assert.Nil(t, err) + assert.NoError(t, err) } tests := []struct { path string - withOptional bool wantOK bool wantParams Params wantUnescapedURL string @@ -591,14 +588,219 @@ func TestTree_MatchEscape(t *testing.T) { } for _, test := range tests { t.Run(test.path, func(t *testing.T) { - leaf, params, ok := tree.Match(test.path) - assert.Equal(t, test.wantOK, ok) + leaf, params, ok := tree.Match(test.path, nil) + require.Equal(t, test.wantOK, ok) + + if !ok { + return + } + assert.Equal(t, test.wantParams, params) + assert.Equal(t, strings.TrimRight(test.wantUnescapedURL, "/"), leaf.URLPath(params, false)) + }) + } +} + +func TestTree_MatchHeader(t *testing.T) { + parser, err := NewParser() + require.NoError(t, err) + + tree := NewTree() + + addRoute := func(path string, header map[string]string) { + t.Helper() + + r, err := parser.Parse(path) + require.NoError(t, err) + + l, err := AddRoute(tree, r, nil) + assert.NoError(t, err) + + matches := make(map[string]*regexp.Regexp, len(header)) + for k, v := range header { + matches[k] = regexp.MustCompile(v) + } + l.SetHeaderMatcher(NewHeaderMatcher(matches)) + } + // Note: The order of routes and tests matters, matching for the same priority + // is first in first match. + addRoute("/webapi/static", + map[string]string{ + "Server": "Caddy", + "Status": "", + }, + ) + + addRoute("/webapi/vars/{var}", + map[string]string{ + "Server": "Caddy", + }, + ) + addRoute("/webapi/vars/{var}.html", + map[string]string{ + "Server": "Caddy", + "Status": "", + }, + ) + + addRoute(`/webapi/users/ids/{id: /[0-9]+/}_html`, + map[string]string{ + "Server": "Caddy", + "Status": "", + }, + ) + addRoute(`/webapi/users/ids/{id: /\w+/}`, + map[string]string{ + "Server": "Caddy", + }, + ) + + addRoute("/webapi/users/sessions/123", + map[string]string{ + "Server": "Caddy", + "Status": "", + }, + ) + addRoute("/webapi/users/sessions/{paths: **}", + map[string]string{ + "Server": "Caddy", + }, + ) + + addRoute("/webapi/users/events/{names: **}/feed", + map[string]string{ + "Server": "Caddy", + "Status": "", + }, + ) + addRoute("/webapi/users/events/{names: **}", + map[string]string{ + "Server": "Caddy", + }, + ) + + tests := []struct { + path string + header map[string]string + wantOK bool + wantParams Params + }{ + { + path: "/webapi/static", + header: map[string]string{ + "Server": "Caddy", + "Status": "200 OK", + }, + wantOK: true, + wantParams: Params{}, + }, + { + path: "/webapi/static", + header: map[string]string{ + "Server": "Caddy", + }, + wantOK: false, // Missing "Status" header + }, + + { + path: "/webapi/vars/abc.html", + header: map[string]string{ + "Server": "Caddy", + }, + wantOK: true, + wantParams: Params{ + "var": "abc.html", // Not matching "/webapi/vars/{var}.html" because missing "Status" header + }, + }, + { + path: "/webapi/vars/abc.html", + header: map[string]string{ + "Server": "Caddy", + "Status": "200 OK", + }, + wantOK: true, + wantParams: Params{ + "var": "abc", + }, + }, + + { + path: "/webapi/users/ids/abc_html", + header: map[string]string{ + "Server": "Caddy", + }, + wantOK: true, + wantParams: Params{ + "id": "abc_html", // Not matching "/webapi/users/ids/{id: /[0-9]+/}_html" because missing "Status" header + }, + }, + { + path: "/webapi/users/ids/2830_html", + header: map[string]string{ + "Server": "Caddy", + "Status": "200 OK", + }, + wantOK: true, + wantParams: Params{ + "id": "2830", + }, + }, + + { + path: "/webapi/users/sessions/123", + header: map[string]string{ + "Server": "Caddy", + }, + wantOK: true, + wantParams: Params{ + "paths": "123", // Not matching "/webapi/users/sessions/123" because missing "Status" header + }, + }, + { + path: "/webapi/users/sessions/123", + header: map[string]string{ + "Server": "Caddy", + "Status": "200 OK", + }, + wantOK: true, + wantParams: Params{}, + }, + + { + path: "/webapi/users/events/push/feed", + header: map[string]string{ + "Server": "Caddy", + }, + wantOK: true, + wantParams: Params{ + "names": "push/feed", // Not matching "/webapi/users/events/{names: **}/feed" because missing "Status" header + }, + }, + { + path: "/webapi/users/events/push/feed", + header: map[string]string{ + "Server": "Caddy", + "Status": "200 OK", + }, + wantOK: true, + wantParams: Params{ + "names": "push", + }, + }, + } + for _, test := range tests { + t.Run(test.path, func(t *testing.T) { + header := make(http.Header, len(test.header)) + for k, v := range test.header { + header.Set(k, v) + } + leaf, params, ok := tree.Match(test.path, header) + require.Equal(t, test.wantOK, ok) if !ok { return } assert.Equal(t, test.wantParams, params) - assert.Equal(t, strings.TrimRight(test.wantUnescapedURL, "/"), leaf.URLPath(params, test.withOptional)) + assert.Equal(t, strings.TrimRight(test.path, "/"), leaf.URLPath(params, false)) }) } } diff --git a/router.go b/router.go index 2c257d3..fe5e14a 100644 --- a/router.go +++ b/router.go @@ -7,6 +7,7 @@ package flamego import ( "fmt" "net/http" + "regexp" "strings" "github.com/flamego/flamego/internal/route" @@ -132,10 +133,43 @@ func (r *router) HandlerWrapper(f func(Handler) Handler) { r.handlerWrapper = f } -// Route is a wrapper of the route leaf and its router. +// Route is a wrapper of the route leaves and its router. type Route struct { router *router - leaf route.Leaf + leaves map[string]route.Leaf +} + +// Headers uses given key-value pairs as the list of matching criteria for +// request headers, where key is the header name and value is a regex. Once set, +// the route will only be matched if all header matches are successful in +// addition to the request path. +// +// For example: +// f.Get("/", ...).Headers( +// "User-Agent", "Chrome", // Loose match +// "Host", "^flamego\.dev$", // Exact match +// "Cache-Control", "", // As long as "Cache-Control" is not empty +// ) +// +// Subsequent calls to Headers() replace previously set matches. +func (r *Route) Headers(pairs ...string) *Route { + if len(pairs)%2 != 0 { + panic(fmt.Sprintf("imbalanced pairs with %d", len(pairs))) + } + + matches := make(map[string]*regexp.Regexp, len(pairs)/2) + for i := 1; i < len(pairs); i += 2 { + matches[pairs[i-1]] = regexp.MustCompile(pairs[i]) + } + for m, leaf := range r.leaves { + leaf.SetHeaderMatcher(route.NewHeaderMatcher(matches)) + + // Delete static route from fast paths since header matches are dynamic. + if leaf.Static() { + delete(r.router.staticRoutes[m], leaf.Route()) + } + } + return r } // Name sets the name for the route. @@ -145,7 +179,11 @@ func (r *Route) Name(name string) { } else if _, ok := r.router.namedRoutes[name]; ok { panic("duplicated route name: " + name) } - r.router.namedRoutes[name] = r.leaf + + for _, leaf := range r.leaves { + r.router.namedRoutes[name] = leaf + break + } } func (r *router) addRoute(method, routePath string, handler route.Handler) *Route { @@ -171,9 +209,9 @@ func (r *router) addRoute(method, routePath string, handler route.Handler) *Rout panic(fmt.Sprintf("unable to parse route %q: %v", routePath, err)) } - var leaf route.Leaf + leaves := make(map[string]route.Leaf, len(methods)) for _, m := range methods { - leaf, err = route.AddRoute(r.routeTrees[m], ast, handler) + leaf, err := route.AddRoute(r.routeTrees[m], ast, handler) if err != nil { panic(fmt.Sprintf("unable to add route %q with method %s: %v", routePath, m, err)) } @@ -181,11 +219,12 @@ func (r *router) addRoute(method, routePath string, handler route.Handler) *Rout if leaf.Static() { r.staticRoutes[m][leaf.Route()] = leaf } + leaves[m] = leaf } return &Route{ router: r, - leaf: leaf, + leaves: leaves, } } @@ -319,7 +358,7 @@ func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } - leaf, params, ok := routeTree.Match(req.URL.Path) + leaf, params, ok := routeTree.Match(req.URL.Path, req.Header) if !ok { r.notFound(w, req) return diff --git a/router_test.go b/router_test.go index dbf95ea..66eda6d 100644 --- a/router_test.go +++ b/router_test.go @@ -10,6 +10,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/flamego/flamego/internal/route" ) @@ -35,7 +36,7 @@ func TestRouter_Route(t *testing.T) { t.Run("request with invalid HTTP method", func(t *testing.T) { resp := httptest.NewRecorder() req, err := http.NewRequest("UNEXPECTED", "/", nil) - assert.Nil(t, err) + require.NoError(t, err) ctx.run_ = func() {} r.ServeHTTP(resp, req) @@ -107,7 +108,7 @@ func TestRouter_Route(t *testing.T) { resp := httptest.NewRecorder() req, err := http.NewRequest(test.method, test.routePath, nil) - assert.Nil(t, err) + require.NoError(t, err) r.ServeHTTP(resp, req) @@ -137,7 +138,7 @@ func TestRouter_Routes(t *testing.T) { resp := httptest.NewRecorder() req, err := http.NewRequest(m, "/routes", nil) - assert.Nil(t, err) + require.NoError(t, err) r.ServeHTTP(resp, req) @@ -157,7 +158,7 @@ func TestRouter_Routes(t *testing.T) { resp := httptest.NewRecorder() req, err := http.NewRequest(m, "/routes", nil) - assert.Nil(t, err) + require.NoError(t, err) r.ServeHTTP(resp, req) @@ -185,7 +186,7 @@ func TestRouter_AutoHead(t *testing.T) { resp := httptest.NewRecorder() req, err := http.NewRequest(http.MethodHead, "/", nil) - assert.Nil(t, err) + require.NoError(t, err) r.ServeHTTP(resp, req) @@ -203,7 +204,7 @@ func TestRouter_AutoHead(t *testing.T) { resp := httptest.NewRecorder() req, err := http.NewRequest(http.MethodHead, "/", nil) - assert.Nil(t, err) + require.NoError(t, err) r.ServeHTTP(resp, req) @@ -226,6 +227,34 @@ func TestRouter_DuplicatedRoutes(t *testing.T) { r.Get("/", func() {}) } +func TestRoute_Headers(t *testing.T) { + f := New() + f.Get("/", func() {}).Headers("Server", "Caddy", "Cache-Control", "") + + t.Run("ok", func(t *testing.T) { + resp := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/", nil) + require.NoError(t, err) + + req.Header.Set("Server", "Caddy") + req.Header.Set("Cache-Control", "No-Cache") + f.ServeHTTP(resp, req) + + assert.Equal(t, http.StatusOK, resp.Code) + }) + + t.Run("not found", func(t *testing.T) { + resp := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/", nil) + require.NoError(t, err) + + req.Header.Set("Server", "Caddy") + f.ServeHTTP(resp, req) + + assert.Equal(t, http.StatusNotFound, resp.Code) + }) +} + func TestRoute_Name(t *testing.T) { contextCreator := func(w http.ResponseWriter, r *http.Request, params route.Params, handlers []Handler, urlPath urlPather) internalContext { return newMockContext() @@ -324,7 +353,7 @@ func TestRouter_Group(t *testing.T) { resp := httptest.NewRecorder() req, err := http.NewRequest(http.MethodGet, route, nil) - assert.Nil(t, err) + require.NoError(t, err) r.ServeHTTP(resp, req) @@ -362,7 +391,7 @@ func TestComboRoute(t *testing.T) { resp := httptest.NewRecorder() req, err := http.NewRequest(m, "/", nil) - assert.Nil(t, err) + require.NoError(t, err) r.ServeHTTP(resp, req)