diff --git a/bench_test.go b/bench_test.go index 6859382..688d2b2 100644 --- a/bench_test.go +++ b/bench_test.go @@ -2,6 +2,7 @@ package cors import ( "net/http" + "strings" "testing" ) @@ -87,7 +88,22 @@ func BenchmarkPreflightHeader(b *testing.B) { req, _ := http.NewRequest(http.MethodOptions, dummyEndpoint, nil) req.Header.Add(headerOrigin, dummyOrigin) req.Header.Add(headerACRM, http.MethodGet) - req.Header.Add(headerACRH, "Accept") + req.Header.Add(headerACRH, "accept") + handler := Default().Handler(testHandler) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + handler.ServeHTTP(resps[i], req) + } +} + +func BenchmarkPreflightAdversarialACRH(b *testing.B) { + resps := makeFakeResponses(b.N) + req, _ := http.NewRequest(http.MethodOptions, dummyEndpoint, nil) + req.Header.Add(headerOrigin, dummyOrigin) + req.Header.Add(headerACRM, http.MethodGet) + req.Header.Add(headerACRH, strings.Repeat(",", 1024)) handler := Default().Handler(testHandler) b.ReportAllocs() diff --git a/cors.go b/cors.go index 08aff84..da80d34 100644 --- a/cors.go +++ b/cors.go @@ -26,6 +26,8 @@ import ( "os" "strconv" "strings" + + "github.com/rs/cors/internal" ) var headerVaryOrigin = []string{"Origin"} @@ -111,7 +113,11 @@ type Cors struct { // Optional origin validator function allowOriginFunc func(r *http.Request, origin string) (bool, []string) // Normalized list of allowed headers - allowedHeaders []string + // Note: the Fetch standard guarantees that CORS-unsafe request-header names + // (i.e. the values listed in the Access-Control-Request-Headers header) + // are unique and sorted; + // see https://fetch.spec.whatwg.org/#cors-unsafe-request-header-names. + allowedHeaders internal.SortedSet // Normalized list of allowed methods allowedMethods []string // Pre-computed normalized list of exposed headers @@ -183,15 +189,19 @@ func New(options Options) *Cors { } // Allowed Headers + // Note: the Fetch standard guarantees that CORS-unsafe request-header names + // (i.e. the values listed in the Access-Control-Request-Headers header) + // are lowercase; see https://fetch.spec.whatwg.org/#cors-unsafe-request-header-names. if len(options.AllowedHeaders) == 0 { // Use sensible defaults - c.allowedHeaders = []string{"Accept", "Content-Type", "X-Requested-With"} + c.allowedHeaders = internal.NewSortedSet("accept", "content-type", "x-requested-with") } else { - c.allowedHeaders = convert(options.AllowedHeaders, http.CanonicalHeaderKey) + normalized := convert(options.AllowedHeaders, strings.ToLower) + c.allowedHeaders = internal.NewSortedSet(normalized...) for _, h := range options.AllowedHeaders { if h == "*" { c.allowedHeadersAll = true - c.allowedHeaders = nil + c.allowedHeaders = internal.SortedSet{} break } } @@ -351,10 +361,12 @@ func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) { c.logf(" Preflight aborted: method '%s' not allowed", reqMethod) return } - reqHeadersRaw := r.Header["Access-Control-Request-Headers"] - reqHeaders, reqHeadersEdited := convertDidCopy(splitHeaderValues(reqHeadersRaw), http.CanonicalHeaderKey) - if !c.areHeadersAllowed(reqHeaders) { - c.logf(" Preflight aborted: headers '%v' not allowed", reqHeaders) + // Note: the Fetch standard guarantees that at most one + // Access-Control-Request-Headers header is present in the preflight request; + // see step 5.2 in https://fetch.spec.whatwg.org/#cors-preflight-fetch-0. + reqHeaders, found := first(r.Header, "Access-Control-Request-Headers") + if found && !c.allowedHeadersAll && !c.allowedHeaders.Subsumes(reqHeaders[0]) { + c.logf(" Preflight aborted: headers '%v' not allowed", reqHeaders[0]) return } if c.allowedOriginsAll { @@ -365,14 +377,10 @@ func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) { // Spec says: Since the list of methods can be unbounded, simply returning the method indicated // by Access-Control-Request-Method (if supported) can be enough headers["Access-Control-Allow-Methods"] = r.Header["Access-Control-Request-Method"] - if len(reqHeaders) > 0 { + if found && len(reqHeaders[0]) > 0 { // Spec says: Since the list of headers can be unbounded, simply returning supported headers // from Access-Control-Request-Headers can be enough - if reqHeadersEdited || len(reqHeaders) != len(reqHeadersRaw) { - headers.Set("Access-Control-Allow-Headers", strings.Join(reqHeaders, ", ")) - } else { - headers["Access-Control-Allow-Headers"] = reqHeadersRaw - } + headers["Access-Control-Allow-Headers"] = reqHeaders } if c.allowCredentials { headers["Access-Control-Allow-Credentials"] = headerTrue @@ -492,24 +500,3 @@ func (c *Cors) isMethodAllowed(method string) bool { } return false } - -// areHeadersAllowed checks if a given list of headers are allowed to used within -// a cross-domain request. -func (c *Cors) areHeadersAllowed(requestedHeaders []string) bool { - if c.allowedHeadersAll || len(requestedHeaders) == 0 { - return true - } - for _, header := range requestedHeaders { - found := false - for _, h := range c.allowedHeaders { - if h == header { - found = true - break - } - } - if !found { - return false - } - } - return true -} diff --git a/cors_test.go b/cors_test.go index c17dee2..a3c0aab 100644 --- a/cors_test.go +++ b/cors_test.go @@ -303,19 +303,19 @@ func TestSpec(t *testing.T) { "AllowedHeaders", Options{ AllowedOrigins: []string{"http://foobar.com"}, - AllowedHeaders: []string{"X-Header-1", "x-header-2"}, + AllowedHeaders: []string{"X-Header-1", "x-header-2", "X-HEADER-3"}, }, "OPTIONS", map[string]string{ "Origin": "http://foobar.com", "Access-Control-Request-Method": "GET", - "Access-Control-Request-Headers": "X-Header-2, X-HEADER-1", + "Access-Control-Request-Headers": "x-header-1,x-header-2", }, map[string]string{ "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", "Access-Control-Allow-Origin": "http://foobar.com", "Access-Control-Allow-Methods": "GET", - "Access-Control-Allow-Headers": "X-Header-2, X-Header-1", + "Access-Control-Allow-Headers": "x-header-1,x-header-2", }, true, }, @@ -329,13 +329,13 @@ func TestSpec(t *testing.T) { map[string]string{ "Origin": "http://foobar.com", "Access-Control-Request-Method": "GET", - "Access-Control-Request-Headers": "X-Requested-With", + "Access-Control-Request-Headers": "x-requested-with", }, map[string]string{ "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", "Access-Control-Allow-Origin": "http://foobar.com", "Access-Control-Allow-Methods": "GET", - "Access-Control-Allow-Headers": "X-Requested-With", + "Access-Control-Allow-Headers": "x-requested-with", }, true, }, @@ -349,13 +349,13 @@ func TestSpec(t *testing.T) { map[string]string{ "Origin": "http://foobar.com", "Access-Control-Request-Method": "GET", - "Access-Control-Request-Headers": "X-Header-2, X-HEADER-1", + "Access-Control-Request-Headers": "x-header-1,x-header-2", }, map[string]string{ "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", "Access-Control-Allow-Origin": "http://foobar.com", "Access-Control-Allow-Methods": "GET", - "Access-Control-Allow-Headers": "X-Header-2, X-Header-1", + "Access-Control-Allow-Headers": "x-header-1,x-header-2", }, true, }, @@ -369,7 +369,7 @@ func TestSpec(t *testing.T) { map[string]string{ "Origin": "http://foobar.com", "Access-Control-Request-Method": "GET", - "Access-Control-Request-Headers": "X-Header-3, X-Header-1", + "Access-Control-Request-Headers": "x-header-1,x-header-3", }, map[string]string{ "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", @@ -577,8 +577,8 @@ func TestDefault(t *testing.T) { if !s.allowedOriginsAll { t.Error("c.allowedOriginsAll should be true when Default") } - if s.allowedHeaders == nil { - t.Error("c.allowedHeaders should be nil when Default") + if s.allowedHeaders.Size() == 0 { + t.Error("c.allowedHeaders should be empty when Default") } if s.allowedMethods == nil { t.Error("c.allowedMethods should be nil when Default") @@ -712,64 +712,6 @@ func TestOptionsSuccessStatusCodeOverride(t *testing.T) { }) } -func TestCorsAreHeadersAllowed(t *testing.T) { - cases := []struct { - name string - allowedHeaders []string - requestedHeaders []string - want bool - }{ - { - name: "nil allowedHeaders", - allowedHeaders: nil, - requestedHeaders: []string{"X-PINGOTHER, Content-Type"}, - want: false, - }, - { - name: "star allowedHeaders", - allowedHeaders: []string{"*"}, - requestedHeaders: []string{"X-PINGOTHER, Content-Type"}, - want: true, - }, - { - name: "empty reqHeader", - allowedHeaders: nil, - requestedHeaders: []string{}, - want: true, - }, - { - name: "match allowedHeaders", - allowedHeaders: []string{"Content-Type", "X-PINGOTHER", "X-APP-KEY"}, - requestedHeaders: []string{"X-PINGOTHER, Content-Type"}, - want: true, - }, - { - name: "not matched allowedHeaders", - allowedHeaders: []string{"X-PINGOTHER"}, - requestedHeaders: []string{"X-API-KEY, Content-Type"}, - want: false, - }, - { - name: "allowedHeaders should be a superset of requestedHeaders", - allowedHeaders: []string{"X-PINGOTHER"}, - requestedHeaders: []string{"X-PINGOTHER, Content-Type"}, - want: false, - }, - } - - for _, tt := range cases { - tt := tt - - t.Run(tt.name, func(t *testing.T) { - c := New(Options{AllowedHeaders: tt.allowedHeaders}) - have := c.areHeadersAllowed(convert(splitHeaderValues(tt.requestedHeaders), http.CanonicalHeaderKey)) - if have != tt.want { - t.Errorf("Cors.areHeadersAllowed() have: %t want: %t", have, tt.want) - } - }) - } -} - func TestAccessControlExposeHeadersPresence(t *testing.T) { cases := []struct { name string diff --git a/internal/sortedset.go b/internal/sortedset.go new file mode 100644 index 0000000..513da20 --- /dev/null +++ b/internal/sortedset.go @@ -0,0 +1,113 @@ +// adapted from github.com/jub0bs/cors +package internal + +import ( + "sort" + "strings" +) + +// A SortedSet represents a mathematical set of strings sorted in +// lexicographical order. +// Each element has a unique position ranging from 0 (inclusive) +// to the set's cardinality (exclusive). +// The zero value represents an empty set. +type SortedSet struct { + m map[string]int + maxLen int +} + +// NewSortedSet returns a SortedSet that contains all of elems, +// but no other elements. +func NewSortedSet(elems ...string) SortedSet { + sort.Strings(elems) + m := make(map[string]int) + var maxLen int + i := 0 + for _, s := range elems { + if _, exists := m[s]; exists { + continue + } + m[s] = i + i++ + maxLen = max(maxLen, len(s)) + } + return SortedSet{ + m: m, + maxLen: maxLen, + } +} + +// Size returns the cardinality of set. +func (set SortedSet) Size() int { + return len(set.m) +} + +// String sorts joins the elements of set (in lexicographical order) +// with a comma and returns the resulting string. +func (set SortedSet) String() string { + elems := make([]string, len(set.m)) + for elem, i := range set.m { + elems[i] = elem // safe indexing, by construction of SortedSet + } + return strings.Join(elems, ",") +} + +// Subsumes reports whether csv is a sequence of comma-separated names that are +// - all elements of set, +// - sorted in lexicographically order, +// - unique. +func (set SortedSet) Subsumes(csv string) bool { + if csv == "" { + return true + } + posOfLastNameSeen := -1 + chunkSize := set.maxLen + 1 // (to accommodate for at least one comma) + for { + // As a defense against maliciously long names in csv, + // we only process at most chunkSize bytes per iteration. + end := min(len(csv), chunkSize) + comma := strings.IndexByte(csv[:end], ',') + var name string + if comma == -1 { + name = csv + } else { + name = csv[:comma] + } + pos, found := set.m[name] + if !found { + return false + } + // The names in csv are expected to be sorted in lexicographical order + // and appear at most once in csv. + // Therefore, the positions (in set) of the names that + // appear in csv should form a strictly increasing sequence. + // If that's not actually the case, bail out. + if pos <= posOfLastNameSeen { + return false + } + posOfLastNameSeen = pos + if comma < 0 { // We've now processed all the names in csv. + break + } + csv = csv[comma+1:] + } + return true +} + +// TODO: when updating go directive to 1.21 or later, +// use min builtin instead. +func min(a, b int) int { + if a < b { + return a + } + return b +} + +// TODO: when updating go directive to 1.21 or later, +// use max builtin instead. +func max(a, b int) int { + if a > b { + return a + } + return b +} diff --git a/internal/sortedset_test.go b/internal/sortedset_test.go new file mode 100644 index 0000000..9727686 --- /dev/null +++ b/internal/sortedset_test.go @@ -0,0 +1,119 @@ +package internal + +import ( + "testing" +) + +func TestSortedSet(t *testing.T) { + cases := []struct { + desc string + elems []string + combined string + subsets []string + notSubsets []string + wantSize int + }{ + { + desc: "empty set", + combined: "", + notSubsets: []string{ + "bar", + "bar,foo", + }, + wantSize: 0, + }, { + desc: "singleton set", + elems: []string{"foo"}, + combined: "foo", + subsets: []string{ + "", + "foo", + }, + notSubsets: []string{ + "bar", + "bar,foo", + }, + wantSize: 1, + }, { + desc: "no dupes", + elems: []string{"foo", "bar", "baz"}, + combined: "bar,baz,foo", + subsets: []string{ + "", + "bar", + "baz", + "foo", + "bar,baz", + "bar,foo", + "baz,foo", + "bar,baz,foo", + }, + notSubsets: []string{ + "qux", + "bar,baz,baz", + "qux,baz", + "qux,foo", + "quxbaz,foo", + }, + wantSize: 3, + }, { + desc: "some dupes", + elems: []string{"foo", "bar", "bar", "foo", "e"}, + combined: "bar,e,foo", + subsets: []string{ + "", + "bar", + "e", + "foo", + "bar,foo", + "bar,e", + "e,foo", + "bar,e,foo", + }, + notSubsets: []string{ + "qux", + "qux,bar", + "qux,foo", + "qux,baz,foo", + }, + wantSize: 3, + }, + } + for _, tc := range cases { + f := func(t *testing.T) { + elems := clone(tc.elems) + s := NewSortedSet(tc.elems...) + size := s.Size() + if s.Size() != tc.wantSize { + const tmpl = "NewSortedSet(%#v...).Size(): got %d; want %d" + t.Errorf(tmpl, elems, size, tc.wantSize) + } + combined := s.String() + if combined != tc.combined { + const tmpl = "NewSortedSet(%#v...).String(): got %q; want %q" + t.Errorf(tmpl, elems, combined, tc.combined) + } + for _, sub := range tc.subsets { + if !s.Subsumes(sub) { + const tmpl = "%q is not a subset of %q, but should be" + t.Errorf(tmpl, sub, s) + } + } + for _, notSub := range tc.notSubsets { + if s.Subsumes(notSub) { + const tmpl = "%q is a subset of %q, but should not be" + t.Errorf(tmpl, notSub, s) + } + } + } + t.Run(tc.desc, f) + } +} + +// adapted from https://pkg.go.dev/slices#Clone +// TODO: when updating go directive to 1.21 or later, +// use slices.Clone instead. +func clone(s []string) []string { + // The s[:0:0] preserves nil in case it matters. + return append(s[:0:0], s...) +} diff --git a/utils.go b/utils.go index ca9983d..7019f45 100644 --- a/utils.go +++ b/utils.go @@ -1,72 +1,34 @@ package cors import ( + "net/http" "strings" ) -type converter func(string) string - type wildcard struct { prefix string suffix string } func (w wildcard) match(s string) bool { - return len(s) >= len(w.prefix)+len(w.suffix) && strings.HasPrefix(s, w.prefix) && strings.HasSuffix(s, w.suffix) -} - -// split compounded header values ["foo, bar", "baz"] -> ["foo", "bar", "baz"] -func splitHeaderValues(values []string) []string { - out := values - copied := false - for i, v := range values { - needsSplit := strings.IndexByte(v, ',') != -1 - if !copied { - if needsSplit { - split := strings.Split(v, ",") - out = make([]string, i, len(values)+len(split)-1) - copy(out, values[:i]) - for _, s := range split { - out = append(out, strings.TrimSpace(s)) - } - copied = true - } - } else { - if needsSplit { - split := strings.Split(v, ",") - for _, s := range split { - out = append(out, strings.TrimSpace(s)) - } - } else { - out = append(out, v) - } - } - } - return out + return len(s) >= len(w.prefix)+len(w.suffix) && + strings.HasPrefix(s, w.prefix) && + strings.HasSuffix(s, w.suffix) } // convert converts a list of string using the passed converter function -func convert(s []string, c converter) []string { - out, _ := convertDidCopy(s, c) +func convert(s []string, f func(string) string) []string { + out := make([]string, len(s)) + for i := range s { + out[i] = f(s[i]) + } return out } -// convertDidCopy is same as convert but returns true if it copied the slice -func convertDidCopy(s []string, c converter) ([]string, bool) { - out := s - copied := false - for i, v := range s { - if !copied { - v2 := c(v) - if v2 != v { - out = make([]string, len(s)) - copy(out, s[:i]) - out[i] = v2 - copied = true - } - } else { - out[i] = c(v) - } +func first(hdrs http.Header, k string) ([]string, bool) { + v, found := hdrs[k] + if !found || len(v) == 0 { + return nil, false } - return out, copied + return v[:1], true } diff --git a/utils_test.go b/utils_test.go index f6a83ec..409f24e 100644 --- a/utils_test.go +++ b/utils_test.go @@ -1,7 +1,6 @@ package cors import ( - "reflect" "strings" "testing" ) @@ -24,41 +23,6 @@ func TestWildcard(t *testing.T) { } } -func TestSplitHeaderValues(t *testing.T) { - testCases := []struct { - input []string - expected []string - }{ - { - input: []string{}, - expected: []string{}, - }, - { - input: []string{"foo"}, - expected: []string{"foo"}, - }, - { - input: []string{"foo, bar, baz"}, - expected: []string{"foo", "bar", "baz"}, - }, - { - input: []string{"abc", "def, ghi", "jkl"}, - expected: []string{"abc", "def", "ghi", "jkl"}, - }, - { - input: []string{"foo, bar", "baz, qux", "quux, corge"}, - expected: []string{"foo", "bar", "baz", "qux", "quux", "corge"}, - }, - } - - for _, testCase := range testCases { - output := splitHeaderValues(testCase.input) - if !reflect.DeepEqual(output, testCase.expected) { - t.Errorf("Input: %v, Expected: %v, Got: %v", testCase.input, testCase.expected, output) - } - } -} - func TestConvert(t *testing.T) { s := convert([]string{"A", "b", "C"}, strings.ToLower) e := []string{"a", "b", "c"}