diff --git a/cors.go b/cors.go index 69f89d8..9a8b419 100644 --- a/cors.go +++ b/cors.go @@ -140,33 +140,29 @@ func New(options Options) *Cors { c.Log = log.New(os.Stdout, "[cors] ", log.LstdFlags) } - if options.AllowOriginVaryRequestFunc != nil { + // Allowed origins + switch { + case options.AllowOriginVaryRequestFunc != nil: c.allowOriginFunc = options.AllowOriginVaryRequestFunc - } else if options.AllowOriginRequestFunc != nil { + case options.AllowOriginRequestFunc != nil: c.allowOriginFunc = func(r *http.Request, origin string) (bool, []string) { return options.AllowOriginRequestFunc(r, origin), nil } - } else if options.AllowOriginFunc != nil { + case options.AllowOriginFunc != nil: c.allowOriginFunc = func(r *http.Request, origin string) (bool, []string) { return options.AllowOriginFunc(origin), nil } - } - - // Normalize options - // Note: for origins matching, the spec requires a case-sensitive matching. - // As it may error prone, we chose to ignore the spec here. - - // Allowed Origins - if len(options.AllowedOrigins) == 0 { + case len(options.AllowedOrigins) == 0: if c.allowOriginFunc == nil { // Default is all origins c.allowedOriginsAll = true } - } else { + default: c.allowedOrigins = []string{} c.allowedWOrigins = []wildcard{} for _, origin := range options.AllowedOrigins { - // Normalize + // Note: for origins matching, the spec requires a case-sensitive matching. + // As it may error prone, we chose to ignore the spec here. origin = strings.ToLower(origin) if origin == "*" { // If "*" is present in the list, turn the whole list into a match all diff --git a/cors_test.go b/cors_test.go index 05ccf3b..c17dee2 100644 --- a/cors_test.go +++ b/cors_test.go @@ -481,6 +481,23 @@ func TestSpec(t *testing.T) { "Access-Control-Allow-Origin": "http://foobar.com", }, true, + }, { + "AllowedOriginsPlusAllowOriginFunc", + Options{ + AllowedOrigins: []string{"*"}, + AllowOriginFunc: func(origin string) bool { + return true + }, + }, + "GET", + map[string]string{ + "Origin": "http://foobar.com", + }, + map[string]string{ + "Vary": "Origin", + "Access-Control-Allow-Origin": "http://foobar.com", + }, + true, }, } for i := range cases {