Skip to content

Commit

Permalink
JWT, KeyAuth, CSRF multivalue extractors (#2060)
Browse files Browse the repository at this point in the history
* CSRF, JWT, KeyAuth middleware support for multivalue value extractors
* Add flag to JWT and KeyAuth middleware to allow continuing execution `next(c)` when error handler decides to swallow the error (returns nil).
  • Loading branch information
aldas committed Jan 24, 2022
1 parent 9e9924d commit 4a1ccdf
Show file tree
Hide file tree
Showing 10 changed files with 1,562 additions and 395 deletions.
4 changes: 2 additions & 2 deletions echo.go
Expand Up @@ -111,10 +111,10 @@ type (
}

// MiddlewareFunc defines a function to process middleware.
MiddlewareFunc func(HandlerFunc) HandlerFunc
MiddlewareFunc func(next HandlerFunc) HandlerFunc

// HandlerFunc defines a function to serve HTTP requests.
HandlerFunc func(Context) error
HandlerFunc func(c Context) error

// HTTPErrorHandler is a centralized HTTP error handler.
HTTPErrorHandler func(error, Context)
Expand Down
110 changes: 47 additions & 63 deletions middleware/csrf.go
Expand Up @@ -2,9 +2,7 @@ package middleware

import (
"crypto/subtle"
"errors"
"net/http"
"strings"
"time"

"github.com/labstack/echo/v4"
Expand All @@ -21,13 +19,15 @@ type (
TokenLength uint8 `yaml:"token_length"`
// Optional. Default value 32.

// TokenLookup is a string in the form of "<source>:<key>" that is used
// TokenLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
// to extract token from the request.
// Optional. Default value "header:X-CSRF-Token".
// Possible values:
// - "header:<name>"
// - "form:<name>"
// - "header:<name>" or "header:<name>:<cut-prefix>"
// - "query:<name>"
// - "form:<name>"
// Multiple sources example:
// - "header:X-CSRF-Token,query:csrf"
TokenLookup string `yaml:"token_lookup"`

// Context key to store generated CSRF token into context.
Expand Down Expand Up @@ -62,12 +62,11 @@ type (
// Optional. Default value SameSiteDefaultMode.
CookieSameSite http.SameSite `yaml:"cookie_same_site"`
}

// csrfTokenExtractor defines a function that takes `echo.Context` and returns
// either a token or an error.
csrfTokenExtractor func(echo.Context) (string, error)
)

// ErrCSRFInvalid is returned when CSRF check fails
var ErrCSRFInvalid = echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")

var (
// DefaultCSRFConfig is the default CSRF middleware config.
DefaultCSRFConfig = CSRFConfig{
Expand Down Expand Up @@ -114,14 +113,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
config.CookieSecure = true
}

// Initialize
parts := strings.Split(config.TokenLookup, ":")
extractor := csrfTokenFromHeader(parts[1])
switch parts[0] {
case "form":
extractor = csrfTokenFromForm(parts[1])
case "query":
extractor = csrfTokenFromQuery(parts[1])
extractors, err := createExtractors(config.TokenLookup, "")
if err != nil {
panic(err)
}

return func(next echo.HandlerFunc) echo.HandlerFunc {
Expand All @@ -130,28 +124,50 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
return next(c)
}

req := c.Request()
k, err := c.Cookie(config.CookieName)
token := ""

// Generate token
if err != nil {
token = random.String(config.TokenLength)
if k, err := c.Cookie(config.CookieName); err != nil {
token = random.String(config.TokenLength) // Generate token
} else {
// Reuse token
token = k.Value
token = k.Value // Reuse token
}

switch req.Method {
switch c.Request().Method {
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
default:
// Validate token only for requests which are not defined as 'safe' by RFC7231
clientToken, err := extractor(c)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
var lastExtractorErr error
var lastTokenErr error
outer:
for _, extractor := range extractors {
clientTokens, err := extractor(c)
if err != nil {
lastExtractorErr = err
continue
}

for _, clientToken := range clientTokens {
if validateCSRFToken(token, clientToken) {
lastTokenErr = nil
lastExtractorErr = nil
break outer
}
lastTokenErr = ErrCSRFInvalid
}
}
if !validateCSRFToken(token, clientToken) {
return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
if lastTokenErr != nil {
return lastTokenErr
} else if lastExtractorErr != nil {
// ugly part to preserve backwards compatible errors. someone could rely on them
if lastExtractorErr == errQueryExtractorValueMissing {
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the query string")
} else if lastExtractorErr == errFormExtractorValueMissing {
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the form parameter")
} else if lastExtractorErr == errHeaderExtractorValueMissing {
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in request header")
} else {
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, lastExtractorErr.Error())
}
return lastExtractorErr
}
}

Expand Down Expand Up @@ -184,38 +200,6 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
}
}

// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
// provided request header.
func csrfTokenFromHeader(header string) csrfTokenExtractor {
return func(c echo.Context) (string, error) {
return c.Request().Header.Get(header), nil
}
}

// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
// provided form parameter.
func csrfTokenFromForm(param string) csrfTokenExtractor {
return func(c echo.Context) (string, error) {
token := c.FormValue(param)
if token == "" {
return "", errors.New("missing csrf token in the form parameter")
}
return token, nil
}
}

// csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the
// provided query parameter.
func csrfTokenFromQuery(param string) csrfTokenExtractor {
return func(c echo.Context) (string, error) {
token := c.QueryParam(param)
if token == "" {
return "", errors.New("missing csrf token in the query string")
}
return token, nil
}
}

func validateCSRFToken(token, clientToken string) bool {
return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1
}

0 comments on commit 4a1ccdf

Please sign in to comment.