Skip to content

Commit

Permalink
introduce AllowedDomainFunc #489
Browse files Browse the repository at this point in the history
  • Loading branch information
emicklei committed Jun 2, 2022
1 parent 3e9df1c commit 5728f44
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 29 deletions.
40 changes: 15 additions & 25 deletions cors_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@ import (
// http://www.html5rocks.com/en/tutorials/cors/#toc-handling-a-not-so-simple-request
type CrossOriginResourceSharing struct {
ExposeHeaders []string // list of Header names
AllowedHeaders []string // list of Header names
AllowedHeaders []string // list of Header names. Checking is case-insensitive.
// AllowedDomains list of allowed values for Http Origin.
// An allowed value can be a regular expression to support subdomain matching.
// Non-regular expression values will be changed into an exact match: ^yourdomain.com$
// The list may contain the special wildcard string ".*" ; all is allowed
// If empty all are allowed.
AllowedDomains []string
// AllowedDomainFunc is optional and is a function that will do the check
// whether the origin is not part of the AllowedDomains.
AllowedDomainFunc func(origin string) bool
// AllowedMethods is either empty or has a list of http methods names. Checking is case-insensitive.
AllowedMethods []string
MaxAge int // number of seconds before requiring new Options request
CookiesAllowed bool
Expand Down Expand Up @@ -125,35 +128,22 @@ func (c CrossOriginResourceSharing) isOriginAllowed(origin string) bool {
return false
}
if len(c.AllowedDomains) == 0 {
if c.AllowedDomainFunc != nil {
return c.AllowedDomainFunc(origin)
}
return true
}

allowed := false
// exact match on each allowed domain
for _, domain := range c.AllowedDomains {
if domain == origin {
allowed = true
break
if domain == ".*" || domain == origin {
return true
}
}

if !allowed {
if len(c.allowedOriginPatterns) == 0 {
// compile allowed domains to allowed origin patterns
allowedOriginRegexps, err := compileRegexps(c.AllowedDomains)
if err != nil {
return false
}
c.allowedOriginPatterns = allowedOriginRegexps
}

for _, pattern := range c.allowedOriginPatterns {
if allowed = pattern.MatchString(origin); allowed {
break
}
}
if c.AllowedDomainFunc != nil {
return c.AllowedDomainFunc(origin)
}

return allowed
return false
}

func (c CrossOriginResourceSharing) setAllowOriginHeader(req *Request, resp *Response) {
Expand Down
27 changes: 23 additions & 4 deletions cors_filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,6 @@ var allowedDomainInput = []struct {
{[]string{"example.com"}, "not-allowed", false},
{[]string{"not-matching.com", "example.com"}, "example.com", true},
{[]string{".*"}, "example.com", true},
{[]string{"^some.example.com$"}, "some.example.com", true},
{[]string{"^some\\.example\\.com"}, "some.example.com.org", false},
}

// go test -v -test.run TestCORSFilter_AllowedDomains ...restful
Expand All @@ -122,10 +120,31 @@ func TestCORSFilter_AllowedDomains(t *testing.T) {
DefaultContainer.Dispatch(httpWriter, httpRequest)
actual := httpWriter.Header().Get(HEADER_AccessControlAllowOrigin)
if actual != each.origin && each.allowed {
t.Fatal("expected to be accepted")
t.Error("expected to be accepted", each)
}
if actual == each.origin && !each.allowed {
t.Fatal("did not expect to be accepted")
t.Error("did not expect to be accepted")
}
}
}

func TestCORSFilter_AllowedDomainFunc(t *testing.T) {
cors := CrossOriginResourceSharing{
AllowedDomains: []string{"here", "there"},
AllowedDomainFunc: func(origin string) bool {
return "where" == origin
},
}
if got, want := cors.isOriginAllowed("here"), true; got != want {
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
}
if got, want := cors.isOriginAllowed("there"), true; got != want {
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
}
if got, want := cors.isOriginAllowed("where"), true; got != want {
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
}
if got, want := cors.isOriginAllowed("nowhere"), false; got != want {
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
}
}

0 comments on commit 5728f44

Please sign in to comment.