diff --git a/echo.go b/echo.go index 246a62256..a28fa0c1a 100644 --- a/echo.go +++ b/echo.go @@ -358,6 +358,11 @@ func (e *Echo) Routers() map[string]*Router { // DefaultHTTPErrorHandler is the default HTTP error handler. It sends a JSON response // with status code. func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) { + + if c.Response().Committed { + return + } + he, ok := err.(*HTTPError) if ok { if he.Internal != nil { @@ -384,15 +389,13 @@ func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) { } // Send response - if !c.Response().Committed { - if c.Request().Method == http.MethodHead { // Issue #608 - err = c.NoContent(he.Code) - } else { - err = c.JSON(code, message) - } - if err != nil { - e.Logger.Error(err) - } + if c.Request().Method == http.MethodHead { // Issue #608 + err = c.NoContent(he.Code) + } else { + err = c.JSON(code, message) + } + if err != nil { + e.Logger.Error(err) } } diff --git a/echo_test.go b/echo_test.go index dc553490b..f28915864 100644 --- a/echo_test.go +++ b/echo_test.go @@ -1124,6 +1124,15 @@ func TestDefaultHTTPErrorHandler(t *testing.T) { "error": "stackinfo", }) }) + e.Any("/early-return", func(c Context) error { + c.String(http.StatusOK, "OK") + return errors.New("ERROR") + }) + e.GET("/internal-error", func(c Context) error { + err := errors.New("internal error message body") + return NewHTTPError(http.StatusBadRequest).SetInternal(err) + }) + // With Debug=true plain response contains error message c, b := request(http.MethodGet, "/plain", e) assert.Equal(t, http.StatusInternalServerError, c) @@ -1136,6 +1145,14 @@ func TestDefaultHTTPErrorHandler(t *testing.T) { c, b = request(http.MethodGet, "/servererror", e) assert.Equal(t, http.StatusInternalServerError, c) assert.Equal(t, "{\n \"code\": 33,\n \"error\": \"stackinfo\",\n \"message\": \"Something bad happened\"\n}\n", b) + // if the body is already set HTTPErrorHandler should not add anything to response body + c, b = request(http.MethodGet, "/early-return", e) + assert.Equal(t, http.StatusOK, c) + assert.Equal(t, "OK", b) + // internal error should be reflected in the message + c, b = request(http.MethodGet, "/internal-error", e) + assert.Equal(t, http.StatusBadRequest, c) + assert.Equal(t, "{\n \"error\": \"code=400, message=Bad Request, internal=internal error message body\",\n \"message\": \"Bad Request\"\n}\n", b) e.Debug = false // With Debug=false the error response is shortened diff --git a/middleware/jwt.go b/middleware/jwt.go index c2e7c06d4..21e33ab82 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -295,7 +295,7 @@ func jwtFromHeader(header string, authScheme string) jwtExtractor { return func(c echo.Context) (string, error) { auth := c.Request().Header.Get(header) l := len(authScheme) - if len(auth) > l+1 && auth[:l] == authScheme { + if len(auth) > l+1 && strings.EqualFold(auth[:l], authScheme) { return auth[l+1:], nil } return "", ErrJWTMissing diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index 393fd93d3..5f36ce0a5 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -261,6 +261,11 @@ func TestJWT(t *testing.T) { expErrCode: http.StatusUnauthorized, info: "Token verification does not pass using a user-defined KeyFunc", }, + { + hdrAuth: strings.ToLower(DefaultJWTConfig.AuthScheme) + " " + token, + config: JWTConfig{SigningKey: validKey}, + info: "Valid JWT with lower case AuthScheme", + }, } { if tc.reqURL == "" { tc.reqURL = "/" diff --git a/middleware/key_auth.go b/middleware/key_auth.go index fd169aa2c..54f3b47f3 100644 --- a/middleware/key_auth.go +++ b/middleware/key_auth.go @@ -2,6 +2,7 @@ package middleware import ( "errors" + "fmt" "net/http" "strings" @@ -21,6 +22,7 @@ type ( // - "header:" // - "query:" // - "form:" + // - "cookie:" KeyLookup string `yaml:"key_lookup"` // AuthScheme to be used in the Authorization header. @@ -91,6 +93,8 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc { extractor = keyFromQuery(parts[1]) case "form": extractor = keyFromForm(parts[1]) + case "cookie": + extractor = keyFromCookie(parts[1]) } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -164,3 +168,14 @@ func keyFromForm(param string) keyExtractor { return key, nil } } + +// keyFromCookie returns a `keyExtractor` that extracts key from the form. +func keyFromCookie(cookieName string) keyExtractor { + return func(c echo.Context) (string, error) { + key, err := c.Cookie(cookieName) + if err != nil { + return "", fmt.Errorf("missing key in cookies: %w", err) + } + return key.Value, nil + } +} diff --git a/middleware/key_auth_test.go b/middleware/key_auth_test.go index 476b402d9..0cc513ab0 100644 --- a/middleware/key_auth_test.go +++ b/middleware/key_auth_test.go @@ -157,6 +157,30 @@ func TestKeyAuthWithConfig(t *testing.T) { expectHandlerCalled: false, expectError: "code=400, message=missing key in the form", }, + { + name: "ok, custom key lookup, cookie", + givenRequest: func(req *http.Request) { + req.AddCookie(&http.Cookie{ + Name: "key", + Value: "valid-key", + }) + q := req.URL.Query() + q.Add("key", "valid-key") + req.URL.RawQuery = q.Encode() + }, + whenConfig: func(conf *KeyAuthConfig) { + conf.KeyLookup = "cookie:key" + }, + expectHandlerCalled: true, + }, + { + name: "nok, custom key lookup, missing cookie param", + whenConfig: func(conf *KeyAuthConfig) { + conf.KeyLookup = "cookie:key" + }, + expectHandlerCalled: false, + expectError: "code=400, message=missing key in cookies: http: named cookie not present", + }, { name: "nok, custom errorHandler, error from extractor", whenConfig: func(conf *KeyAuthConfig) {