Skip to content

Commit

Permalink
contrib/internal/httptrace: simplify getClientIP prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
Hellzy committed Jun 14, 2022
1 parent 277cb8f commit 9208ede
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 22 deletions.
16 changes: 8 additions & 8 deletions contrib/internal/httptrace/httptrace.go
Expand Up @@ -57,7 +57,7 @@ func StartRequestSpan(r *http.Request, opts ...ddtrace.StartSpanOption) (tracer.
tracer.Tag("http.host", r.Host),
}, opts...)
}
if ip := getClientIP(r.RemoteAddr, r.Header, clientIPHeader); ip.IsValid() {
if ip := getClientIP(r); ip.IsValid() {
opts = append(opts, tracer.Tag(ext.HTTPClientIP, ip.String()))
}
if spanctx, err := tracer.Extract(tracer.HTTPHeadersCarrier(r.Header)); err == nil {
Expand Down Expand Up @@ -92,7 +92,7 @@ func ippref(s string) *netaddr.IPPrefix {

// getClientIP uses the request headers to resolve the client IP. If a specific header to check is provided through
// DD_CLIENT_IP_HEADER, then only this header is checked.
func getClientIP(remoteAddr string, headers http.Header, clientIPHeader string) netaddr.IP {
func getClientIP(r *http.Request) netaddr.IP {
ipHeaders := defaultIPHeaders
if len(clientIPHeader) > 0 {
ipHeaders = []string{clientIPHeader}
Expand All @@ -110,13 +110,13 @@ func getClientIP(remoteAddr string, headers http.Header, clientIPHeader string)
return netaddr.IP{}
}
for _, hdr := range ipHeaders {
if v := headers.Get(hdr); v != "" {
if v := r.Header.Get(hdr); v != "" {
if ip := check(v); ip.IsValid() {
return ip
}
}
}
if remoteIP := parseIP(remoteAddr); remoteIP.IsValid() && isGlobal(remoteIP) {
if remoteIP := parseIP(r.RemoteAddr); remoteIP.IsValid() && isGlobal(remoteIP) {
return remoteIP
}
return netaddr.IP{}
Expand All @@ -136,14 +136,14 @@ func parseIP(s string) netaddr.IP {

func isGlobal(ip netaddr.IP) bool {
//IsPrivate also checks for ipv6 ULA
globalCheck := !ip.IsPrivate() && !ip.IsLoopback() && !ip.IsLinkLocalUnicast()
if !globalCheck || !ip.Is6() {
return globalCheck
isGlobal := !ip.IsPrivate() && !ip.IsLoopback() && !ip.IsLinkLocalUnicast()
if !isGlobal || !ip.Is6() {
return isGlobal
}
for _, n := range ipv6SpecialNetworks {
if n.Contains(ip) {
return false
}
}
return globalCheck
return isGlobal
}
32 changes: 18 additions & 14 deletions contrib/internal/httptrace/httptrace_test.go
Expand Up @@ -32,11 +32,11 @@ func TestStartRequestSpan(t *testing.T) {
}

type IPTestCase struct {
name string
remoteAddr string
headers map[string]string
expectedIP netaddr.IP
userIPHeader string
name string
remoteAddr string
headers map[string]string
expectedIP netaddr.IP
clientIPHeader string
}

func genIPTestCases() []IPTestCase {
Expand Down Expand Up @@ -148,30 +148,34 @@ func genIPTestCases() []IPTestCase {
headers: map[string]string{"X-fOrWaRdEd-FoR": ipv4Global},
},
{
name: "user-header",
expectedIP: netaddr.MustParseIP(ipv4Global),
headers: map[string]string{"x-forwarded-for": ipv6Global, "custom-header": ipv4Global},
userIPHeader: "custom-header",
name: "user-header",
expectedIP: netaddr.MustParseIP(ipv4Global),
headers: map[string]string{"x-forwarded-for": ipv6Global, "custom-header": ipv4Global},
clientIPHeader: "custom-header",
},
{
name: "user-header-not-found",
expectedIP: netaddr.IP{},
headers: map[string]string{"x-forwarded-for": ipv4Global},
userIPHeader: "custom-header",
name: "user-header-not-found",
expectedIP: netaddr.IP{},
headers: map[string]string{"x-forwarded-for": ipv4Global},
clientIPHeader: "custom-header",
},
}, tcs...)

return tcs
}

func TestIPHeaders(t *testing.T) {
// Make sure to restore the real value of clientIPHeader at the end of the test
defer func(s string) { clientIPHeader = s }(clientIPHeader)
for _, tc := range genIPTestCases() {
t.Run(tc.name, func(t *testing.T) {
header := http.Header{}
for k, v := range tc.headers {
header.Add(k, v)
}
require.Equal(t, tc.expectedIP.String(), getClientIP(tc.remoteAddr, header, tc.userIPHeader).String())
r := http.Request{Header: header, RemoteAddr: tc.remoteAddr}
clientIPHeader = tc.clientIPHeader
require.Equal(t, tc.expectedIP.String(), getClientIP(&r).String())
})
}
}
Expand Down

0 comments on commit 9208ede

Please sign in to comment.