diff --git a/httpx/private_ip_validator.go b/httpx/private_ip_validator.go index d54d03fc..7beceb18 100644 --- a/httpx/private_ip_validator.go +++ b/httpx/private_ip_validator.go @@ -8,8 +8,8 @@ import ( "net" "net/http" "net/url" - - "github.com/ory/x/stringsx" + "syscall" + "time" "github.com/pkg/errors" ) @@ -80,29 +80,54 @@ var _ http.RoundTripper = (*NoInternalIPRoundTripper)(nil) // NoInternalIPRoundTripper is a RoundTripper that disallows internal IP addresses. type NoInternalIPRoundTripper struct { - http.RoundTripper internalIPExceptions []string } func (n NoInternalIPRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) { - rt := http.DefaultTransport - if n.RoundTripper != nil { - rt = n.RoundTripper - } - incoming := IncomingRequestURL(request) incoming.RawQuery = "" incoming.RawFragment = "" for _, exception := range n.internalIPExceptions { if incoming.String() == exception { - return rt.RoundTrip(request) + return http.DefaultTransport.RoundTrip(request) } } - host, _, _ := net.SplitHostPort(request.Host) - if err := DisallowIPPrivateAddresses(stringsx.Coalesce(host, request.Host)); err != nil { - return nil, err - } + return NoInternalTransport.RoundTrip(request) +} + +var NoInternalDialer = &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + Control: func(network, address string, _ syscall.RawConn) error { + if !(network == "tcp4" || network == "tcp6") { + return ErrPrivateIPAddressDisallowed(fmt.Errorf("%s is not a safe network type", network)) + } + + host, _, err := net.SplitHostPort(address) + if err != nil { + return ErrPrivateIPAddressDisallowed(fmt.Errorf("%s is not a valid host/port pair: %s", address, err)) + } + + ip := net.ParseIP(host) + if ip == nil { + return ErrPrivateIPAddressDisallowed(fmt.Errorf("%s is not a valid IP address", host)) + } + + if ip.IsPrivate() || ip.IsLoopback() || ip.IsUnspecified() { + return ErrPrivateIPAddressDisallowed(fmt.Errorf("%s is not a public IP address", ip)) + } + + return nil + }, +} - return rt.RoundTrip(request) +var NoInternalTransport http.RoundTripper = &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: NoInternalDialer.DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, } diff --git a/httpx/private_ip_validator_test.go b/httpx/private_ip_validator_test.go index a50c98a2..c65a6cf2 100644 --- a/httpx/private_ip_validator_test.go +++ b/httpx/private_ip_validator_test.go @@ -4,10 +4,12 @@ package httpx import ( + "net" "net/http" "net/url" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -47,7 +49,7 @@ func (n noOpRoundTripper) RoundTrip(request *http.Request) (*http.Response, erro var _ http.RoundTripper = new(noOpRoundTripper) func TestAllowExceptions(t *testing.T) { - rt := &NoInternalIPRoundTripper{RoundTripper: new(noOpRoundTripper), internalIPExceptions: []string{"http://localhost/asdf"}} + rt := &NoInternalIPRoundTripper{internalIPExceptions: []string{"http://localhost/asdf"}} _, err := rt.RoundTrip(&http.Request{ Host: "localhost", @@ -56,7 +58,12 @@ func TestAllowExceptions(t *testing.T) { "Host": []string{"localhost"}, }, }) - require.NoError(t, err) + // assert that the error is eiher nil or a dial error. + if err != nil { + opErr := new(net.OpError) + require.ErrorAs(t, err, &opErr) + require.Equal(t, "dial", opErr.Op) + } _, err = rt.RoundTrip(&http.Request{ Host: "localhost", @@ -67,3 +74,52 @@ func TestAllowExceptions(t *testing.T) { }) require.Error(t, err) } + +func assertErrorContains(msg string) assert.ErrorAssertionFunc { + return func(t assert.TestingT, err error, i ...interface{}) bool { + if !assert.Error(t, err, i...) { + return false + } + return assert.Contains(t, err.Error(), msg) + } +} + +func TestNoInternalDialer(t *testing.T) { + for _, tt := range []struct { + name string + network string + address string + assertErr assert.ErrorAssertionFunc + }{{ + name: "TCP public is allowed", + network: "tcp", + address: "www.google.de:443", + assertErr: assert.NoError, + }, { + name: "TCP private is denied", + network: "tcp", + address: "localhost:443", + assertErr: assertErrorContains("is not a public IP address"), + }, { + name: "UDP public is denied", + network: "udp", + address: "www.google.de:443", + assertErr: assertErrorContains("not a safe network type"), + }, { + name: "UDP public is denied", + network: "udp", + address: "www.google.de:443", + assertErr: assertErrorContains("not a safe network type"), + }, { + name: "UNIX sockets are denied", + network: "unix", + address: "/etc/passwd", + assertErr: assertErrorContains("not a safe network type"), + }} { + + t.Run("case="+tt.name, func(t *testing.T) { + _, err := NoInternalDialer.Dial(tt.network, tt.address) + tt.assertErr(t, err) + }) + } +} diff --git a/httpx/resilient_client.go b/httpx/resilient_client.go index 81cbb467..04f1ef0f 100644 --- a/httpx/resilient_client.go +++ b/httpx/resilient_client.go @@ -117,7 +117,6 @@ func NewResilientClient(opts ...ResilientOptions) *retryablehttp.Client { if o.noInternalIPs == true { o.c.Transport = &NoInternalIPRoundTripper{ - RoundTripper: o.c.Transport, internalIPExceptions: o.internalIPExceptions, } } diff --git a/httpx/resilient_client_test.go b/httpx/resilient_client_test.go index bb9c8db7..f08364e4 100644 --- a/httpx/resilient_client_test.go +++ b/httpx/resilient_client_test.go @@ -17,7 +17,7 @@ import ( ) func TestNoPrivateIPs(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("Hello, world!")) })) t.Cleanup(ts.Close) @@ -46,7 +46,7 @@ func TestNoPrivateIPs(t *testing.T) { _, err := c.Get(destination) if !passes { require.Error(t, err) - assert.Contains(t, err.Error(), "is in the") + assert.Contains(t, err.Error(), "is not a public IP address") } else { require.NoError(t, err) } @@ -54,7 +54,7 @@ func TestNoPrivateIPs(t *testing.T) { } func TestClientWithTracer(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("Hello, world!")) })) t.Cleanup(ts.Close)