Skip to content

Commit

Permalink
Merge pull request #61 from CAFxX/patch-1
Browse files Browse the repository at this point in the history
Respect the http.RoundTripper contract
  • Loading branch information
bradleyfalzon committed Mar 28, 2022
2 parents f6d76ac + c28e0c0 commit 6f00d4f
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 3 deletions.
31 changes: 28 additions & 3 deletions transport.go
Expand Up @@ -111,14 +111,25 @@ func NewFromAppsTransport(atr *AppsTransport, installationID int64) *Transport {

// RoundTrip implements http.RoundTripper interface.
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
reqBodyClosed := false
if req.Body != nil {
defer func() {
if !reqBodyClosed {
req.Body.Close()
}
}()
}

token, err := t.Token(req.Context())
if err != nil {
return nil, err
}

req.Header.Set("Authorization", "token "+token)
req.Header.Add("Accept", acceptHeader) // We add to "Accept" header to avoid overwriting existing req headers.
resp, err := t.tr.RoundTrip(req)
creq := cloneRequest(req) // per RoundTripper contract
creq.Header.Set("Authorization", "token "+token)
creq.Header.Add("Accept", acceptHeader) // We add to "Accept" header to avoid overwriting existing req headers.
reqBodyClosed = true // req.Body is assumed to be closed by the tr RoundTripper.
resp, err := t.tr.RoundTrip(creq)
return resp, err
}

Expand Down Expand Up @@ -212,3 +223,17 @@ func GetReadWriter(i interface{}) (io.ReadWriter, error) {
}
return buf, nil
}

// cloneRequest returns a clone of the provided *http.Request.
// The clone is a shallow copy of the struct and its Header map.
func cloneRequest(r *http.Request) *http.Request {
// shallow copy of the struct
r2 := new(http.Request)
*r2 = *r
// deep copy of the Header
r2.Header = make(http.Header, len(r.Header))
for k, s := range r.Header {
r2.Header[k] = append([]string(nil), s...)
}
return r2
}
35 changes: 35 additions & 0 deletions transport_test.go
Expand Up @@ -9,6 +9,7 @@ import (
"net/http/httptest"
"os"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -344,3 +345,37 @@ func TestRefreshTokenWithTrailingSlashBaseURL(t *testing.T) {
t.Fatalf("Unexpected RoundTrip response code: %d", res.StatusCode)
}
}

func TestRoundTripperContract(t *testing.T) {
tr := &Transport{
token: &accessToken{
ExpiresAt: time.Now().Add(1 * time.Hour),
Token: "42",
},
mu: &sync.Mutex{},
tr: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
if auth := req.Header.Get("Authorization"); auth != "token 42" {
t.Errorf("got unexpected Authorization request header in parent RoundTripper: %q", auth)
}
return nil, nil
}),
}
req, err := http.NewRequest("GET", "http://localhost", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Authorization", "xxx")
_, err = tr.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
if accept := req.Header.Get("Authorization"); accept != "xxx" {
t.Errorf("got unexpected Authorization request header in caller: %q", accept)
}
}

type roundTripperFunc func(*http.Request) (*http.Response, error)

func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return fn(req)
}

0 comments on commit 6f00d4f

Please sign in to comment.