Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

route: add Headers method for matching request headers #133

Merged
merged 4 commits into from Jun 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions codecov.yml
Expand Up @@ -5,6 +5,10 @@ coverage:
default:
threshold: 1%
informational: true
patch:
default:
only_pulls: true
informational: true

comment:
layout: 'diff'
Expand Down
37 changes: 37 additions & 0 deletions internal/route/header_matcher.go
@@ -0,0 +1,37 @@
// Copyright 2022 Flamego. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.

package route

import (
"net/http"
"regexp"
)

// HeaderMatcher stores matchers for request headers.
type HeaderMatcher struct {
matches map[string]*regexp.Regexp // Key is the header name
}

// NewHeaderMatcher creates a new HeaderMatcher using given matches, where keys
// are header names.
func NewHeaderMatcher(matches map[string]*regexp.Regexp) *HeaderMatcher {
return &HeaderMatcher{
matches: matches,
}
}

// Match returns true if all matches are successfully in the given header.
func (m *HeaderMatcher) Match(header http.Header) bool {
for name, re := range m.matches {
v := header.Get(name)
if v == "" {
return false
}
if !re.MatchString(v) {
return false
}
}
return true
}
81 changes: 81 additions & 0 deletions internal/route/header_matcher_test.go
@@ -0,0 +1,81 @@
// Copyright 2022 Flamego. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.

package route

import (
"net/http"
"regexp"
"testing"

"github.com/stretchr/testify/assert"
)

func TestHeaderMatcher(t *testing.T) {
header := make(http.Header)
header.Set("Server", "Caddy")
header.Set("Status", "200 OK")

tests := []struct {
name string
matches map[string]*regexp.Regexp
want bool
}{
{
name: "loose matches",
matches: map[string]*regexp.Regexp{
"Server": regexp.MustCompile("Caddy"),
"Status": regexp.MustCompile("200"),
},
want: true,
},
{
name: "loose matches",
matches: map[string]*regexp.Regexp{
"Server": regexp.MustCompile("Caddy"),
"Status": regexp.MustCompile("404"),
},
want: false,
},

{
name: "exact matches",
matches: map[string]*regexp.Regexp{
"Server": regexp.MustCompile("^Caddy$"),
"Status": regexp.MustCompile("^200 OK$"),
},
want: true,
},
{
name: "exact matches",
matches: map[string]*regexp.Regexp{
"Server": regexp.MustCompile("^Caddy$"),
"Status": regexp.MustCompile("^200$"),
},
want: false,
},

{
name: "presence match",
matches: map[string]*regexp.Regexp{
"Server": regexp.MustCompile(""),
},
want: true,
},
{
name: "presence match",
matches: map[string]*regexp.Regexp{
"Server": regexp.MustCompile(""),
"Cache-Control": regexp.MustCompile(""),
},
want: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
got := NewHeaderMatcher(test.matches).Match(header)
assert.Equal(t, test.want, got)
})
}
}
48 changes: 37 additions & 11 deletions internal/route/leaf.go
Expand Up @@ -6,6 +6,7 @@ package route

import (
"bytes"
"net/http"
"regexp"
"strconv"
"strings"
Expand All @@ -27,6 +28,9 @@ const (

// Leaf is a leaf derived from a segment.
type Leaf interface {
// SetHeaderMatcher sets the HeaderMatcher for the leaf.
SetHeaderMatcher(m *HeaderMatcher)

// URLPath fills in bind parameters with given values to build the "path"
// portion of the URL. If `withOptional` is true, the path will include the
// current leaf when it is optional; otherwise, the current leaf is excluded.
Expand All @@ -46,15 +50,16 @@ type Leaf interface {
getMatchStyle() MatchStyle
// match returns true if the leaf matches the segment, values of bind parameters
// are stored in the `Params`.
match(segment string, params Params) bool
match(segment string, params Params, header http.Header) bool
}

// baseLeaf contains common fields for any leaf.
type baseLeaf struct {
parent Tree // The parent tree this leaf belongs to.
route *Route // The route that the segment belongs to.
segment *Segment // The segment that the leaf is derived from.
handler Handler // The handler bound to the leaf.
parent Tree // The parent tree this leaf belongs to.
route *Route // The route that the segment belongs to.
segment *Segment // The segment that the leaf is derived from.
handler Handler // The handler bound to the leaf.
headerMatcher *HeaderMatcher // The matcher for header values.
}

func (l *baseLeaf) getParent() Tree {
Expand All @@ -65,6 +70,14 @@ func (l *baseLeaf) getSegment() *Segment {
return l.segment
}

func (l *baseLeaf) SetHeaderMatcher(m *HeaderMatcher) {
l.headerMatcher = m
}

func (l *baseLeaf) matchHeader(header http.Header) bool {
return l.headerMatcher == nil || l.headerMatcher.Match(header)
}

func (l *baseLeaf) URLPath(vals map[string]string, withOptional bool) string {
var buf bytes.Buffer
for _, s := range l.route.Segments {
Expand Down Expand Up @@ -123,8 +136,8 @@ func (*staticLeaf) getMatchStyle() MatchStyle {
return matchStyleStatic
}

func (l *staticLeaf) match(segment string, _ Params) bool {
return l.literals == segment
func (l *staticLeaf) match(segment string, _ Params, header http.Header) bool {
return l.literals == segment && l.matchHeader(header)
}

func (l *staticLeaf) Static() bool {
Expand All @@ -149,12 +162,16 @@ func (*regexLeaf) getMatchStyle() MatchStyle {
return matchStyleRegex
}

func (l *regexLeaf) match(segment string, params Params) bool {
func (l *regexLeaf) match(segment string, params Params, header http.Header) bool {
submatches := l.regexp.FindStringSubmatch(segment)
if len(submatches) < len(l.binds)+1 {
return false
}

if !l.matchHeader(header) {
return false
}

for i, bind := range l.binds {
params[bind] = submatches[i+1]
}
Expand All @@ -171,7 +188,10 @@ func (*placeholderLeaf) getMatchStyle() MatchStyle {
return matchStylePlaceholder
}

func (l *placeholderLeaf) match(segment string, params Params) bool {
func (l *placeholderLeaf) match(segment string, params Params, header http.Header) bool {
if !l.matchHeader(header) {
return false
}
params[l.bind] = segment
return true
}
Expand All @@ -187,7 +207,10 @@ func (*matchAllLeaf) getMatchStyle() MatchStyle {
return matchStyleAll
}

func (l *matchAllLeaf) match(segment string, params Params) bool {
func (l *matchAllLeaf) match(segment string, params Params, header http.Header) bool {
if !l.matchHeader(header) {
return false
}
params[l.bind] = segment
return true
}
Expand All @@ -196,13 +219,16 @@ func (l *matchAllLeaf) match(segment string, params Params) bool {
// defined). The `path` should be original request path, `segment` should NOT be
// unescaped by the caller. It returns true if segments are captured within the
// limit, and the capture result is stored in `params`.
func (l *matchAllLeaf) matchAll(path, segment string, next int, params Params) bool {
func (l *matchAllLeaf) matchAll(path, segment string, next int, params Params, header http.Header) bool {
// Do `next-1` because "next" starts at the next character of preceding "/"; do
// `strings.Count()+1` because the segment itself also counts. E.g. "webapi" +
// "users/events" => 3
if l.capture > 0 && l.capture < strings.Count(path[next-1:], "/")+1 {
return false
}
if !l.matchHeader(header) {
return false
}

params[l.bind] = segment + "/" + path[next:]
return true
Expand Down
19 changes: 10 additions & 9 deletions internal/route/leaf_test.go
Expand Up @@ -9,6 +9,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNewLeaf(t *testing.T) {
Expand All @@ -21,7 +22,7 @@ func TestNewLeaf(t *testing.T) {
})

parser, err := NewParser()
assert.Nil(t, err)
require.NoError(t, err)

tests := []struct {
route string
Expand Down Expand Up @@ -60,12 +61,12 @@ func TestNewLeaf(t *testing.T) {
for _, test := range tests {
t.Run(test.route, func(t *testing.T) {
route, err := parser.Parse(test.route)
assert.Nil(t, err)
require.NoError(t, err)
assert.Len(t, route.Segments, 1)

segment := route.Segments[0]
got, err := newLeaf(nil, route, segment, nil)
assert.Nil(t, err)
require.NoError(t, err)

switch test.style {
case matchStyleStatic:
Expand All @@ -86,7 +87,7 @@ func TestNewLeaf(t *testing.T) {

func TestNewLeaf_Regex(t *testing.T) {
parser, err := NewParser()
assert.Nil(t, err)
require.NoError(t, err)

tests := []struct {
route string
Expand Down Expand Up @@ -122,12 +123,12 @@ func TestNewLeaf_Regex(t *testing.T) {
for _, test := range tests {
t.Run(test.route, func(t *testing.T) {
route, err := parser.Parse(test.route)
assert.Nil(t, err)
require.NoError(t, err)
assert.Len(t, route.Segments, 1)

segment := route.Segments[0]
got, err := newLeaf(nil, route, segment, nil)
assert.Nil(t, err)
require.NoError(t, err)

leaf := got.(*regexLeaf)
assert.Equal(t, test.wantRegexp, leaf.regexp.String())
Expand All @@ -138,7 +139,7 @@ func TestNewLeaf_Regex(t *testing.T) {

func TestLeaf_URLPath(t *testing.T) {
parser, err := NewParser()
assert.Nil(t, err)
require.NoError(t, err)

tests := []struct {
route string
Expand Down Expand Up @@ -245,11 +246,11 @@ func TestLeaf_URLPath(t *testing.T) {
for _, test := range tests {
t.Run(test.route, func(t *testing.T) {
route, err := parser.Parse(test.route)
assert.Nil(t, err)
require.NoError(t, err)

segment := route.Segments[len(route.Segments)-1]
leaf, err := newLeaf(nil, route, segment, nil)
assert.Nil(t, err)
require.NoError(t, err)

got := leaf.URLPath(test.vals, test.withOptional)
assert.Equal(t, test.want, got)
Expand Down