Skip to content

Commit

Permalink
contrib/internal/httptrace: simplify client ip tag generation down to…
Browse files Browse the repository at this point in the history
… one function
  • Loading branch information
Hellzy committed Jun 22, 2022
1 parent 3eded75 commit 5f70211
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 62 deletions.
52 changes: 15 additions & 37 deletions contrib/internal/httptrace/httptrace.go
Expand Up @@ -92,33 +92,14 @@ func ippref(s string) *netaddr.IPPrefix {
return nil
}

// genClientIPSpanTags generates the client IP related span tags.
func genClientIPSpanTags(r *http.Request) []ddtrace.StartSpanOption {
tags := []ddtrace.StartSpanOption{}
ip, matches := getClientIP(r)
if matches == nil {
if ip.IsValid() {
tags = append(tags, tracer.Tag(ext.HTTPClientIP, ip.String()))
}
return tags
}
for _, hdr := range matches {
tags = append(tags, tracer.Tag(ext.HTTPRequestHeaders+"."+hdr, ip))
}
tags = append(tags, tracer.Tag(ext.MultipleIPHeaders, strings.Join(matches, ",")))
return tags
}

// getClientIP attempts to find the client IP address in the given request r.
// If several IP headers are present in the request, the returned IP is invalid and the list of all IP headers is
// returned. Otherwise, the returned list is nil.
// genClientIPSpanTags generates the client IP related tags that need to be added to the span.
// See https://datadoghq.atlassian.net/wiki/spaces/APS/pages/2118779066/Client+IP+addresses+resolution
func getClientIP(r *http.Request) (netaddr.IP, []string) {
func genClientIPSpanTags(r *http.Request) []ddtrace.StartSpanOption {
ipHeaders := defaultIPHeaders
if len(clientIPHeader) > 0 {
ipHeaders = []string{clientIPHeader}
}
check := func(s string) netaddr.IP {
checkIP := func(s string) netaddr.IP {
for _, ipstr := range strings.Split(s, ",") {
ip := parseIP(strings.TrimSpace(ipstr))
if !ip.IsValid() {
Expand All @@ -130,26 +111,23 @@ func getClientIP(r *http.Request) (netaddr.IP, []string) {
}
return netaddr.IP{}
}
matches := []string{}
var matchedIP netaddr.IP
headers := []string{}
opts := []ddtrace.StartSpanOption{}
var ip netaddr.IP
for _, hdr := range ipHeaders {
if v := r.Header.Get(hdr); v != "" {
matches = append(matches, hdr)
if ip := check(v); ip.IsValid() {
matchedIP = ip
}
headers = append(headers, hdr)
ip = checkIP(v)
}
}
if len(matches) == 1 {
return matchedIP, nil
}
if len(matches) > 1 {
return netaddr.IP{}, matches
}
if remoteIP := parseIP(r.RemoteAddr); remoteIP.IsValid() && isGlobal(remoteIP) {
return remoteIP, nil
if len(headers) == 1 && ip.IsValid() {
opts = append(opts, tracer.Tag(ext.HTTPClientIP, ip.String()))
} else if len(headers) > 1 {
opts = append(opts, tracer.Tag(ext.MultipleIPHeaders, strings.Join(headers, ",")))
} else if remoteIP := parseIP(r.RemoteAddr); remoteIP.IsValid() && isGlobal(remoteIP) {
opts = append(opts, tracer.Tag(ext.HTTPClientIP, remoteIP.String()))
}
return netaddr.IP{}, nil
return opts
}

func parseIP(s string) netaddr.IP {
Expand Down
62 changes: 37 additions & 25 deletions contrib/internal/httptrace/httptrace_test.go
Expand Up @@ -16,6 +16,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"gopkg.in/DataDog/dd-trace-go.v1/ddtrace"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer"
)

Expand All @@ -32,12 +34,12 @@ func TestStartRequestSpan(t *testing.T) {
}

type IPTestCase struct {
name string
remoteAddr string
headers map[string]string
expectedIP netaddr.IP
expectedMatches []string
clientIPHeader string
name string
remoteAddr string
headers map[string]string
expectedIP netaddr.IP
multiHeaders string
clientIPHeader string
}

func genIPTestCases() []IPTestCase {
Expand Down Expand Up @@ -108,16 +110,16 @@ func genIPTestCases() []IPTestCase {
expectedIP: netaddr.MustParseIP(ipv4Global),
},
{
name: "ipv4-multi-header-1",
headers: map[string]string{"x-forwarded-for": "127.0.0.1", "forwarded-for": ipv4Global},
expectedIP: netaddr.IP{},
expectedMatches: []string{"x-forwarded-for", "forwarded-for"},
name: "ipv4-multi-header-1",
headers: map[string]string{"x-forwarded-for": "127.0.0.1", "forwarded-for": ipv4Global},
expectedIP: netaddr.IP{},
multiHeaders: "x-forwarded-for,forwarded-for",
},
{
name: "ipv4-multi-header-2",
headers: map[string]string{"forwarded-for": ipv4Global, "x-forwarded-for": "127.0.0.1"},
expectedIP: netaddr.IP{},
expectedMatches: []string{"x-forwarded-for", "forwarded-for"},
name: "ipv4-multi-header-2",
headers: map[string]string{"forwarded-for": ipv4Global, "x-forwarded-for": "127.0.0.1"},
expectedIP: netaddr.IP{},
multiHeaders: "x-forwarded-for,forwarded-for",
},
{
name: "invalid-ipv6",
Expand All @@ -130,16 +132,16 @@ func genIPTestCases() []IPTestCase {
expectedIP: netaddr.MustParseIP(ipv6Global),
},
{
name: "ipv6-multi-header-1",
headers: map[string]string{"x-forwarded-for": "2001:0db8:2001:zzzz::", "forwarded-for": ipv6Global},
expectedIP: netaddr.IP{},
expectedMatches: []string{"x-forwarded-for", "forwarded-for"},
name: "ipv6-multi-header-1",
headers: map[string]string{"x-forwarded-for": "2001:0db8:2001:zzzz::", "forwarded-for": ipv6Global},
expectedIP: netaddr.IP{},
multiHeaders: "x-forwarded-for,forwarded-for",
},
{
name: "ipv6-multi-header-2",
headers: map[string]string{"forwarded-for": ipv6Global, "x-forwarded-for": "2001:0db8:2001:zzzz::"},
expectedIP: netaddr.IP{},
expectedMatches: []string{"x-forwarded-for", "forwarded-for"},
name: "ipv6-multi-header-2",
headers: map[string]string{"forwarded-for": ipv6Global, "x-forwarded-for": "2001:0db8:2001:zzzz::"},
expectedIP: netaddr.IP{},
multiHeaders: "x-forwarded-for,forwarded-for",
},
}, tcs...)
tcs = append([]IPTestCase{
Expand Down Expand Up @@ -180,9 +182,19 @@ func TestIPHeaders(t *testing.T) {
}
r := http.Request{Header: header, RemoteAddr: tc.remoteAddr}
clientIPHeader = tc.clientIPHeader
ip, matches := getClientIP(&r)
require.Equal(t, tc.expectedIP, ip)
require.Equal(t, tc.expectedMatches, matches)
cfg := ddtrace.StartSpanConfig{}
for _, opt := range genClientIPSpanTags(&r) {
opt(&cfg)
}
if tc.expectedIP.IsValid() {
require.Equal(t, tc.expectedIP.String(), cfg.Tags[ext.HTTPClientIP])
require.Nil(t, cfg.Tags[ext.MultipleIPHeaders])
} else {
require.Nil(t, cfg.Tags[ext.HTTPClientIP])
if tc.multiHeaders != "" {
require.Equal(t, tc.multiHeaders, cfg.Tags[ext.MultipleIPHeaders])
}
}
})
}
}
Expand Down

0 comments on commit 5f70211

Please sign in to comment.