diff --git a/contrib/labstack/echo.v4/echotrace.go b/contrib/labstack/echo.v4/echotrace.go index 6afa2562bc..1628711d9b 100644 --- a/contrib/labstack/echo.v4/echotrace.go +++ b/contrib/labstack/echo.v4/echotrace.go @@ -7,7 +7,11 @@ package echo import ( + "errors" + "fmt" "math" + "net/http" + "strconv" "gopkg.in/DataDog/dd-trace-go.v1/contrib/internal/httptrace" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace" @@ -60,23 +64,48 @@ func Middleware(opts ...Option) echo.MiddlewareFunc { span, ctx := httptrace.StartRequestSpan(request, opts...) defer func() { - httptrace.FinishRequestSpan(span, c.Response().Status, finishOpts...) + span.Finish(finishOpts...) }() // pass the span through the request context c.SetRequest(request.WithContext(ctx)) - // serve the request to the next middleware + if appsecEnabled { afterMiddleware := useAppSec(c, span) defer afterMiddleware() } + // serve the request to the next middleware err := next(c) if err != nil { - finishOpts = append(finishOpts, tracer.WithError(err)) // invokes the registered HTTP error handler c.Error(err) - } + // It is impossible to determine what the final status code of a request is in echo. + // This is the best we can do. + var echoErr *echo.HTTPError + if errors.As(err, &echoErr) { + if cfg.isStatusError(echoErr.Code) { + finishOpts = append(finishOpts, tracer.WithError(err)) + } + span.SetTag(ext.HTTPCode, strconv.Itoa(echoErr.Code)) + } else { + // Any error that is not an *echo.HTTPError will be treated as an error with 500 status code. + if cfg.isStatusError(500) { + finishOpts = append(finishOpts, tracer.WithError(err)) + } + span.SetTag(ext.HTTPCode, "500") + } + } else if status := c.Response().Status; status > 0 { + if cfg.isStatusError(status) { + finishOpts = append(finishOpts, tracer.WithError(fmt.Errorf("%d: %s", status, http.StatusText(status)))) + } + span.SetTag(ext.HTTPCode, strconv.Itoa(status)) + } else { + if cfg.isStatusError(200) { + finishOpts = append(finishOpts, tracer.WithError(fmt.Errorf("%d: %s", 200, http.StatusText(200)))) + } + span.SetTag(ext.HTTPCode, "200") + } return err } } diff --git a/contrib/labstack/echo.v4/echotrace_test.go b/contrib/labstack/echo.v4/echotrace_test.go index 52b2345d4f..c3894e1fbf 100644 --- a/contrib/labstack/echo.v4/echotrace_test.go +++ b/contrib/labstack/echo.v4/echotrace_test.go @@ -7,6 +7,7 @@ package echo import ( "errors" + "fmt" "io" "net/http" "net/http/httptest" @@ -78,7 +79,7 @@ func TestTrace200(t *testing.T) { assert.True(traced) spans := mt.FinishedSpans() - assert.Len(spans, 1) + require.Len(t, spans, 1) span := spans[0] assert.Equal("http.request", span.OperationName()) @@ -126,7 +127,7 @@ func TestTraceAnalytics(t *testing.T) { assert.True(traced) spans := mt.FinishedSpans() - assert.Len(spans, 1) + require.Len(t, spans, 1) span := spans[0] assert.Equal("http.request", span.OperationName()) @@ -173,7 +174,7 @@ func TestError(t *testing.T) { assert.True(traced) spans := mt.FinishedSpans() - assert.Len(spans, 1) + require.Len(t, spans, 1) span := spans[0] assert.Equal("http.request", span.OperationName()) @@ -213,7 +214,7 @@ func TestErrorHandling(t *testing.T) { assert.True(traced) spans := mt.FinishedSpans() - assert.Len(spans, 1) + require.Len(t, spans, 1) span := spans[0] assert.Equal("http.request", span.OperationName()) @@ -224,6 +225,116 @@ func TestErrorHandling(t *testing.T) { assert.Equal(ext.SpanKindServer, span.Tag(ext.SpanKind)) } +func TestStatusError(t *testing.T) { + for _, tt := range []struct { + isStatusError func(statusCode int) bool + err error + code string + handler func(c echo.Context) error + }{ + { + err: errors.New("oh no"), + code: "500", + handler: func(c echo.Context) error { + return errors.New("oh no") + }, + }, + { + err: echo.NewHTTPError(http.StatusInternalServerError, "my error message"), + code: "500", + handler: func(c echo.Context) error { + return echo.NewHTTPError(http.StatusInternalServerError, "my error message") + }, + }, + { + err: nil, + code: "400", + handler: func(c echo.Context) error { + return echo.NewHTTPError(http.StatusBadRequest, "my error message") + }, + }, + { + isStatusError: func(statusCode int) bool { return statusCode >= 400 && statusCode < 500 }, + err: nil, + code: "500", + handler: func(c echo.Context) error { + return errors.New("oh no") + }, + }, + { + isStatusError: func(statusCode int) bool { return statusCode >= 400 && statusCode < 500 }, + err: nil, + code: "500", + handler: func(c echo.Context) error { + return echo.NewHTTPError(http.StatusInternalServerError, "my error message") + }, + }, + { + isStatusError: func(statusCode int) bool { return statusCode >= 400 }, + err: echo.NewHTTPError(http.StatusBadRequest, "my error message"), + code: "400", + handler: func(c echo.Context) error { + return echo.NewHTTPError(http.StatusBadRequest, "my error message") + }, + }, + { + isStatusError: func(statusCode int) bool { return statusCode >= 200 }, + err: fmt.Errorf("201: Created"), + code: "201", + handler: func(c echo.Context) error { + c.JSON(201, map[string]string{"status": "ok", "type": "test"}) + return nil + }, + }, + { + isStatusError: func(statusCode int) bool { return statusCode >= 200 }, + err: fmt.Errorf("200: OK"), + code: "200", + handler: func(c echo.Context) error { + // It's not clear if unset (0) status is possible naturally, but we can simulate that situation. + c.Response().Status = 0 + return nil + }, + }, + } { + t.Run("", func(t *testing.T) { + assert := assert.New(t) + mt := mocktracer.Start() + defer mt.Stop() + + router := echo.New() + opts := []Option{WithServiceName("foobar")} + if tt.isStatusError != nil { + opts = append(opts, WithStatusCheck(tt.isStatusError)) + } + router.Use(Middleware(opts...)) + router.GET("/err", tt.handler) + r := httptest.NewRequest("GET", "/err", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, r) + + spans := mt.FinishedSpans() + require.Len(t, spans, 1) + span := spans[0] + assert.Equal("http.request", span.OperationName()) + assert.Equal(ext.SpanTypeWeb, span.Tag(ext.SpanType)) + assert.Equal("foobar", span.Tag(ext.ServiceName)) + assert.Contains(span.Tag(ext.ResourceName), "/err") + assert.Equal(tt.code, span.Tag(ext.HTTPCode)) + assert.Equal("GET", span.Tag(ext.HTTPMethod)) + err := span.Tag(ext.Error) + if tt.err != nil { + if !assert.NotNil(err) { + return + } + assert.Equal(tt.err.Error(), err.(error).Error()) + } else { + assert.Nil(err) + } + }) + } +} + func TestGetSpanNotInstrumented(t *testing.T) { assert := assert.New(t) router := echo.New() @@ -273,7 +384,7 @@ func TestNoDebugStack(t *testing.T) { assert.True(traced) spans := mt.FinishedSpans() - assert.Len(spans, 1) + require.Len(t, spans, 1) span := spans[0] assert.Equal(wantErr.Error(), span.Tag(ext.Error).(error).Error()) diff --git a/contrib/labstack/echo.v4/option.go b/contrib/labstack/echo.v4/option.go index ea6bd4f97e..d634b3136b 100644 --- a/contrib/labstack/echo.v4/option.go +++ b/contrib/labstack/echo.v4/option.go @@ -18,6 +18,7 @@ type config struct { analyticsRate float64 noDebugStack bool ignoreRequestFunc IgnoreRequestFunc + isStatusError func(statusCode int) bool } // Option represents an option that can be passed to Middleware. @@ -32,6 +33,7 @@ func defaults(cfg *config) { cfg.serviceName = svc } cfg.analyticsRate = math.NaN() + cfg.isStatusError = isServerError } // WithServiceName sets the given service name for the system. @@ -80,3 +82,15 @@ func WithIgnoreRequest(ignoreRequestFunc IgnoreRequestFunc) Option { cfg.ignoreRequestFunc = ignoreRequestFunc } } + +// WithStatusCheck specifies a function fn which reports whether the passed +// statusCode should be considered an error. +func WithStatusCheck(fn func(statusCode int) bool) Option { + return func(cfg *config) { + cfg.isStatusError = fn + } +} + +func isServerError(statusCode int) bool { + return statusCode >= 500 && statusCode < 600 +} diff --git a/contrib/labstack/echo/echotrace.go b/contrib/labstack/echo/echotrace.go index 5a827a437a..d1614b6766 100644 --- a/contrib/labstack/echo/echotrace.go +++ b/contrib/labstack/echo/echotrace.go @@ -7,7 +7,11 @@ package echo import ( + "errors" + "fmt" "math" + "net/http" + "strconv" "gopkg.in/DataDog/dd-trace-go.v1/contrib/internal/httptrace" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace" @@ -48,7 +52,8 @@ func Middleware(opts ...Option) echo.MiddlewareFunc { span, ctx := httptrace.StartRequestSpan(request, opts...) defer func() { - httptrace.FinishRequestSpan(span, c.Response().Status, finishOpts...) + //httptrace.FinishRequestSpan(span, c.Response().Status, finishOpts...) + span.Finish(finishOpts...) }() // pass the span through the request context @@ -57,11 +62,35 @@ func Middleware(opts ...Option) echo.MiddlewareFunc { // serve the request to the next middleware err := next(c) if err != nil { - finishOpts = append(finishOpts, tracer.WithError(err)) // invokes the registered HTTP error handler c.Error(err) - } + // It is impossible to determine what the final status code of a request is in echo. + // This is the best we can do. + var echoErr *echo.HTTPError + if errors.As(err, &echoErr) { + if cfg.isStatusError(echoErr.Code) { + finishOpts = append(finishOpts, tracer.WithError(err)) + } + span.SetTag(ext.HTTPCode, strconv.Itoa(echoErr.Code)) + } else { + // Any error that is not an *echo.HTTPError will be treated as an error with 500 status code. + if cfg.isStatusError(500) { + finishOpts = append(finishOpts, tracer.WithError(err)) + } + span.SetTag(ext.HTTPCode, "500") + } + } else if status := c.Response().Status; status > 0 { + if cfg.isStatusError(status) { + finishOpts = append(finishOpts, tracer.WithError(fmt.Errorf("%d: %s", status, http.StatusText(status)))) + } + span.SetTag(ext.HTTPCode, strconv.Itoa(status)) + } else { + if cfg.isStatusError(200) { + finishOpts = append(finishOpts, tracer.WithError(fmt.Errorf("%d: %s", 200, http.StatusText(200)))) + } + span.SetTag(ext.HTTPCode, "200") + } return err } } diff --git a/contrib/labstack/echo/echotrace_test.go b/contrib/labstack/echo/echotrace_test.go index 6f9b098798..625811a81d 100644 --- a/contrib/labstack/echo/echotrace_test.go +++ b/contrib/labstack/echo/echotrace_test.go @@ -7,6 +7,7 @@ package echo import ( "errors" + "fmt" "net/http" "net/http/httptest" "testing" @@ -17,6 +18,7 @@ import ( "github.com/labstack/echo" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestChildSpan(t *testing.T) { @@ -174,6 +176,7 @@ func TestError(t *testing.T) { assert.Equal("http.request", span.OperationName()) assert.Equal("foobar", span.Tag(ext.ServiceName)) assert.Equal("500", span.Tag(ext.HTTPCode)) + require.NotNil(t, span.Tag(ext.Error)) assert.Equal(wantErr.Error(), span.Tag(ext.Error).(error).Error()) assert.Equal("labstack/echo", span.Tag(ext.Component)) assert.Equal(ext.SpanKindServer, span.Tag(ext.SpanKind)) @@ -214,11 +217,121 @@ func TestErrorHandling(t *testing.T) { assert.Equal("http.request", span.OperationName()) assert.Equal("foobar", span.Tag(ext.ServiceName)) assert.Equal("500", span.Tag(ext.HTTPCode)) + require.NotNil(t, span.Tag(ext.Error)) assert.Equal(wantErr.Error(), span.Tag(ext.Error).(error).Error()) assert.Equal("labstack/echo", span.Tag(ext.Component)) assert.Equal(ext.SpanKindServer, span.Tag(ext.SpanKind)) } +func TestStatusError(t *testing.T) { + for _, tt := range []struct { + isStatusError func(statusCode int) bool + err error + code string + handler func(c echo.Context) error + }{ + { + err: errors.New("oh no"), + code: "500", + handler: func(c echo.Context) error { + return errors.New("oh no") + }, + }, + { + err: echo.NewHTTPError(http.StatusInternalServerError, "my error message"), + code: "500", + handler: func(c echo.Context) error { + return echo.NewHTTPError(http.StatusInternalServerError, "my error message") + }, + }, + { + err: nil, + code: "400", + handler: func(c echo.Context) error { + return echo.NewHTTPError(http.StatusBadRequest, "my error message") + }, + }, + { + isStatusError: func(statusCode int) bool { return statusCode >= 400 && statusCode < 500 }, + err: nil, + code: "500", + handler: func(c echo.Context) error { + return errors.New("oh no") + }, + }, + { + isStatusError: func(statusCode int) bool { return statusCode >= 400 && statusCode < 500 }, + err: nil, + code: "500", + handler: func(c echo.Context) error { + return echo.NewHTTPError(http.StatusInternalServerError, "my error message") + }, + }, + { + isStatusError: func(statusCode int) bool { return statusCode >= 400 }, + err: echo.NewHTTPError(http.StatusBadRequest, "my error message"), + code: "400", + handler: func(c echo.Context) error { + return echo.NewHTTPError(http.StatusBadRequest, "my error message") + }, + }, + { + isStatusError: func(statusCode int) bool { return statusCode >= 200 }, + err: fmt.Errorf("201: Created"), + code: "201", + handler: func(c echo.Context) error { + c.JSON(201, map[string]string{"status": "ok", "type": "test"}) + return nil + }, + }, + { + isStatusError: func(statusCode int) bool { return statusCode >= 200 }, + err: fmt.Errorf("200: OK"), + code: "200", + handler: func(c echo.Context) error { + // It's not clear if unset (0) status is possible naturally, but we can simulate that situation. + c.Response().Status = 0 + return nil + }, + }, + } { + t.Run("", func(t *testing.T) { + assert := assert.New(t) + mt := mocktracer.Start() + defer mt.Stop() + + router := echo.New() + opts := []Option{WithServiceName("foobar")} + if tt.isStatusError != nil { + opts = append(opts, WithStatusCheck(tt.isStatusError)) + } + router.Use(Middleware(opts...)) + router.GET("/err", tt.handler) + r := httptest.NewRequest("GET", "/err", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, r) + + spans := mt.FinishedSpans() + assert.Len(spans, 1) + span := spans[0] + assert.Equal("http.request", span.OperationName()) + assert.Equal(ext.SpanTypeWeb, span.Tag(ext.SpanType)) + assert.Equal("foobar", span.Tag(ext.ServiceName)) + assert.Contains(span.Tag(ext.ResourceName), "/err") + assert.Equal(tt.code, span.Tag(ext.HTTPCode)) + assert.Equal("GET", span.Tag(ext.HTTPMethod)) + err := span.Tag(ext.Error) + if tt.err != nil { + assert.NotNil(err) + require.NotNil(t, span.Tag(ext.Error)) + assert.Equal(tt.err.Error(), err.(error).Error()) + } else { + assert.Nil(err) + } + }) + } +} + func TestGetSpanNotInstrumented(t *testing.T) { assert := assert.New(t) router := echo.New() @@ -271,6 +384,7 @@ func TestNoDebugStack(t *testing.T) { assert.Len(spans, 1) span := spans[0] + require.NotNil(t, span.Tag(ext.Error)) assert.Equal(wantErr.Error(), span.Tag(ext.Error).(error).Error()) assert.Equal("", span.Tag(ext.ErrorStack)) assert.Equal("labstack/echo", span.Tag(ext.Component)) diff --git a/contrib/labstack/echo/option.go b/contrib/labstack/echo/option.go index 4af218c5bb..959a0627d6 100644 --- a/contrib/labstack/echo/option.go +++ b/contrib/labstack/echo/option.go @@ -16,6 +16,7 @@ type config struct { serviceName string analyticsRate float64 noDebugStack bool + isStatusError func(statusCode int) bool } // Option represents an option that can be passed to Middleware. @@ -31,6 +32,7 @@ func defaults(cfg *config) { } else { cfg.analyticsRate = math.NaN() } + cfg.isStatusError = isServerError } // WithServiceName sets the given service name for the system. @@ -71,3 +73,15 @@ func NoDebugStack() Option { cfg.noDebugStack = true } } + +// WithStatusCheck specifies a function fn which reports whether the passed +// statusCode should be considered an error. +func WithStatusCheck(fn func(statusCode int) bool) Option { + return func(cfg *config) { + cfg.isStatusError = fn + } +} + +func isServerError(statusCode int) bool { + return statusCode >= 500 && statusCode < 600 +}