diff --git a/config.go b/config.go index d4fc118..adb7ef8 100644 --- a/config.go +++ b/config.go @@ -43,6 +43,12 @@ func newCors(config Config) *cors { panic(err.Error()) } + for _, origin := range config.AllowOrigins { + if origin == "*" { + config.AllowAllOrigins = true + } + } + return &cors{ allowOriginFunc: config.AllowOriginFunc, allowAllOrigins: config.AllowAllOrigins, diff --git a/cors.go b/cors.go index d6d06de..48576a8 100644 --- a/cors.go +++ b/cors.go @@ -95,7 +95,7 @@ func (c Config) validateAllowedSchemas(origin string) bool { } // Validate is check configuration of user defined. -func (c *Config) Validate() error { +func (c Config) Validate() error { if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) { return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowOrigins is not needed") } @@ -103,10 +103,7 @@ func (c *Config) Validate() error { return errors.New("conflict settings: all origins disabled") } for _, origin := range c.AllowOrigins { - if origin == "*" { - c.AllowAllOrigins = true - return nil - } else if !strings.Contains(origin, "*") && !c.validateAllowedSchemas(origin) { + if !strings.Contains(origin, "*") && !c.validateAllowedSchemas(origin) { return errors.New("bad origin: origins must contain '*' or include " + strings.Join(c.getAllowedSchemas(), ",")) } }