Skip to content

Commit

Permalink
TrustedProxies: Add default IPv6 support and refactor (gin-gonic#2967)
Browse files Browse the repository at this point in the history
  • Loading branch information
Bisstocuz authored and daheige committed Apr 18, 2022
1 parent 226769e commit a442221
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 49 deletions.
52 changes: 8 additions & 44 deletions context.go
Expand Up @@ -757,10 +757,14 @@ func (c *Context) ClientIP() string {
}
}

remoteIP, trusted := c.RemoteIP()
// It also checks if the remoteIP is a trusted proxy or not.
// In order to perform this validation, it will see if the IP is contained within at least one of the CIDR blocks
// defined by Engine.SetTrustedProxies()
remoteIP := net.ParseIP(c.RemoteIP())
if remoteIP == nil {
return ""
}
trusted := c.engine.isTrustedProxy(remoteIP)

if trusted && c.engine.ForwardedByClientIP && c.engine.RemoteIPHeaders != nil {
for _, headerName := range c.engine.RemoteIPHeaders {
Expand All @@ -773,53 +777,13 @@ func (c *Context) ClientIP() string {
return remoteIP.String()
}

func (e *Engine) isTrustedProxy(ip net.IP) bool {
if e.trustedCIDRs != nil {
for _, cidr := range e.trustedCIDRs {
if cidr.Contains(ip) {
return true
}
}
}
return false
}

// RemoteIP parses the IP from Request.RemoteAddr, normalizes and returns the IP (without the port).
// It also checks if the remoteIP is a trusted proxy or not.
// In order to perform this validation, it will see if the IP is contained within at least one of the CIDR blocks
// defined by Engine.SetTrustedProxies()
func (c *Context) RemoteIP() (net.IP, bool) {
func (c *Context) RemoteIP() string {
ip, _, err := net.SplitHostPort(strings.TrimSpace(c.Request.RemoteAddr))
if err != nil {
return nil, false
}
remoteIP := net.ParseIP(ip)
if remoteIP == nil {
return nil, false
}

return remoteIP, c.engine.isTrustedProxy(remoteIP)
}

func (e *Engine) validateHeader(header string) (clientIP string, valid bool) {
if header == "" {
return "", false
}
items := strings.Split(header, ",")
for i := len(items) - 1; i >= 0; i-- {
ipStr := strings.TrimSpace(items[i])
ip := net.ParseIP(ipStr)
if ip == nil {
return "", false
}

// X-Forwarded-For is appended by proxy
// Check IPs in reverse order and stop when find untrusted proxy
if (i == 0) || (!e.isTrustedProxy(ip)) {
return ipStr, true
}
return ""
}
return
return ip
}

// ContentType returns the Content-Type header of the request.
Expand Down
10 changes: 9 additions & 1 deletion context_test.go
Expand Up @@ -12,6 +12,7 @@ import (
"html/template"
"io"
"mime/multipart"
"net"
"net/http"
"net/http/httptest"
"os"
Expand Down Expand Up @@ -1404,6 +1405,11 @@ func TestContextClientIP(t *testing.T) {
// Tests exercising the TrustedProxies functionality
resetContextForClientIPTests(c)

// IPv6 support
c.Request.RemoteAddr = "[::1]:12345"
assert.Equal(t, "20.20.20.20", c.ClientIP())

resetContextForClientIPTests(c)
// No trusted proxies
_ = c.engine.SetTrustedProxies([]string{})
c.engine.RemoteIPHeaders = []string{"X-Forwarded-For"}
Expand Down Expand Up @@ -1500,6 +1506,7 @@ func resetContextForClientIPTests(c *Context) {
c.Request.Header.Set("CF-Connecting-IP", "60.60.60.60")
c.Request.RemoteAddr = " 40.40.40.40:42123 "
c.engine.TrustedPlatform = ""
c.engine.trustedCIDRs = defaultTrustedCIDRs
c.engine.AppEngine = false
}

Expand Down Expand Up @@ -2051,7 +2058,8 @@ func TestRemoteIPFail(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("POST", "/", nil)
c.Request.RemoteAddr = "[:::]:80"
ip, trust := c.RemoteIP()
ip := net.ParseIP(c.RemoteIP())
trust := c.engine.isTrustedProxy(ip)
assert.Nil(t, ip)
assert.False(t, trust)
}
Expand Down
51 changes: 47 additions & 4 deletions gin.go
Expand Up @@ -11,7 +11,6 @@ import (
"net/http"
"os"
"path"
"reflect"
"strings"
"sync"

Expand All @@ -28,7 +27,16 @@ var (

var defaultPlatform string

var defaultTrustedCIDRs = []*net.IPNet{{IP: net.IP{0x0, 0x0, 0x0, 0x0}, Mask: net.IPMask{0x0, 0x0, 0x0, 0x0}}} // 0.0.0.0/0
var defaultTrustedCIDRs = []*net.IPNet{
{ // 0.0.0.0/0 (IPv4)
IP: net.IP{0x0, 0x0, 0x0, 0x0},
Mask: net.IPMask{0x0, 0x0, 0x0, 0x0},
},
{ // ::/0 (IPv6)
IP: net.IP{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
Mask: net.IPMask{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
},
}

// HandlerFunc defines the handler used by gin middleware as return value.
type HandlerFunc func(*Context)
Expand Down Expand Up @@ -399,9 +407,9 @@ func (engine *Engine) SetTrustedProxies(trustedProxies []string) error {
return engine.parseTrustedProxies()
}

// isUnsafeTrustedProxies compares Engine.trustedCIDRs and defaultTrustedCIDRs, it's not safe if equal (returns true)
// isUnsafeTrustedProxies checks if Engine.trustedCIDRs contains all IPs, it's not safe if it has (returns true)
func (engine *Engine) isUnsafeTrustedProxies() bool {
return reflect.DeepEqual(engine.trustedCIDRs, defaultTrustedCIDRs)
return engine.isTrustedProxy(net.ParseIP("0.0.0.0")) || engine.isTrustedProxy(net.ParseIP("::"))
}

// parseTrustedProxies parse Engine.trustedProxies to Engine.trustedCIDRs
Expand All @@ -411,6 +419,41 @@ func (engine *Engine) parseTrustedProxies() error {
return err
}

// isTrustedProxy will check whether the IP address is included in the trusted list according to Engine.trustedCIDRs
func (engine *Engine) isTrustedProxy(ip net.IP) bool {
if engine.trustedCIDRs == nil {
return false
}
for _, cidr := range engine.trustedCIDRs {
if cidr.Contains(ip) {
return true
}
}
return false
}

// validateHeader will parse X-Forwarded-For header and return the trusted client IP address
func (engine *Engine) validateHeader(header string) (clientIP string, valid bool) {
if header == "" {
return "", false
}
items := strings.Split(header, ",")
for i := len(items) - 1; i >= 0; i-- {
ipStr := strings.TrimSpace(items[i])
ip := net.ParseIP(ipStr)
if ip == nil {
break
}

// X-Forwarded-For is appended by proxy
// Check IPs in reverse order and stop when find untrusted proxy
if (i == 0) || (!engine.isTrustedProxy(ip)) {
return ipStr, true
}
}
return "", false
}

// parseIP parse a string representation of an IP and returns a net.IP with the
// minimum byte representation or nil if input is invalid.
func parseIP(ip string) net.IP {
Expand Down

0 comments on commit a442221

Please sign in to comment.