diff --git a/contrib/gorilla/mux/mux_test.go b/contrib/gorilla/mux/mux_test.go index 08844fedff..2c8774ab59 100644 --- a/contrib/gorilla/mux/mux_test.go +++ b/contrib/gorilla/mux/mux_test.go @@ -384,4 +384,24 @@ func TestAppSec(t *testing.T) { require.True(t, strings.Contains(event, "myPathParam2")) require.True(t, strings.Contains(event, "server.request.path_params")) }) + + t.Run("response-status", func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + req, err := http.NewRequest("POST", srv.URL+"/etc/", nil) + if err != nil { + panic(err) + } + res, err := srv.Client().Do(req) + require.NoError(t, err) + require.Equal(t, 404, res.StatusCode) + + finished := mt.FinishedSpans() + require.Len(t, finished, 1) + event := finished[0].Tag("_dd.appsec.json").(string) + require.NotNil(t, event) + require.True(t, strings.Contains(event, "server.response.status")) + require.True(t, strings.Contains(event, "nfd-000-001")) + }) } diff --git a/contrib/labstack/echo.v4/appsec.go b/contrib/labstack/echo.v4/appsec.go index e67698c332..1b12f4b8d7 100644 --- a/contrib/labstack/echo.v4/appsec.go +++ b/contrib/labstack/echo.v4/appsec.go @@ -14,30 +14,23 @@ import ( "github.com/labstack/echo/v4" ) -func withAppSec(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - req := c.Request() - span, ok := tracer.SpanFromContext(req.Context()) - if !ok { - return next(c) - } - httpsec.SetAppSecTags(span) - params := make(map[string]string) - for _, n := range c.ParamNames() { - params[n] = c.Param(n) - } - args := httpsec.MakeHandlerOperationArgs(req, params) - op := httpsec.StartOperation(args, nil) - defer func() { - events := op.Finish(httpsec.HandlerOperationRes{Status: c.Response().Status}) - if len(events) > 0 { - remoteIP, _, err := net.SplitHostPort(req.RemoteAddr) - if err != nil { - remoteIP = req.RemoteAddr - } - httpsec.SetSecurityEventTags(span, events, remoteIP, args.Headers, c.Response().Writer.Header()) +func useAppSec(c echo.Context, span tracer.Span) func() { + req := c.Request() + httpsec.SetAppSecTags(span) + params := make(map[string]string) + for _, n := range c.ParamNames() { + params[n] = c.Param(n) + } + args := httpsec.MakeHandlerOperationArgs(req, params) + op := httpsec.StartOperation(args, nil) + return func() { + events := op.Finish(httpsec.HandlerOperationRes{Status: c.Response().Status}) + if len(events) > 0 { + remoteIP, _, err := net.SplitHostPort(req.RemoteAddr) + if err != nil { + remoteIP = req.RemoteAddr } - }() - return next(c) + httpsec.SetSecurityEventTags(span, events, remoteIP, args.Headers, c.Response().Writer.Header()) + } } } diff --git a/contrib/labstack/echo.v4/echotrace.go b/contrib/labstack/echo.v4/echotrace.go index 0b291de7f7..fc16489626 100644 --- a/contrib/labstack/echo.v4/echotrace.go +++ b/contrib/labstack/echo.v4/echotrace.go @@ -20,15 +20,13 @@ import ( // Middleware returns echo middleware which will trace incoming requests. func Middleware(opts ...Option) echo.MiddlewareFunc { + appsecEnabled := appsec.Enabled() cfg := new(config) defaults(cfg) for _, fn := range opts { fn(cfg) } return func(next echo.HandlerFunc) echo.HandlerFunc { - if appsec.Enabled() { - next = withAppSec(next) - } return func(c echo.Context) error { request := c.Request() resource := request.Method + " " + c.Path() @@ -57,6 +55,10 @@ func Middleware(opts ...Option) echo.MiddlewareFunc { // pass the span through the request context c.SetRequest(request.WithContext(ctx)) // serve the request to the next middleware + if appsecEnabled { + afterMiddleware := useAppSec(c, span) + defer afterMiddleware() + } err := next(c) if err != nil { finishOpts = append(finishOpts, tracer.WithError(err)) diff --git a/contrib/labstack/echo.v4/echotrace_test.go b/contrib/labstack/echo.v4/echotrace_test.go index 973b976f85..334735170b 100644 --- a/contrib/labstack/echo.v4/echotrace_test.go +++ b/contrib/labstack/echo.v4/echotrace_test.go @@ -367,5 +367,25 @@ func TestAppSec(t *testing.T) { require.False(t, strings.Contains(event, "myPathParam3")) require.True(t, strings.Contains(event, "server.request.path_params")) }) + + t.Run("response-status", func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + req, err := http.NewRequest("POST", srv.URL+"/etc/", nil) + if err != nil { + panic(err) + } + res, err := srv.Client().Do(req) + require.NoError(t, err) + require.Equal(t, 404, res.StatusCode) + + finished := mt.FinishedSpans() + require.Len(t, finished, 1) + event := finished[0].Tag("_dd.appsec.json").(string) + require.NotNil(t, event) + require.True(t, strings.Contains(event, "server.response.status")) + require.True(t, strings.Contains(event, "nfd-000-001")) + }) }) }