diff --git a/middleware/csrf.go b/middleware/csrf.go index 61299f5ca..ea90fdba7 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -61,7 +61,13 @@ type ( // Indicates SameSite mode of the CSRF cookie. // Optional. Default value SameSiteDefaultMode. CookieSameSite http.SameSite `yaml:"cookie_same_site"` + + // ErrorHandler defines a function which is executed for returning custom errors. + ErrorHandler CSRFErrorHandler } + + // CSRFErrorHandler is a function which is executed for creating custom errors. + CSRFErrorHandler func(err error, c echo.Context) error ) // ErrCSRFInvalid is returned when CSRF check fails @@ -154,8 +160,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { lastTokenErr = ErrCSRFInvalid } } + var finalErr error if lastTokenErr != nil { - return lastTokenErr + finalErr = lastTokenErr } else if lastExtractorErr != nil { // ugly part to preserve backwards compatible errors. someone could rely on them if lastExtractorErr == errQueryExtractorValueMissing { @@ -167,7 +174,14 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { } else { lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, lastExtractorErr.Error()) } - return lastExtractorErr + finalErr = lastExtractorErr + } + + if finalErr != nil { + if config.ErrorHandler != nil { + return config.ErrorHandler(finalErr, c) + } + return finalErr } } diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index 9aff82a98..6bccdbe4d 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -358,3 +358,25 @@ func TestCSRFConfig_skipper(t *testing.T) { }) } } + +func TestCSRFErrorHandling(t *testing.T) { + cfg := CSRFConfig{ + ErrorHandler: func(err error, c echo.Context) error { + return echo.NewHTTPError(http.StatusTeapot, "error_handler_executed") + }, + } + + e := echo.New() + e.POST("/", func(c echo.Context) error { + return c.String(http.StatusNotImplemented, "should not end up here") + }) + + e.Use(CSRFWithConfig(cfg)) + + req := httptest.NewRequest(http.MethodPost, "/", nil) + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusTeapot, res.Code) + assert.Equal(t, "{\"message\":\"error_handler_executed\"}\n", res.Body.String()) +}