Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow a cors custom validation function which receives the full gin context #140

Merged
merged 12 commits into from Mar 10, 2024
44 changes: 27 additions & 17 deletions config.go
Expand Up @@ -8,14 +8,15 @@ import (
)

type cors struct {
allowAllOrigins bool
allowCredentials bool
allowOriginFunc func(string) bool
allowOrigins []string
normalHeaders http.Header
preflightHeaders http.Header
wildcardOrigins [][]string
optionsResponseStatusCode int
allowAllOrigins bool
allowCredentials bool
allowOriginFunc func(string) bool
allowOriginWithContextFunc func(*gin.Context, string) bool
allowOrigins []string
normalHeaders http.Header
preflightHeaders http.Header
wildcardOrigins [][]string
optionsResponseStatusCode int
}

var (
Expand Down Expand Up @@ -54,14 +55,15 @@ func newCors(config Config) *cors {
}

return &cors{
allowOriginFunc: config.AllowOriginFunc,
allowAllOrigins: config.AllowAllOrigins,
allowCredentials: config.AllowCredentials,
allowOrigins: normalize(config.AllowOrigins),
normalHeaders: generateNormalHeaders(config),
preflightHeaders: generatePreflightHeaders(config),
wildcardOrigins: config.parseWildcardRules(),
optionsResponseStatusCode: config.OptionsResponseStatusCode,
allowOriginFunc: config.AllowOriginFunc,
allowOriginWithContextFunc: config.AllowOriginWithContextFunc,
allowAllOrigins: config.AllowAllOrigins,
allowCredentials: config.AllowCredentials,
allowOrigins: normalize(config.AllowOrigins),
normalHeaders: generateNormalHeaders(config),
preflightHeaders: generatePreflightHeaders(config),
wildcardOrigins: config.parseWildcardRules(),
optionsResponseStatusCode: config.OptionsResponseStatusCode,
}
}

Expand All @@ -79,7 +81,7 @@ func (cors *cors) applyCors(c *gin.Context) {
return
}

if !cors.validateOrigin(origin) {
if !cors.isOriginValid(c, origin) {
c.AbortWithStatus(http.StatusForbidden)
return
}
Expand Down Expand Up @@ -112,6 +114,14 @@ func (cors *cors) validateWildcardOrigin(origin string) bool {
return false
}

func (cors *cors) isOriginValid(c *gin.Context, origin string) bool {
valid := cors.validateOrigin(origin)
if !valid && cors.allowOriginWithContextFunc != nil {
valid = cors.allowOriginWithContextFunc(c, origin)
}
return valid
}

func (cors *cors) validateOrigin(origin string) bool {
if cors.allowAllOrigins {
return true
Expand Down
24 changes: 21 additions & 3 deletions cors.go
Expand Up @@ -2,6 +2,7 @@ package cors

import (
"errors"
"fmt"
"strings"
"time"

Expand All @@ -22,6 +23,12 @@ type Config struct {
// set, the content of AllowOrigins is ignored.
AllowOriginFunc func(origin string) bool

// Same as AllowOriginFunc except also receives the full request context.
// This function should use the context as a read only source and not
// have any side effects on the request, such as aborting or injecting
// values on the request.
AllowOriginWithContextFunc func(c *gin.Context, origin string) bool

// AllowMethods is a list of methods the client is allowed to use with
// cross-domain requests. Default value is simple methods (GET, POST, PUT, PATCH, DELETE, HEAD, and OPTIONS)
AllowMethods []string
Expand Down Expand Up @@ -108,10 +115,21 @@ func (c Config) validateAllowedSchemas(origin string) bool {

// Validate is check configuration of user defined.
func (c Config) Validate() error {
if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) {
return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowOrigins is not needed")
hasOriginFn := c.AllowOriginFunc != nil
hasOriginFn = hasOriginFn || c.AllowOriginWithContextFunc != nil

if c.AllowAllOrigins && (hasOriginFn || len(c.AllowOrigins) > 0) {
originFields := strings.Join([]string{
"AllowOriginFunc",
"AllowOriginFuncWithContext",
"AllowOrigins",
}, " or ")
return fmt.Errorf(
"conflict settings: all origins enabled. %s is not needed",
originFields,
)
}
if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowOrigins) == 0 {
if !c.AllowAllOrigins && !hasOriginFn && len(c.AllowOrigins) == 0 {
return errors.New("conflict settings: all origins disabled")
}
for _, origin := range c.AllowOrigins {
Expand Down
84 changes: 80 additions & 4 deletions cors_test.go
Expand Up @@ -28,12 +28,34 @@ func newTestRouter(config Config) *gin.Engine {
return router
}

func multiGroupRouter(config Config) *gin.Engine {
router := gin.New()
router.Use(New(config))

app1 := router.Group("/app1")
app1.GET("", func(c *gin.Context) {
c.String(http.StatusOK, "app1")
})

app2 := router.Group("/app2")
app2.GET("", func(c *gin.Context) {
c.String(http.StatusOK, "app2")
})

app3 := router.Group("/app3")
app3.GET("", func(c *gin.Context) {
c.String(http.StatusOK, "app3")
})

return router
}

func performRequest(r http.Handler, method, origin string) *httptest.ResponseRecorder {
return performRequestWithHeaders(r, method, origin, http.Header{})
return performRequestWithHeaders(r, method, "/", origin, http.Header{})
}

func performRequestWithHeaders(r http.Handler, method, origin string, header http.Header) *httptest.ResponseRecorder {
req, _ := http.NewRequestWithContext(context.Background(), method, "/", nil)
func performRequestWithHeaders(r http.Handler, method, path, origin string, header http.Header) *httptest.ResponseRecorder {
req, _ := http.NewRequestWithContext(context.Background(), method, path, nil)
// From go/net/http/request.go:
// For incoming requests, the Host header is promoted to the
// Request.Host field and removed from the Header map.
Expand Down Expand Up @@ -299,6 +321,9 @@ func TestPassesAllowOrigins(t *testing.T) {
AllowOriginFunc: func(origin string) bool {
return origin == "http://github.com"
},
AllowOriginWithContextFunc: func(c *gin.Context, origin string) bool {
return origin == "http://sample.com"
},
dbhoot marked this conversation as resolved.
Show resolved Hide resolved
})

// no CORS request, origin == ""
Expand All @@ -311,7 +336,7 @@ func TestPassesAllowOrigins(t *testing.T) {
// no CORS request, origin == host
h := http.Header{}
h.Set("Host", "facebook.com")
w = performRequestWithHeaders(router, "GET", "http://facebook.com", h)
w = performRequestWithHeaders(router, "GET", "/", "http://facebook.com", h)
assert.Equal(t, "get", w.Body.String())
assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
Expand Down Expand Up @@ -346,6 +371,15 @@ func TestPassesAllowOrigins(t *testing.T) {
assert.Equal(t, "Content-Type,Timestamp", w.Header().Get("Access-Control-Allow-Headers"))
assert.Equal(t, "43200", w.Header().Get("Access-Control-Max-Age"))

// allowed CORS prefligh request: allowed via AllowOriginWithContextFunc
w = performRequest(router, "OPTIONS", "http://sample.com")
assert.Equal(t, http.StatusNoContent, w.Code)
assert.Equal(t, "http://sample.com", w.Header().Get("Access-Control-Allow-Origin"))
assert.Equal(t, "", w.Header().Get("Access-Control-Allow-Credentials"))
assert.Equal(t, "GET,POST,PUT,HEAD", w.Header().Get("Access-Control-Allow-Methods"))
assert.Equal(t, "Content-Type,Timestamp", w.Header().Get("Access-Control-Allow-Headers"))
assert.Equal(t, "43200", w.Header().Get("Access-Control-Max-Age"))

// deny CORS prefligh request
w = performRequest(router, "OPTIONS", "http://example.com")
assert.Equal(t, http.StatusForbidden, w.Code)
Expand Down Expand Up @@ -432,6 +466,48 @@ func TestWildcard(t *testing.T) {
assert.Equal(t, 200, w.Code)
}

func TestMultiGroupRouter(t *testing.T) {
router := multiGroupRouter(Config{
AllowMethods: []string{"GET"},
AllowOriginWithContextFunc: func(c *gin.Context, origin string) bool {
path := c.Request.URL.Path
if strings.HasPrefix(path, "/app1") {
return "http://app1.example.com" == origin
}

if strings.HasPrefix(path, "/app2") {
return "http://app2.example.com" == origin
}

// app 3 allows all origins
return true
},
})

// allowed CORS prefligh request
emptyHeaders := http.Header{}
app1Origin := "http://app1.example.com"
app2Origin := "http://app2.example.com"
randomOrgin := "http://random.com"

// allowed CORS preflight
w := performRequestWithHeaders(router, "OPTIONS", "/app1", app1Origin, emptyHeaders)
assert.Equal(t, http.StatusNoContent, w.Code)

w = performRequestWithHeaders(router, "OPTIONS", "/app2", app2Origin, emptyHeaders)
assert.Equal(t, http.StatusNoContent, w.Code)

w = performRequestWithHeaders(router, "OPTIONS", "/app3", randomOrgin, emptyHeaders)
assert.Equal(t, http.StatusNoContent, w.Code)

// disallowed CORS preflight
w = performRequestWithHeaders(router, "OPTIONS", "/app1", randomOrgin, emptyHeaders)
assert.Equal(t, http.StatusForbidden, w.Code)

w = performRequestWithHeaders(router, "OPTIONS", "/app2", randomOrgin, emptyHeaders)
assert.Equal(t, http.StatusForbidden, w.Code)
}

func TestParseWildcardRules_NoWildcard(t *testing.T) {
config := Config{
AllowOrigins: []string{
Expand Down