From 27aac5d2f921e4049f7f1bd51bf43ee16627c049 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Mazeau?= Date: Fri, 10 Jun 2022 14:58:41 +0200 Subject: [PATCH] contrib/internal/httptrace: code review fix - Improve code wrt style guidelines - Rework error checking in parseIP - Use net.SplitHostPort instead of local function - Minor nits --- contrib/internal/httptrace/httptrace.go | 54 ++++++++++--------------- 1 file changed, 21 insertions(+), 33 deletions(-) diff --git a/contrib/internal/httptrace/httptrace.go b/contrib/internal/httptrace/httptrace.go index a5558ee63c..07ee8e2ed2 100644 --- a/contrib/internal/httptrace/httptrace.go +++ b/contrib/internal/httptrace/httptrace.go @@ -10,6 +10,7 @@ package httptrace import ( "context" "fmt" + "net" "net/http" "os" "strconv" @@ -96,8 +97,8 @@ func getClientIP(remoteAddr string, headers http.Header, clientIPHeader string) if len(clientIPHeader) > 0 { ipHeaders = []string{clientIPHeader} } - check := func(value string) netaddr.IP { - for _, ipstr := range strings.Split(value, ",") { + check := func(s string) netaddr.IP { + for _, ipstr := range strings.Split(s, ",") { ip := parseIP(strings.TrimSpace(ipstr)) if !ip.IsValid() { continue @@ -109,13 +110,12 @@ func getClientIP(remoteAddr string, headers http.Header, clientIPHeader string) return netaddr.IP{} } for _, hdr := range ipHeaders { - if value := headers.Get(hdr); value != "" { - if ip := check(value); ip.IsValid() { + if v := headers.Get(hdr); v != "" { + if ip := check(v); ip.IsValid() { return ip } } } - if remoteIP := parseIP(remoteAddr); remoteIP.IsValid() && isGlobal(remoteIP) { return remoteIP } @@ -123,39 +123,27 @@ func getClientIP(remoteAddr string, headers http.Header, clientIPHeader string) } func parseIP(s string) netaddr.IP { - ip, err := netaddr.ParseIP(s) - if err != nil { - h, _ := splitHostPort(s) - ip, err = netaddr.ParseIP(h) + if ip, err := netaddr.ParseIP(s); err == nil { + return ip } - return ip -} - -func isGlobal(ip netaddr.IP) bool { - if ip.Is6() { - for _, network := range ipv6SpecialNetworks { - if network.Contains(ip) { - return false - } + if h, _, err := net.SplitHostPort(s); err == nil { + if ip, err := netaddr.ParseIP(h); err == nil { + return ip } } - //IsPrivate also checks for ipv6 ULA - return !ip.IsPrivate() && !ip.IsLoopback() && !ip.IsLinkLocalUnicast() + return netaddr.IP{} } -// SplitHostPort splits a network address of the form `host:port` or -// `[host]:port` into `host` and `port`. -func splitHostPort(addr string) (host string, port string) { - i := strings.LastIndex(addr, "]:") - if i != -1 { - // ipv6 - return strings.Trim(addr[:i+1], "[]"), addr[i+2:] +func isGlobal(ip netaddr.IP) bool { + //IsPrivate also checks for ipv6 ULA + globalCheck := !ip.IsPrivate() && !ip.IsLoopback() && !ip.IsLinkLocalUnicast() + if !globalCheck || !ip.Is6() { + return globalCheck } - - i = strings.LastIndex(addr, ":") - if i == -1 { - // not an address with a port number - return addr, "" + for _, n := range ipv6SpecialNetworks { + if n.Contains(ip) { + return false + } } - return addr[:i], addr[i+1:] + return globalCheck }