diff --git a/contrib/internal/httptrace/httptrace.go b/contrib/internal/httptrace/httptrace.go index 07ee8e2ed2..60a3c4d424 100644 --- a/contrib/internal/httptrace/httptrace.go +++ b/contrib/internal/httptrace/httptrace.go @@ -57,7 +57,7 @@ func StartRequestSpan(r *http.Request, opts ...ddtrace.StartSpanOption) (tracer. tracer.Tag("http.host", r.Host), }, opts...) } - if ip := getClientIP(r.RemoteAddr, r.Header, clientIPHeader); ip.IsValid() { + if ip := getClientIP(r); ip.IsValid() { opts = append(opts, tracer.Tag(ext.HTTPClientIP, ip.String())) } if spanctx, err := tracer.Extract(tracer.HTTPHeadersCarrier(r.Header)); err == nil { @@ -92,7 +92,7 @@ func ippref(s string) *netaddr.IPPrefix { // getClientIP uses the request headers to resolve the client IP. If a specific header to check is provided through // DD_CLIENT_IP_HEADER, then only this header is checked. -func getClientIP(remoteAddr string, headers http.Header, clientIPHeader string) netaddr.IP { +func getClientIP(r *http.Request) netaddr.IP { ipHeaders := defaultIPHeaders if len(clientIPHeader) > 0 { ipHeaders = []string{clientIPHeader} @@ -110,13 +110,13 @@ func getClientIP(remoteAddr string, headers http.Header, clientIPHeader string) return netaddr.IP{} } for _, hdr := range ipHeaders { - if v := headers.Get(hdr); v != "" { + if v := r.Header.Get(hdr); v != "" { if ip := check(v); ip.IsValid() { return ip } } } - if remoteIP := parseIP(remoteAddr); remoteIP.IsValid() && isGlobal(remoteIP) { + if remoteIP := parseIP(r.RemoteAddr); remoteIP.IsValid() && isGlobal(remoteIP) { return remoteIP } return netaddr.IP{} @@ -136,14 +136,14 @@ func parseIP(s string) netaddr.IP { func isGlobal(ip netaddr.IP) bool { //IsPrivate also checks for ipv6 ULA - globalCheck := !ip.IsPrivate() && !ip.IsLoopback() && !ip.IsLinkLocalUnicast() - if !globalCheck || !ip.Is6() { - return globalCheck + isGlobal := !ip.IsPrivate() && !ip.IsLoopback() && !ip.IsLinkLocalUnicast() + if !isGlobal || !ip.Is6() { + return isGlobal } for _, n := range ipv6SpecialNetworks { if n.Contains(ip) { return false } } - return globalCheck + return isGlobal } diff --git a/contrib/internal/httptrace/httptrace_test.go b/contrib/internal/httptrace/httptrace_test.go index e5d7e94dd8..391d5f8dfa 100644 --- a/contrib/internal/httptrace/httptrace_test.go +++ b/contrib/internal/httptrace/httptrace_test.go @@ -32,11 +32,11 @@ func TestStartRequestSpan(t *testing.T) { } type IPTestCase struct { - name string - remoteAddr string - headers map[string]string - expectedIP netaddr.IP - userIPHeader string + name string + remoteAddr string + headers map[string]string + expectedIP netaddr.IP + clientIPHeader string } func genIPTestCases() []IPTestCase { @@ -148,16 +148,16 @@ func genIPTestCases() []IPTestCase { headers: map[string]string{"X-fOrWaRdEd-FoR": ipv4Global}, }, { - name: "user-header", - expectedIP: netaddr.MustParseIP(ipv4Global), - headers: map[string]string{"x-forwarded-for": ipv6Global, "custom-header": ipv4Global}, - userIPHeader: "custom-header", + name: "user-header", + expectedIP: netaddr.MustParseIP(ipv4Global), + headers: map[string]string{"x-forwarded-for": ipv6Global, "custom-header": ipv4Global}, + clientIPHeader: "custom-header", }, { - name: "user-header-not-found", - expectedIP: netaddr.IP{}, - headers: map[string]string{"x-forwarded-for": ipv4Global}, - userIPHeader: "custom-header", + name: "user-header-not-found", + expectedIP: netaddr.IP{}, + headers: map[string]string{"x-forwarded-for": ipv4Global}, + clientIPHeader: "custom-header", }, }, tcs...) @@ -165,13 +165,17 @@ func genIPTestCases() []IPTestCase { } func TestIPHeaders(t *testing.T) { + // Make sure to restore the real value of clientIPHeader at the end of the test + defer func(s string) { clientIPHeader = s }(clientIPHeader) for _, tc := range genIPTestCases() { t.Run(tc.name, func(t *testing.T) { header := http.Header{} for k, v := range tc.headers { header.Add(k, v) } - require.Equal(t, tc.expectedIP.String(), getClientIP(tc.remoteAddr, header, tc.userIPHeader).String()) + r := http.Request{Header: header, RemoteAddr: tc.remoteAddr} + clientIPHeader = tc.clientIPHeader + require.Equal(t, tc.expectedIP.String(), getClientIP(&r).String()) }) } }