diff --git a/contrib/internal/httptrace/httptrace.go b/contrib/internal/httptrace/httptrace.go index 18bcd7c84a..1e3ae32d3e 100644 --- a/contrib/internal/httptrace/httptrace.go +++ b/contrib/internal/httptrace/httptrace.go @@ -92,64 +92,41 @@ func ippref(s string) *netaddr.IPPrefix { return nil } -// genClientIPSpanTags generates the client IP related span tags. -func genClientIPSpanTags(r *http.Request) []ddtrace.StartSpanOption { - tags := []ddtrace.StartSpanOption{} - ip, matches := getClientIP(r) - if matches == nil { - if ip.IsValid() { - tags = append(tags, tracer.Tag(ext.HTTPClientIP, ip.String())) - } - return tags - } - for _, hdr := range matches { - tags = append(tags, tracer.Tag(ext.HTTPRequestHeaders+"."+hdr, ip)) - } - tags = append(tags, tracer.Tag(ext.MultipleIPHeaders, strings.Join(matches, ","))) - return tags -} - -// getClientIP attempts to find the client IP address in the given request r. -// If several IP headers are present in the request, the returned IP is invalid and the list of all IP headers is -// returned. Otherwise, the returned list is nil. +// genClientIPSpanTags generates the client IP related tags that need to be added to the span. // See https://datadoghq.atlassian.net/wiki/spaces/APS/pages/2118779066/Client+IP+addresses+resolution -func getClientIP(r *http.Request) (netaddr.IP, []string) { +func genClientIPSpanTags(r *http.Request) []ddtrace.StartSpanOption { ipHeaders := defaultIPHeaders if len(clientIPHeader) > 0 { ipHeaders = []string{clientIPHeader} } - check := func(s string) netaddr.IP { - for _, ipstr := range strings.Split(s, ",") { - ip := parseIP(strings.TrimSpace(ipstr)) - if !ip.IsValid() { - continue - } - if isGlobal(ip) { - return ip - } - } - return netaddr.IP{} - } - matches := []string{} - var matchedIP netaddr.IP + var headers []string + var ips []string + var opts []ddtrace.StartSpanOption for _, hdr := range ipHeaders { if v := r.Header.Get(hdr); v != "" { - matches = append(matches, hdr) - if ip := check(v); ip.IsValid() { - matchedIP = ip - } + headers = append(headers, hdr) + ips = append(ips, v) } } - if len(matches) == 1 { - return matchedIP, nil - } - if len(matches) > 1 { - return netaddr.IP{}, matches - } - if remoteIP := parseIP(r.RemoteAddr); remoteIP.IsValid() && isGlobal(remoteIP) { - return remoteIP, nil + if len(ips) == 0 { + if remoteIP := parseIP(r.RemoteAddr); remoteIP.IsValid() && isGlobal(remoteIP) { + opts = append(opts, tracer.Tag(ext.HTTPClientIP, remoteIP.String())) + } + } else if len(ips) == 1 { + for _, ipstr := range strings.Split(ips[0], ",") { + ip := parseIP(strings.TrimSpace(ipstr)) + if ip.IsValid() && isGlobal(ip) { + opts = append(opts, tracer.Tag(ext.HTTPClientIP, ip.String())) + break + } + } + } else { + for i := range ips { + opts = append(opts, tracer.Tag(ext.HTTPRequestHeaders+"."+headers[i], ips[i])) + } + opts = append(opts, tracer.Tag(ext.MultipleIPHeaders, strings.Join(headers, ","))) } - return netaddr.IP{}, nil + return opts } func parseIP(s string) netaddr.IP { diff --git a/contrib/internal/httptrace/httptrace_test.go b/contrib/internal/httptrace/httptrace_test.go index e6df6282c7..a16d482553 100644 --- a/contrib/internal/httptrace/httptrace_test.go +++ b/contrib/internal/httptrace/httptrace_test.go @@ -16,6 +16,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer" ) @@ -32,12 +34,12 @@ func TestStartRequestSpan(t *testing.T) { } type IPTestCase struct { - name string - remoteAddr string - headers map[string]string - expectedIP netaddr.IP - expectedMatches []string - clientIPHeader string + name string + remoteAddr string + headers map[string]string + expectedIP netaddr.IP + multiHeaders string + clientIPHeader string } func genIPTestCases() []IPTestCase { @@ -108,16 +110,16 @@ func genIPTestCases() []IPTestCase { expectedIP: netaddr.MustParseIP(ipv4Global), }, { - name: "ipv4-multi-header-1", - headers: map[string]string{"x-forwarded-for": "127.0.0.1", "forwarded-for": ipv4Global}, - expectedIP: netaddr.IP{}, - expectedMatches: []string{"x-forwarded-for", "forwarded-for"}, + name: "ipv4-multi-header-1", + headers: map[string]string{"x-forwarded-for": "127.0.0.1", "forwarded-for": ipv4Global}, + expectedIP: netaddr.IP{}, + multiHeaders: "x-forwarded-for,forwarded-for", }, { - name: "ipv4-multi-header-2", - headers: map[string]string{"forwarded-for": ipv4Global, "x-forwarded-for": "127.0.0.1"}, - expectedIP: netaddr.IP{}, - expectedMatches: []string{"x-forwarded-for", "forwarded-for"}, + name: "ipv4-multi-header-2", + headers: map[string]string{"forwarded-for": ipv4Global, "x-forwarded-for": "127.0.0.1"}, + expectedIP: netaddr.IP{}, + multiHeaders: "x-forwarded-for,forwarded-for", }, { name: "invalid-ipv6", @@ -130,16 +132,16 @@ func genIPTestCases() []IPTestCase { expectedIP: netaddr.MustParseIP(ipv6Global), }, { - name: "ipv6-multi-header-1", - headers: map[string]string{"x-forwarded-for": "2001:0db8:2001:zzzz::", "forwarded-for": ipv6Global}, - expectedIP: netaddr.IP{}, - expectedMatches: []string{"x-forwarded-for", "forwarded-for"}, + name: "ipv6-multi-header-1", + headers: map[string]string{"x-forwarded-for": "2001:0db8:2001:zzzz::", "forwarded-for": ipv6Global}, + expectedIP: netaddr.IP{}, + multiHeaders: "x-forwarded-for,forwarded-for", }, { - name: "ipv6-multi-header-2", - headers: map[string]string{"forwarded-for": ipv6Global, "x-forwarded-for": "2001:0db8:2001:zzzz::"}, - expectedIP: netaddr.IP{}, - expectedMatches: []string{"x-forwarded-for", "forwarded-for"}, + name: "ipv6-multi-header-2", + headers: map[string]string{"forwarded-for": ipv6Global, "x-forwarded-for": "2001:0db8:2001:zzzz::"}, + expectedIP: netaddr.IP{}, + multiHeaders: "x-forwarded-for,forwarded-for", }, }, tcs...) tcs = append([]IPTestCase{ @@ -180,9 +182,19 @@ func TestIPHeaders(t *testing.T) { } r := http.Request{Header: header, RemoteAddr: tc.remoteAddr} clientIPHeader = tc.clientIPHeader - ip, matches := getClientIP(&r) - require.Equal(t, tc.expectedIP, ip) - require.Equal(t, tc.expectedMatches, matches) + cfg := ddtrace.StartSpanConfig{} + for _, opt := range genClientIPSpanTags(&r) { + opt(&cfg) + } + if tc.expectedIP.IsValid() { + require.Equal(t, tc.expectedIP.String(), cfg.Tags[ext.HTTPClientIP]) + require.Nil(t, cfg.Tags[ext.MultipleIPHeaders]) + } else { + require.Nil(t, cfg.Tags[ext.HTTPClientIP]) + if tc.multiHeaders != "" { + require.Equal(t, tc.multiHeaders, cfg.Tags[ext.MultipleIPHeaders]) + } + } }) } }