From c32fafad68daa7214f0ca005b4614ca38e90b2b8 Mon Sep 17 00:00:00 2001 From: Guilherme Cardoso Date: Tue, 7 Dec 2021 10:56:32 +0000 Subject: [PATCH] Add support for configurable target header for the request_id middleware --- echo.go | 1 + middleware/request_id.go | 15 +++++++++++---- middleware/request_id_test.go | 31 +++++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/echo.go b/echo.go index df5d35843..ad03dd519 100644 --- a/echo.go +++ b/echo.go @@ -214,6 +214,7 @@ const ( HeaderXHTTPMethodOverride = "X-HTTP-Method-Override" HeaderXRealIP = "X-Real-IP" HeaderXRequestID = "X-Request-ID" + HeaderXCorrelationID = "X-Correlation-ID" HeaderXRequestedWith = "X-Requested-With" HeaderServer = "Server" HeaderOrigin = "Origin" diff --git a/middleware/request_id.go b/middleware/request_id.go index b0baeeb2d..8c5ff6605 100644 --- a/middleware/request_id.go +++ b/middleware/request_id.go @@ -17,14 +17,18 @@ type ( // RequestIDHandler defines a function which is executed for a request id. RequestIDHandler func(echo.Context, string) + + // TargetHeader defines what header to look for to populate the id + TargetHeader string } ) var ( // DefaultRequestIDConfig is the default RequestID middleware config. DefaultRequestIDConfig = RequestIDConfig{ - Skipper: DefaultSkipper, - Generator: generator, + Skipper: DefaultSkipper, + Generator: generator, + TargetHeader: echo.HeaderXRequestID, } ) @@ -42,6 +46,9 @@ func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc { if config.Generator == nil { config.Generator = generator } + if config.TargetHeader == "" { + config.TargetHeader = echo.HeaderXRequestID + } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { @@ -51,11 +58,11 @@ func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc { req := c.Request() res := c.Response() - rid := req.Header.Get(echo.HeaderXRequestID) + rid := req.Header.Get(config.TargetHeader) if rid == "" { rid = config.Generator() } - res.Header().Set(echo.HeaderXRequestID, rid) + res.Header().Set(config.TargetHeader, rid) if config.RequestIDHandler != nil { config.RequestIDHandler(c, rid) } diff --git a/middleware/request_id_test.go b/middleware/request_id_test.go index 944b3b49e..21b777826 100644 --- a/middleware/request_id_test.go +++ b/middleware/request_id_test.go @@ -55,3 +55,34 @@ func TestRequestID_IDNotAltered(t *testing.T) { _ = h(c) assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "") } + +func TestRequestIDConfigDifferentHeader(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + + rid := RequestIDWithConfig(RequestIDConfig{TargetHeader: echo.HeaderXCorrelationID}) + h := rid(handler) + h(c) + assert.Len(t, rec.Header().Get(echo.HeaderXCorrelationID), 32) + + // Custom generator and handler + customID := "customGenerator" + calledHandler := false + rid = RequestIDWithConfig(RequestIDConfig{ + Generator: func() string { return customID }, + TargetHeader: echo.HeaderXCorrelationID, + RequestIDHandler: func(_ echo.Context, id string) { + calledHandler = true + assert.Equal(t, customID, id) + }, + }) + h = rid(handler) + h(c) + assert.Equal(t, rec.Header().Get(echo.HeaderXCorrelationID), "customGenerator") + assert.True(t, calledHandler) +}