Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use exact matching of allowed domain entries, issue #489 #493

Merged
merged 5 commits into from
Jun 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
64 changes: 26 additions & 38 deletions cors_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,22 @@ import (
// http://enable-cors.org/server.html
// http://www.html5rocks.com/en/tutorials/cors/#toc-handling-a-not-so-simple-request
type CrossOriginResourceSharing struct {
ExposeHeaders []string // list of Header names
AllowedHeaders []string // list of Header names
AllowedDomains []string // list of allowed values for Http Origin. An allowed value can be a regular expression to support subdomain matching. If empty all are allowed.
ExposeHeaders []string // list of Header names

// AllowedHeaders is alist of Header names. Checking is case-insensitive.
// The list may contain the special wildcard string ".*" ; all is allowed
AllowedHeaders []string

// AllowedDomains is a list of allowed values for Http Origin.
// The list may contain the special wildcard string ".*" ; all is allowed
// If empty all are allowed.
AllowedDomains []string

// AllowedDomainFunc is optional and is a function that will do the check
// when the origin is not part of the AllowedDomains and it does not contain the wildcard ".*".
AllowedDomainFunc func(origin string) bool

// AllowedMethods is either empty or has a list of http methods names. Checking is case-insensitive.
AllowedMethods []string
MaxAge int // number of seconds before requiring new Options request
CookiesAllowed bool
Expand Down Expand Up @@ -119,36 +132,24 @@ func (c CrossOriginResourceSharing) isOriginAllowed(origin string) bool {
if len(origin) == 0 {
return false
}
lowerOrigin := strings.ToLower(origin)
if len(c.AllowedDomains) == 0 {
if c.AllowedDomainFunc != nil {
return c.AllowedDomainFunc(lowerOrigin)
}
return true
}

allowed := false
// exact match on each allowed domain
for _, domain := range c.AllowedDomains {
if domain == origin {
allowed = true
break
if domain == ".*" || strings.ToLower(domain) == lowerOrigin {
return true
}
}

if !allowed {
if len(c.allowedOriginPatterns) == 0 {
// compile allowed domains to allowed origin patterns
allowedOriginRegexps, err := compileRegexps(c.AllowedDomains)
if err != nil {
return false
}
c.allowedOriginPatterns = allowedOriginRegexps
}

for _, pattern := range c.allowedOriginPatterns {
if allowed = pattern.MatchString(origin); allowed {
break
}
}
if c.AllowedDomainFunc != nil {
return c.AllowedDomainFunc(origin)
}

return allowed
return false
}

func (c CrossOriginResourceSharing) setAllowOriginHeader(req *Request, resp *Response) {
Expand Down Expand Up @@ -190,16 +191,3 @@ func (c CrossOriginResourceSharing) isValidAccessControlRequestHeader(header str
}
return false
}

// Take a list of strings and compile them into a list of regular expressions.
func compileRegexps(regexpStrings []string) ([]*regexp.Regexp, error) {
regexps := []*regexp.Regexp{}
for _, regexpStr := range regexpStrings {
r, err := regexp.Compile(regexpStr)
if err != nil {
return regexps, err
}
regexps = append(regexps, r)
}
return regexps, nil
}
40 changes: 38 additions & 2 deletions cors_filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,46 @@ func TestCORSFilter_AllowedDomains(t *testing.T) {
DefaultContainer.Dispatch(httpWriter, httpRequest)
actual := httpWriter.Header().Get(HEADER_AccessControlAllowOrigin)
if actual != each.origin && each.allowed {
t.Fatal("expected to be accepted")
t.Error("expected to be accepted", each)
}
if actual == each.origin && !each.allowed {
t.Fatal("did not expect to be accepted")
t.Error("did not expect to be accepted")
}
}
}

func TestCORSFilter_AllowedDomainFunc(t *testing.T) {
cors := CrossOriginResourceSharing{
AllowedDomains: []string{"here", "there"},
AllowedDomainFunc: func(origin string) bool {
return "where" == origin
},
}
if got, want := cors.isOriginAllowed("here"), true; got != want {
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
}
if got, want := cors.isOriginAllowed("HERE"), true; got != want {
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
}
if got, want := cors.isOriginAllowed("there"), true; got != want {
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
}
if got, want := cors.isOriginAllowed("where"), true; got != want {
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
}
if got, want := cors.isOriginAllowed("nowhere"), false; got != want {
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
}
// just func
cors.AllowedDomains = []string{}
if got, want := cors.isOriginAllowed("here"), false; got != want {
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
}
if got, want := cors.isOriginAllowed("where"), true; got != want {
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
}
// empty domain
if got, want := cors.isOriginAllowed(""), false; got != want {
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
}
}