diff --git a/transport.go b/transport.go index da04a4c..9d27fd6 100644 --- a/transport.go +++ b/transport.go @@ -111,14 +111,25 @@ func NewFromAppsTransport(atr *AppsTransport, installationID int64) *Transport { // RoundTrip implements http.RoundTripper interface. func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { + reqBodyClosed := false + if req.Body != nil { + defer func() { + if !reqBodyClosed { + req.Body.Close() + } + }() + } + token, err := t.Token(req.Context()) if err != nil { return nil, err } - req.Header.Set("Authorization", "token "+token) - req.Header.Add("Accept", acceptHeader) // We add to "Accept" header to avoid overwriting existing req headers. - resp, err := t.tr.RoundTrip(req) + creq := cloneRequest(req) // per RoundTripper contract + creq.Header.Set("Authorization", "token "+token) + creq.Header.Add("Accept", acceptHeader) // We add to "Accept" header to avoid overwriting existing req headers. + reqBodyClosed = true // req.Body is assumed to be closed by the tr RoundTripper. + resp, err := t.tr.RoundTrip(creq) return resp, err } @@ -212,3 +223,17 @@ func GetReadWriter(i interface{}) (io.ReadWriter, error) { } return buf, nil } + +// cloneRequest returns a clone of the provided *http.Request. +// The clone is a shallow copy of the struct and its Header map. +func cloneRequest(r *http.Request) *http.Request { + // shallow copy of the struct + r2 := new(http.Request) + *r2 = *r + // deep copy of the Header + r2.Header = make(http.Header, len(r.Header)) + for k, s := range r.Header { + r2.Header[k] = append([]string(nil), s...) + } + return r2 +} diff --git a/transport_test.go b/transport_test.go index 005b28a..e84716e 100644 --- a/transport_test.go +++ b/transport_test.go @@ -9,6 +9,7 @@ import ( "net/http/httptest" "os" "strings" + "sync" "testing" "time" @@ -344,3 +345,37 @@ func TestRefreshTokenWithTrailingSlashBaseURL(t *testing.T) { t.Fatalf("Unexpected RoundTrip response code: %d", res.StatusCode) } } + +func TestRoundTripperContract(t *testing.T) { + tr := &Transport{ + token: &accessToken{ + ExpiresAt: time.Now().Add(1 * time.Hour), + Token: "42", + }, + mu: &sync.Mutex{}, + tr: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if auth := req.Header.Get("Authorization"); auth != "token 42" { + t.Errorf("got unexpected Authorization request header in parent RoundTripper: %q", auth) + } + return nil, nil + }), + } + req, err := http.NewRequest("GET", "http://localhost", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Authorization", "xxx") + _, err = tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + if accept := req.Header.Get("Authorization"); accept != "xxx" { + t.Errorf("got unexpected Authorization request header in caller: %q", accept) + } +} + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +}