diff --git a/module/apmhttp/client.go b/module/apmhttp/client.go index 2d0df4034..322f19232 100644 --- a/module/apmhttp/client.go +++ b/module/apmhttp/client.go @@ -70,6 +70,7 @@ type roundTripper struct { r http.RoundTripper requestName RequestNameFunc requestIgnorer RequestIgnorerFunc + traceRequests bool } // RoundTrip delegates to r.r, emitting a span if req's context @@ -102,9 +103,13 @@ func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { name := r.requestName(req) span := tx.StartSpan(name, "external.http", apm.SpanFromContext(ctx)) + var rt *requestTracer if !span.Dropped() { traceContext = span.TraceContext() ctx = apm.ContextWithSpan(ctx, span) + if r.traceRequests { + ctx, rt = withClientTrace(ctx, tx, span) + } req = RequestWithContext(ctx, req) span.Context.SetHTTPRequest(req) } else { @@ -116,10 +121,13 @@ func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { resp, err := r.r.RoundTrip(req) if span != nil { if err != nil { + if rt != nil { + rt.end() + } span.End() } else { span.Context.SetHTTPStatusCode(resp.StatusCode) - resp.Body = &responseBody{span: span, body: resp.Body} + resp.Body = &responseBody{span: span, body: resp.Body, requestTracer: rt} } } return resp, err @@ -157,8 +165,9 @@ func (r *roundTripper) CancelRequest(req *http.Request) { } type responseBody struct { - span *apm.Span - body io.ReadCloser + span *apm.Span + requestTracer *requestTracer + body io.ReadCloser } // Close closes the response body, and ends the span if it hasn't already been ended. @@ -180,6 +189,9 @@ func (b *responseBody) Read(p []byte) (n int, err error) { func (b *responseBody) endSpan() { addr := (*unsafe.Pointer)(unsafe.Pointer(&b.span)) if old := atomic.SwapPointer(addr, nil); old != nil { + if b.requestTracer != nil { + b.requestTracer.end() + } (*apm.Span)(old).End() } } diff --git a/module/apmhttp/client_test.go b/module/apmhttp/client_test.go index 2bd2980c5..4a28ffe44 100644 --- a/module/apmhttp/client_test.go +++ b/module/apmhttp/client_test.go @@ -21,6 +21,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "io/ioutil" "net" "net/http" @@ -311,6 +312,20 @@ func TestWithClientRequestName(t *testing.T) { assert.Equal(t, "http://test", span.Name) } +func TestWithClientTrace(t *testing.T) { + server := httptest.NewServer(http.NotFoundHandler()) + defer server.Close() + + _, spans, _ := apmtest.WithTransaction(func(ctx context.Context) { + mustGET(ctx, server.URL, apmhttp.WithClientTrace()) + }) + + require.Len(t, spans, 4) + assert.Equal(t, fmt.Sprintf(fmt.Sprintf("Connect %s", server.Listener.Addr())), spans[0].Name) + assert.Equal(t, "Request", spans[1].Name) + assert.Equal(t, "Response", spans[2].Name) +} + func mustGET(ctx context.Context, url string, o ...apmhttp.ClientOption) (statusCode int, responseBody string) { client := apmhttp.WrapClient(http.DefaultClient, o...) resp, err := ctxhttp.Get(ctx, client, url) diff --git a/module/apmhttp/clienttrace.go b/module/apmhttp/clienttrace.go new file mode 100644 index 000000000..8aba6c646 --- /dev/null +++ b/module/apmhttp/clienttrace.go @@ -0,0 +1,103 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package apmhttp + +import ( + "context" + "crypto/tls" + "fmt" + "net/http/httptrace" + "sync" + + "go.elastic.co/apm" +) + +// WithClientTrace returns a ClientOption for +// tracing events within HTTP client requests. +func WithClientTrace() ClientOption { + return func(rt *roundTripper) { + rt.traceRequests = true + } +} + +type connectKey struct { + network, addr string +} + +type requestTracer struct { + DNS, + TLS, + Request, + Response *apm.Span + + mu sync.RWMutex + Connects map[connectKey]*apm.Span +} + +func withClientTrace(ctx context.Context, tx *apm.Transaction, parent *apm.Span) (context.Context, *requestTracer) { + r := requestTracer{ + Connects: make(map[connectKey]*apm.Span), + } + + return httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ + DNSStart: func(i httptrace.DNSStartInfo) { + r.DNS = tx.StartSpan(fmt.Sprintf("DNS %s", i.Host), "http.dns", parent) + }, + + DNSDone: func(i httptrace.DNSDoneInfo) { + r.DNS.End() + }, + + ConnectStart: func(network, addr string) { + span := tx.StartSpan(fmt.Sprintf("Connect %s", addr), "http.connect", parent) + r.mu.Lock() + r.Connects[connectKey{network: network, addr: addr}] = span + r.mu.Unlock() + }, + + ConnectDone: func(network, addr string, err error) { + r.mu.RLock() + span := r.Connects[connectKey{network: network, addr: addr}] + r.mu.RUnlock() + span.End() + }, + + GotConn: func(info httptrace.GotConnInfo) { + r.Request = tx.StartSpan("Request", "http.request", parent) + }, + + TLSHandshakeStart: func() { + r.TLS = tx.StartSpan("TLS", "http.tls", parent) + }, + + TLSHandshakeDone: func(_ tls.ConnectionState, _ error) { + r.TLS.End() + }, + + GotFirstResponseByte: func() { + r.Request.End() + r.Response = tx.StartSpan("Response", "http.response", parent) + }, + }), &r +} + +func (r *requestTracer) end() { + if r.Response != nil { + r.Response.End() + } +}