From c7079da44a75fe3fcaf929ad5590a4bdfc4dbb5e Mon Sep 17 00:00:00 2001 From: Mojtaba Arezoomand Date: Thu, 1 Sep 2022 00:35:22 +0430 Subject: [PATCH 1/5] feat: add error handlers to csrf middleware --- middleware/csrf.go | 21 +++++++++++++++-- middleware/csrf_test.go | 51 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 2 deletions(-) diff --git a/middleware/csrf.go b/middleware/csrf.go index 61299f5ca..280c6d9c8 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -61,7 +61,15 @@ type ( // Indicates SameSite mode of the CSRF cookie. // Optional. Default value SameSiteDefaultMode. CookieSameSite http.SameSite `yaml:"cookie_same_site"` + + ErrorHandler CSRFErrorHandler + + ErrorHandlerWithContext CSRFErrorHandlerWithContext } + + CSRFErrorHandler func(err error) error + + CSRFErrorHandlerWithContext func(err error, c echo.Context) error ) // ErrCSRFInvalid is returned when CSRF check fails @@ -154,8 +162,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { lastTokenErr = ErrCSRFInvalid } } + var finalErr error if lastTokenErr != nil { - return lastTokenErr + finalErr = lastTokenErr } else if lastExtractorErr != nil { // ugly part to preserve backwards compatible errors. someone could rely on them if lastExtractorErr == errQueryExtractorValueMissing { @@ -167,8 +176,16 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { } else { lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, lastExtractorErr.Error()) } - return lastExtractorErr + finalErr = lastExtractorErr + } + + if config.ErrorHandler != nil { + return config.ErrorHandler(finalErr) + } + if config.ErrorHandlerWithContext != nil { + return config.ErrorHandlerWithContext(finalErr, c) } + return finalErr } // Set CSRF cookie diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index 9aff82a98..d721acb1e 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -1,6 +1,7 @@ package middleware import ( + "encoding/json" "net/http" "net/http/httptest" "net/url" @@ -358,3 +359,53 @@ func TestCSRFConfig_skipper(t *testing.T) { }) } } + +func TestCSRFErrorHandling(t *testing.T) { + testCases := []struct { + name string + cfg CSRFConfig + expectedErr *echo.HTTPError + }{ + { + name: "ok, ErrorHandler is executed", + cfg: CSRFConfig{ + ErrorHandler: func(err error) error { + return echo.NewHTTPError(http.StatusTeapot, "error_handler_executed") + }, + }, + expectedErr: echo.NewHTTPError(http.StatusTeapot, "error_handler_executed"), + }, + { + name: "ok, ErrorHandlerWithContext is executed", + cfg: CSRFConfig{ + ErrorHandlerWithContext: func(err error, c echo.Context) error { + return echo.NewHTTPError(http.StatusTeapot, "error_handler_with_context_executed") + }, + }, + expectedErr: echo.NewHTTPError(http.StatusTeapot, "error_handler_with_context_executed"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + e.POST("/", func(c echo.Context) error { + return c.String(http.StatusNotImplemented, "should not end up here") + }) + + e.Use(CSRFWithConfig(tc.cfg)) + + req := httptest.NewRequest(http.MethodPost, "/", nil) + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + var response struct { + Message string `json:"message"` + } + + assert.NoError(t, json.Unmarshal(res.Body.Bytes(), &response)) + assert.Equal(t, tc.expectedErr.Code, res.Code) + assert.Equal(t, tc.expectedErr.Message, response.Message) + }) + } +} From ea3161f66a4abdf723c9715a6ac32bb79b62b96a Mon Sep 17 00:00:00 2001 From: Mojtaba Arezoomand Date: Thu, 1 Sep 2022 00:49:29 +0430 Subject: [PATCH 2/5] fix: return errors if there is errors in csrf middleware --- middleware/csrf.go | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/middleware/csrf.go b/middleware/csrf.go index 280c6d9c8..a8d629405 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -62,8 +62,10 @@ type ( // Optional. Default value SameSiteDefaultMode. CookieSameSite http.SameSite `yaml:"cookie_same_site"` + // ErrorHandler defines a function which is executed for returning custom errors. ErrorHandler CSRFErrorHandler + // ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context. ErrorHandlerWithContext CSRFErrorHandlerWithContext } @@ -179,13 +181,15 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { finalErr = lastExtractorErr } - if config.ErrorHandler != nil { - return config.ErrorHandler(finalErr) - } - if config.ErrorHandlerWithContext != nil { - return config.ErrorHandlerWithContext(finalErr, c) + if finalErr != nil { + if config.ErrorHandler != nil { + return config.ErrorHandler(finalErr) + } + if config.ErrorHandlerWithContext != nil { + return config.ErrorHandlerWithContext(finalErr, c) + } + return finalErr } - return finalErr } // Set CSRF cookie From ecc6c1ead415ccf3816609decea23f255df62271 Mon Sep 17 00:00:00 2001 From: Mojtaba Arezoomand Date: Thu, 1 Sep 2022 11:28:57 +0430 Subject: [PATCH 3/5] fix: remove ErrorHanlderWithContext --- middleware/csrf.go | 12 ++------- middleware/csrf_test.go | 58 +++++++++++++---------------------------- 2 files changed, 20 insertions(+), 50 deletions(-) diff --git a/middleware/csrf.go b/middleware/csrf.go index a8d629405..695a745e4 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -64,14 +64,9 @@ type ( // ErrorHandler defines a function which is executed for returning custom errors. ErrorHandler CSRFErrorHandler - - // ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context. - ErrorHandlerWithContext CSRFErrorHandlerWithContext } - CSRFErrorHandler func(err error) error - - CSRFErrorHandlerWithContext func(err error, c echo.Context) error + CSRFErrorHandler func(err error, c echo.Context) error ) // ErrCSRFInvalid is returned when CSRF check fails @@ -183,10 +178,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { if finalErr != nil { if config.ErrorHandler != nil { - return config.ErrorHandler(finalErr) - } - if config.ErrorHandlerWithContext != nil { - return config.ErrorHandlerWithContext(finalErr, c) + return config.ErrorHandler(finalErr, c) } return finalErr } diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index d721acb1e..23cc1a767 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -361,51 +361,29 @@ func TestCSRFConfig_skipper(t *testing.T) { } func TestCSRFErrorHandling(t *testing.T) { - testCases := []struct { - name string - cfg CSRFConfig - expectedErr *echo.HTTPError - }{ - { - name: "ok, ErrorHandler is executed", - cfg: CSRFConfig{ - ErrorHandler: func(err error) error { - return echo.NewHTTPError(http.StatusTeapot, "error_handler_executed") - }, - }, - expectedErr: echo.NewHTTPError(http.StatusTeapot, "error_handler_executed"), - }, - { - name: "ok, ErrorHandlerWithContext is executed", - cfg: CSRFConfig{ - ErrorHandlerWithContext: func(err error, c echo.Context) error { - return echo.NewHTTPError(http.StatusTeapot, "error_handler_with_context_executed") - }, - }, - expectedErr: echo.NewHTTPError(http.StatusTeapot, "error_handler_with_context_executed"), + cfg := CSRFConfig{ + ErrorHandler: func(err error, c echo.Context) error { + return echo.NewHTTPError(http.StatusTeapot, "error_handler_executed") }, } + expectedErr := echo.NewHTTPError(http.StatusTeapot, "error_handler_executed") - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := echo.New() - e.POST("/", func(c echo.Context) error { - return c.String(http.StatusNotImplemented, "should not end up here") - }) - - e.Use(CSRFWithConfig(tc.cfg)) + e := echo.New() + e.POST("/", func(c echo.Context) error { + return c.String(http.StatusNotImplemented, "should not end up here") + }) - req := httptest.NewRequest(http.MethodPost, "/", nil) - res := httptest.NewRecorder() - e.ServeHTTP(res, req) + e.Use(CSRFWithConfig(cfg)) - var response struct { - Message string `json:"message"` - } + req := httptest.NewRequest(http.MethodPost, "/", nil) + res := httptest.NewRecorder() + e.ServeHTTP(res, req) - assert.NoError(t, json.Unmarshal(res.Body.Bytes(), &response)) - assert.Equal(t, tc.expectedErr.Code, res.Code) - assert.Equal(t, tc.expectedErr.Message, response.Message) - }) + var response struct { + Message string `json:"message"` } + + assert.NoError(t, json.Unmarshal(res.Body.Bytes(), &response)) + assert.Equal(t, expectedErr.Code, res.Code) + assert.Equal(t, expectedErr.Message, response.Message) } From bf9249e3faaf689749686b53ec7a74c57b19229e Mon Sep 17 00:00:00 2001 From: Mojtaba Arezoomand Date: Thu, 1 Sep 2022 11:42:58 +0430 Subject: [PATCH 4/5] fix: remove unmarshaling in csrf tests --- middleware/csrf_test.go | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index 23cc1a767..6bccdbe4d 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -1,7 +1,6 @@ package middleware import ( - "encoding/json" "net/http" "net/http/httptest" "net/url" @@ -366,7 +365,6 @@ func TestCSRFErrorHandling(t *testing.T) { return echo.NewHTTPError(http.StatusTeapot, "error_handler_executed") }, } - expectedErr := echo.NewHTTPError(http.StatusTeapot, "error_handler_executed") e := echo.New() e.POST("/", func(c echo.Context) error { @@ -379,11 +377,6 @@ func TestCSRFErrorHandling(t *testing.T) { res := httptest.NewRecorder() e.ServeHTTP(res, req) - var response struct { - Message string `json:"message"` - } - - assert.NoError(t, json.Unmarshal(res.Body.Bytes(), &response)) - assert.Equal(t, expectedErr.Code, res.Code) - assert.Equal(t, expectedErr.Message, response.Message) + assert.Equal(t, http.StatusTeapot, res.Code) + assert.Equal(t, "{\"message\":\"error_handler_executed\"}\n", res.Body.String()) } From fc8a542efac71beb93edb0b145fd6ab25eb018b6 Mon Sep 17 00:00:00 2001 From: Mojtaba Arezoomand Date: Thu, 1 Sep 2022 11:47:10 +0430 Subject: [PATCH 5/5] style: add comments for CSRF error handler type --- middleware/csrf.go | 1 + 1 file changed, 1 insertion(+) diff --git a/middleware/csrf.go b/middleware/csrf.go index 695a745e4..ea90fdba7 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -66,6 +66,7 @@ type ( ErrorHandler CSRFErrorHandler } + // CSRFErrorHandler is a function which is executed for creating custom errors. CSRFErrorHandler func(err error, c echo.Context) error )