Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose *gin.Context to AllowOriginFunc #67

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions config.go
Expand Up @@ -10,7 +10,7 @@ import (
type cors struct {
allowAllOrigins bool
allowCredentials bool
allowOriginFunc func(string) bool
allowOriginFunc func(string, *gin.Context) bool
allowOrigins []string
exposeHeaders []string
normalHeaders http.Header
Expand Down Expand Up @@ -68,7 +68,7 @@ func (cors *cors) applyCors(c *gin.Context) {
return
}

if !cors.validateOrigin(origin) {
if !cors.validateOrigin(origin, c) {
c.AbortWithStatus(http.StatusForbidden)
return
}
Expand Down Expand Up @@ -101,7 +101,7 @@ func (cors *cors) validateWildcardOrigin(origin string) bool {
return false
}

func (cors *cors) validateOrigin(origin string) bool {
func (cors *cors) validateOrigin(origin string, c *gin.Context) bool {
if cors.allowAllOrigins {
return true
}
Expand All @@ -114,7 +114,7 @@ func (cors *cors) validateOrigin(origin string) bool {
return true
}
if cors.allowOriginFunc != nil {
return cors.allowOriginFunc(origin)
return cors.allowOriginFunc(origin, c)
}
return false
}
Expand Down
2 changes: 1 addition & 1 deletion cors.go
Expand Up @@ -20,7 +20,7 @@ type Config struct {
// AllowOriginFunc is a custom function to validate the origin. It take the origin
// as argument and returns true if allowed or false otherwise. If this option is
// set, the content of AllowOrigins is ignored.
AllowOriginFunc func(origin string) bool
AllowOriginFunc func(origin string, c *gin.Context) bool

// AllowMethods is a list of methods the client is allowed to use with
// cross-domain requests. Default value is simple methods (GET and POST)
Expand Down
65 changes: 33 additions & 32 deletions cors_test.go
Expand Up @@ -81,7 +81,7 @@ func TestBadConfig(t *testing.T) {
assert.Panics(t, func() {
New(Config{
AllowAllOrigins: true,
AllowOriginFunc: func(origin string) bool { return false },
AllowOriginFunc: func(origin string, c *gin.Context) bool { return false },
})
})
assert.Panics(t, func() {
Expand Down Expand Up @@ -200,66 +200,67 @@ func TestGeneratePreflightHeaders_MaxAge(t *testing.T) {
}

func TestValidateOrigin(t *testing.T) {
emptyContext := &gin.Context{}
cors := newCors(Config{
AllowAllOrigins: true,
})
assert.True(t, cors.validateOrigin("http://google.com"))
assert.True(t, cors.validateOrigin("https://google.com"))
assert.True(t, cors.validateOrigin("example.com"))
assert.True(t, cors.validateOrigin("chrome-extension://random-extension-id"))
assert.True(t, cors.validateOrigin("http://google.com", emptyContext))
assert.True(t, cors.validateOrigin("https://google.com", emptyContext))
assert.True(t, cors.validateOrigin("example.com", emptyContext))
assert.True(t, cors.validateOrigin("chrome-extension://random-extension-id", emptyContext))

cors = newCors(Config{
AllowOrigins: []string{"https://google.com", "https://github.com"},
AllowOriginFunc: func(origin string) bool {
AllowOriginFunc: func(origin string, c *gin.Context) bool {
return (origin == "http://news.ycombinator.com")
},
AllowBrowserExtensions: true,
})
assert.False(t, cors.validateOrigin("http://google.com"))
assert.True(t, cors.validateOrigin("https://google.com"))
assert.True(t, cors.validateOrigin("https://github.com"))
assert.True(t, cors.validateOrigin("http://news.ycombinator.com"))
assert.False(t, cors.validateOrigin("http://example.com"))
assert.False(t, cors.validateOrigin("google.com"))
assert.False(t, cors.validateOrigin("chrome-extension://random-extension-id"))
assert.False(t, cors.validateOrigin("http://google.com", emptyContext))
assert.True(t, cors.validateOrigin("https://google.com", emptyContext))
assert.True(t, cors.validateOrigin("https://github.com", emptyContext))
assert.True(t, cors.validateOrigin("http://news.ycombinator.com", emptyContext))
assert.False(t, cors.validateOrigin("http://example.com", emptyContext))
assert.False(t, cors.validateOrigin("google.com", emptyContext))
assert.False(t, cors.validateOrigin("chrome-extension://random-extension-id", emptyContext))

cors = newCors(Config{
AllowOrigins: []string{"https://google.com", "https://github.com"},
})
assert.False(t, cors.validateOrigin("chrome-extension://random-extension-id"))
assert.False(t, cors.validateOrigin("file://some-dangerous-file.js"))
assert.False(t, cors.validateOrigin("wss://socket-connection"))
assert.False(t, cors.validateOrigin("chrome-extension://random-extension-id", emptyContext))
assert.False(t, cors.validateOrigin("file://some-dangerous-file.js", emptyContext))
assert.False(t, cors.validateOrigin("wss://socket-connection", emptyContext))

cors = newCors(Config{
AllowOrigins: []string{"chrome-extension://*", "safari-extension://my-extension-*-app", "*.some-domain.com"},
AllowBrowserExtensions: true,
AllowWildcard: true,
})
assert.True(t, cors.validateOrigin("chrome-extension://random-extension-id"))
assert.True(t, cors.validateOrigin("chrome-extension://another-one"))
assert.True(t, cors.validateOrigin("safari-extension://my-extension-one-app"))
assert.True(t, cors.validateOrigin("safari-extension://my-extension-two-app"))
assert.False(t, cors.validateOrigin("moz-extension://ext-id-we-not-allow"))
assert.True(t, cors.validateOrigin("http://api.some-domain.com"))
assert.False(t, cors.validateOrigin("http://api.another-domain.com"))
assert.True(t, cors.validateOrigin("chrome-extension://random-extension-id", emptyContext))
assert.True(t, cors.validateOrigin("chrome-extension://another-one", emptyContext))
assert.True(t, cors.validateOrigin("safari-extension://my-extension-one-app", emptyContext))
assert.True(t, cors.validateOrigin("safari-extension://my-extension-two-app", emptyContext))
assert.False(t, cors.validateOrigin("moz-extension://ext-id-we-not-allow", emptyContext))
assert.True(t, cors.validateOrigin("http://api.some-domain.com", emptyContext))
assert.False(t, cors.validateOrigin("http://api.another-domain.com", emptyContext))

cors = newCors(Config{
AllowOrigins: []string{"file://safe-file.js", "wss://some-session-layer-connection"},
AllowFiles: true,
AllowWebSockets: true,
})
assert.True(t, cors.validateOrigin("file://safe-file.js"))
assert.False(t, cors.validateOrigin("file://some-dangerous-file.js"))
assert.True(t, cors.validateOrigin("wss://some-session-layer-connection"))
assert.False(t, cors.validateOrigin("ws://not-what-we-expected"))
assert.True(t, cors.validateOrigin("file://safe-file.js", emptyContext))
assert.False(t, cors.validateOrigin("file://some-dangerous-file.js", emptyContext))
assert.True(t, cors.validateOrigin("wss://some-session-layer-connection", emptyContext))
assert.False(t, cors.validateOrigin("ws://not-what-we-expected", emptyContext))

cors = newCors(Config{
AllowOrigins: []string{"*"},
})
assert.True(t, cors.validateOrigin("http://google.com"))
assert.True(t, cors.validateOrigin("https://google.com"))
assert.True(t, cors.validateOrigin("example.com"))
assert.True(t, cors.validateOrigin("chrome-extension://random-extension-id"))
assert.True(t, cors.validateOrigin("http://google.com", emptyContext))
assert.True(t, cors.validateOrigin("https://google.com", emptyContext))
assert.True(t, cors.validateOrigin("example.com", emptyContext))
assert.True(t, cors.validateOrigin("chrome-extension://random-extension-id", emptyContext))
}

func TestPassesAllowOrigins(t *testing.T) {
Expand All @@ -270,7 +271,7 @@ func TestPassesAllowOrigins(t *testing.T) {
ExposeHeaders: []string{"Data", "x-User"},
AllowCredentials: false,
MaxAge: 12 * time.Hour,
AllowOriginFunc: func(origin string) bool {
AllowOriginFunc: func(origin string, c *gin.Context) bool {
return origin == "http://github.com"
},
})
Expand Down
2 changes: 1 addition & 1 deletion examples/example.go
Expand Up @@ -20,7 +20,7 @@ func main() {
AllowHeaders: []string{"Origin"},
ExposeHeaders: []string{"Content-Length"},
AllowCredentials: true,
AllowOriginFunc: func(origin string) bool {
AllowOriginFunc: func(origin string, c *gin.Context) bool {
return origin == "https://github.com"
},
MaxAge: 12 * time.Hour,
Expand Down