Skip to content

Commit

Permalink
Feature issue gorilla#479
Browse files Browse the repository at this point in the history
  • Loading branch information
WilliamHeaven committed Jun 15, 2020
1 parent b65e629 commit c4646c6
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 5 deletions.
5 changes: 4 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ type Dialer struct {
// If Jar is nil, cookies are not sent in requests and ignored
// in responses.
Jar http.CookieJar

// custom proxy connect header
ProxyConnectHeader http.Header
}

// Dial creates a new client connection by calling DialContext with a background context.
Expand Down Expand Up @@ -274,7 +277,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
return nil, nil, err
}
if proxyURL != nil {
dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial))
dialer, err := proxy_FromURL(proxyURL, &netDialer{d.ProxyConnectHeader,netDial})
if err != nil {
return nil, nil, err
}
Expand Down
7 changes: 7 additions & 0 deletions client_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ func TestProxyDial(t *testing.T) {

cstDialer := cstDialer // make local copy for modification on next line.
cstDialer.Proxy = http.ProxyURL(surl)
cstDialer.ProxyConnectHeader = map[string][]string{
"User-Agents": {"xxx"},
}

connect := false
origHandler := s.Server.Config.Handler
Expand All @@ -166,6 +169,10 @@ func TestProxyDial(t *testing.T) {
if r.Method == "CONNECT" {
connect = true
w.WriteHeader(http.StatusOK)
if r.Header.Get("User-Agents") != "xxx" {
t.Log("xxx not found in the request header")
http.Error(w, "header xxx not found", http.StatusMethodNotAllowed)
}
return
}

Expand Down
23 changes: 19 additions & 4 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,32 @@ import (
"strings"
)

type netDialerFunc func(network, addr string) (net.Conn, error)
// type netDialerFunc func(network, addr string) (net.Conn, error)
//
// func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
// return fn(network, addr)
// }
type netDialer struct {
proxyHeader http.Header
f func(network, addr string) (net.Conn, error)
}

func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
return fn(network, addr)
func (n netDialer) Dial(network, addr string) (net.Conn, error) {
return n.f(network, addr)
}

func init() {
proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil
p, _ := forwardDialer.(*netDialer)
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial, proxyHeader: p.proxyHeader}, nil
// return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil
})
}

type httpProxyDialer struct {
proxyURL *url.URL
forwardDial func(network, addr string) (net.Conn, error)
proxyHeader http.Header
}

func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) {
Expand All @@ -47,6 +58,10 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
}
}

for k, v := range hpd.proxyHeader {
connectHeader[k] = v
}

connectReq := &http.Request{
Method: "CONNECT",
URL: &url.URL{Opaque: addr},
Expand Down

0 comments on commit c4646c6

Please sign in to comment.