Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Respect the http.RoundTripper contract #61

Merged
merged 1 commit into from Mar 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 28 additions & 3 deletions transport.go
Expand Up @@ -111,14 +111,25 @@ func NewFromAppsTransport(atr *AppsTransport, installationID int64) *Transport {

// RoundTrip implements http.RoundTripper interface.
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
reqBodyClosed := false
if req.Body != nil {
defer func() {
if !reqBodyClosed {
req.Body.Close()
}
}()
}

token, err := t.Token(req.Context())
if err != nil {
return nil, err
}

req.Header.Set("Authorization", "token "+token)
req.Header.Add("Accept", acceptHeader) // We add to "Accept" header to avoid overwriting existing req headers.
resp, err := t.tr.RoundTrip(req)
creq := cloneRequest(req) // per RoundTripper contract
creq.Header.Set("Authorization", "token "+token)
creq.Header.Add("Accept", acceptHeader) // We add to "Accept" header to avoid overwriting existing req headers.
reqBodyClosed = true // req.Body is assumed to be closed by the tr RoundTripper.
resp, err := t.tr.RoundTrip(creq)
return resp, err
}

Expand Down Expand Up @@ -212,3 +223,17 @@ func GetReadWriter(i interface{}) (io.ReadWriter, error) {
}
return buf, nil
}

// cloneRequest returns a clone of the provided *http.Request.
// The clone is a shallow copy of the struct and its Header map.
func cloneRequest(r *http.Request) *http.Request {
// shallow copy of the struct
r2 := new(http.Request)
*r2 = *r
// deep copy of the Header
r2.Header = make(http.Header, len(r.Header))
for k, s := range r.Header {
r2.Header[k] = append([]string(nil), s...)
}
return r2
}
35 changes: 35 additions & 0 deletions transport_test.go
Expand Up @@ -9,6 +9,7 @@ import (
"net/http/httptest"
"os"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -344,3 +345,37 @@ func TestRefreshTokenWithTrailingSlashBaseURL(t *testing.T) {
t.Fatalf("Unexpected RoundTrip response code: %d", res.StatusCode)
}
}

func TestRoundTripperContract(t *testing.T) {
tr := &Transport{
token: &accessToken{
ExpiresAt: time.Now().Add(1 * time.Hour),
Token: "42",
},
mu: &sync.Mutex{},
tr: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
if auth := req.Header.Get("Authorization"); auth != "token 42" {
t.Errorf("got unexpected Authorization request header in parent RoundTripper: %q", auth)
}
return nil, nil
}),
}
req, err := http.NewRequest("GET", "http://localhost", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Authorization", "xxx")
_, err = tr.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
if accept := req.Header.Get("Authorization"); accept != "xxx" {
t.Errorf("got unexpected Authorization request header in caller: %q", accept)
}
}

type roundTripperFunc func(*http.Request) (*http.Response, error)

func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return fn(req)
}