From 9d49f1698d35c415b5205e9192e75cbf42bc0f1d Mon Sep 17 00:00:00 2001 From: Dhruv Date: Sat, 9 Mar 2024 22:47:28 -0800 Subject: [PATCH] chore(cors): Allow a custom validation function which receives the full gin context (#140) * Allow a origin validation function with context * Revert "Allow a origin validation function with context" This reverts commit 82827c2bf1d62ccce85980c21253843245f61218. * Allow origin validation function which receives the full request context * fix logic in conditional * add test, fix logic * slightly re-work to pass linter * update comments * restructure to shorten line lengths to pass linter * remove punctuation at the end of error string * Add multi-group preflight test * remove comment --- config.go | 44 ++++++++++++++++----------- cors.go | 24 +++++++++++++-- cors_test.go | 84 +++++++++++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 128 insertions(+), 24 deletions(-) diff --git a/config.go b/config.go index 427cfc0..8a295e3 100644 --- a/config.go +++ b/config.go @@ -8,14 +8,15 @@ import ( ) type cors struct { - allowAllOrigins bool - allowCredentials bool - allowOriginFunc func(string) bool - allowOrigins []string - normalHeaders http.Header - preflightHeaders http.Header - wildcardOrigins [][]string - optionsResponseStatusCode int + allowAllOrigins bool + allowCredentials bool + allowOriginFunc func(string) bool + allowOriginWithContextFunc func(*gin.Context, string) bool + allowOrigins []string + normalHeaders http.Header + preflightHeaders http.Header + wildcardOrigins [][]string + optionsResponseStatusCode int } var ( @@ -54,14 +55,15 @@ func newCors(config Config) *cors { } return &cors{ - allowOriginFunc: config.AllowOriginFunc, - allowAllOrigins: config.AllowAllOrigins, - allowCredentials: config.AllowCredentials, - allowOrigins: normalize(config.AllowOrigins), - normalHeaders: generateNormalHeaders(config), - preflightHeaders: generatePreflightHeaders(config), - wildcardOrigins: config.parseWildcardRules(), - optionsResponseStatusCode: config.OptionsResponseStatusCode, + allowOriginFunc: config.AllowOriginFunc, + allowOriginWithContextFunc: config.AllowOriginWithContextFunc, + allowAllOrigins: config.AllowAllOrigins, + allowCredentials: config.AllowCredentials, + allowOrigins: normalize(config.AllowOrigins), + normalHeaders: generateNormalHeaders(config), + preflightHeaders: generatePreflightHeaders(config), + wildcardOrigins: config.parseWildcardRules(), + optionsResponseStatusCode: config.OptionsResponseStatusCode, } } @@ -79,7 +81,7 @@ func (cors *cors) applyCors(c *gin.Context) { return } - if !cors.validateOrigin(origin) { + if !cors.isOriginValid(c, origin) { c.AbortWithStatus(http.StatusForbidden) return } @@ -112,6 +114,14 @@ func (cors *cors) validateWildcardOrigin(origin string) bool { return false } +func (cors *cors) isOriginValid(c *gin.Context, origin string) bool { + valid := cors.validateOrigin(origin) + if !valid && cors.allowOriginWithContextFunc != nil { + valid = cors.allowOriginWithContextFunc(c, origin) + } + return valid +} + func (cors *cors) validateOrigin(origin string) bool { if cors.allowAllOrigins { return true diff --git a/cors.go b/cors.go index 844bc68..2261df7 100644 --- a/cors.go +++ b/cors.go @@ -2,6 +2,7 @@ package cors import ( "errors" + "fmt" "strings" "time" @@ -22,6 +23,12 @@ type Config struct { // set, the content of AllowOrigins is ignored. AllowOriginFunc func(origin string) bool + // Same as AllowOriginFunc except also receives the full request context. + // This function should use the context as a read only source and not + // have any side effects on the request, such as aborting or injecting + // values on the request. + AllowOriginWithContextFunc func(c *gin.Context, origin string) bool + // AllowMethods is a list of methods the client is allowed to use with // cross-domain requests. Default value is simple methods (GET, POST, PUT, PATCH, DELETE, HEAD, and OPTIONS) AllowMethods []string @@ -108,10 +115,21 @@ func (c Config) validateAllowedSchemas(origin string) bool { // Validate is check configuration of user defined. func (c Config) Validate() error { - if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) { - return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowOrigins is not needed") + hasOriginFn := c.AllowOriginFunc != nil + hasOriginFn = hasOriginFn || c.AllowOriginWithContextFunc != nil + + if c.AllowAllOrigins && (hasOriginFn || len(c.AllowOrigins) > 0) { + originFields := strings.Join([]string{ + "AllowOriginFunc", + "AllowOriginFuncWithContext", + "AllowOrigins", + }, " or ") + return fmt.Errorf( + "conflict settings: all origins enabled. %s is not needed", + originFields, + ) } - if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowOrigins) == 0 { + if !c.AllowAllOrigins && !hasOriginFn && len(c.AllowOrigins) == 0 { return errors.New("conflict settings: all origins disabled") } for _, origin := range c.AllowOrigins { diff --git a/cors_test.go b/cors_test.go index 687ac0d..dedd3cc 100644 --- a/cors_test.go +++ b/cors_test.go @@ -28,12 +28,34 @@ func newTestRouter(config Config) *gin.Engine { return router } +func multiGroupRouter(config Config) *gin.Engine { + router := gin.New() + router.Use(New(config)) + + app1 := router.Group("/app1") + app1.GET("", func(c *gin.Context) { + c.String(http.StatusOK, "app1") + }) + + app2 := router.Group("/app2") + app2.GET("", func(c *gin.Context) { + c.String(http.StatusOK, "app2") + }) + + app3 := router.Group("/app3") + app3.GET("", func(c *gin.Context) { + c.String(http.StatusOK, "app3") + }) + + return router +} + func performRequest(r http.Handler, method, origin string) *httptest.ResponseRecorder { - return performRequestWithHeaders(r, method, origin, http.Header{}) + return performRequestWithHeaders(r, method, "/", origin, http.Header{}) } -func performRequestWithHeaders(r http.Handler, method, origin string, header http.Header) *httptest.ResponseRecorder { - req, _ := http.NewRequestWithContext(context.Background(), method, "/", nil) +func performRequestWithHeaders(r http.Handler, method, path, origin string, header http.Header) *httptest.ResponseRecorder { + req, _ := http.NewRequestWithContext(context.Background(), method, path, nil) // From go/net/http/request.go: // For incoming requests, the Host header is promoted to the // Request.Host field and removed from the Header map. @@ -299,6 +321,9 @@ func TestPassesAllowOrigins(t *testing.T) { AllowOriginFunc: func(origin string) bool { return origin == "http://github.com" }, + AllowOriginWithContextFunc: func(c *gin.Context, origin string) bool { + return origin == "http://sample.com" + }, }) // no CORS request, origin == "" @@ -311,7 +336,7 @@ func TestPassesAllowOrigins(t *testing.T) { // no CORS request, origin == host h := http.Header{} h.Set("Host", "facebook.com") - w = performRequestWithHeaders(router, "GET", "http://facebook.com", h) + w = performRequestWithHeaders(router, "GET", "/", "http://facebook.com", h) assert.Equal(t, "get", w.Body.String()) assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials")) @@ -346,6 +371,15 @@ func TestPassesAllowOrigins(t *testing.T) { assert.Equal(t, "Content-Type,Timestamp", w.Header().Get("Access-Control-Allow-Headers")) assert.Equal(t, "43200", w.Header().Get("Access-Control-Max-Age")) + // allowed CORS prefligh request: allowed via AllowOriginWithContextFunc + w = performRequest(router, "OPTIONS", "http://sample.com") + assert.Equal(t, http.StatusNoContent, w.Code) + assert.Equal(t, "http://sample.com", w.Header().Get("Access-Control-Allow-Origin")) + assert.Equal(t, "", w.Header().Get("Access-Control-Allow-Credentials")) + assert.Equal(t, "GET,POST,PUT,HEAD", w.Header().Get("Access-Control-Allow-Methods")) + assert.Equal(t, "Content-Type,Timestamp", w.Header().Get("Access-Control-Allow-Headers")) + assert.Equal(t, "43200", w.Header().Get("Access-Control-Max-Age")) + // deny CORS prefligh request w = performRequest(router, "OPTIONS", "http://example.com") assert.Equal(t, http.StatusForbidden, w.Code) @@ -432,6 +466,48 @@ func TestWildcard(t *testing.T) { assert.Equal(t, 200, w.Code) } +func TestMultiGroupRouter(t *testing.T) { + router := multiGroupRouter(Config{ + AllowMethods: []string{"GET"}, + AllowOriginWithContextFunc: func(c *gin.Context, origin string) bool { + path := c.Request.URL.Path + if strings.HasPrefix(path, "/app1") { + return "http://app1.example.com" == origin + } + + if strings.HasPrefix(path, "/app2") { + return "http://app2.example.com" == origin + } + + // app 3 allows all origins + return true + }, + }) + + // allowed CORS prefligh request + emptyHeaders := http.Header{} + app1Origin := "http://app1.example.com" + app2Origin := "http://app2.example.com" + randomOrgin := "http://random.com" + + // allowed CORS preflight + w := performRequestWithHeaders(router, "OPTIONS", "/app1", app1Origin, emptyHeaders) + assert.Equal(t, http.StatusNoContent, w.Code) + + w = performRequestWithHeaders(router, "OPTIONS", "/app2", app2Origin, emptyHeaders) + assert.Equal(t, http.StatusNoContent, w.Code) + + w = performRequestWithHeaders(router, "OPTIONS", "/app3", randomOrgin, emptyHeaders) + assert.Equal(t, http.StatusNoContent, w.Code) + + // disallowed CORS preflight + w = performRequestWithHeaders(router, "OPTIONS", "/app1", randomOrgin, emptyHeaders) + assert.Equal(t, http.StatusForbidden, w.Code) + + w = performRequestWithHeaders(router, "OPTIONS", "/app2", randomOrgin, emptyHeaders) + assert.Equal(t, http.StatusForbidden, w.Code) +} + func TestParseWildcardRules_NoWildcard(t *testing.T) { config := Config{ AllowOrigins: []string{