Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

trim trailing slashes from base url #57

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion transport.go
Expand Up @@ -8,6 +8,7 @@ import (
"io"
"io/ioutil"
"net/http"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -159,7 +160,8 @@ func (t *Transport) refreshToken(ctx context.Context) error {
return fmt.Errorf("could not convert installation token parameters into json: %s", err)
}

req, err := http.NewRequest("POST", fmt.Sprintf("%s/app/installations/%v/access_tokens", t.BaseURL, t.installationID), body)
requestURL := fmt.Sprintf("%s/app/installations/%v/access_tokens", strings.TrimRight(t.BaseURL, "/"), t.installationID)
req, err := http.NewRequest("POST", requestURL, body)
if err != nil {
return fmt.Errorf("could not create request: %s", err)
}
Expand Down
95 changes: 95 additions & 0 deletions transport_test.go
Expand Up @@ -8,6 +8,7 @@ import (
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -249,3 +250,97 @@ func TestRefreshTokenWithParameters(t *testing.T) {
t.Fatalf("error calling RoundTrip: %v", err)
}
}

func TestRefreshTokenWithTrailingSlashBaseURL(t *testing.T) {
installationTokenOptions := &github.InstallationTokenOptions{
RepositoryIDs: []int64{1234},
Permissions: &github.InstallationPermissions{
Contents: github.String("write"),
Issues: github.String("read"),
},
}

tokenToBe := "token_string"

// Convert io.ReadWriter to String without deleting body data.
wantBody, _ := GetReadWriter(installationTokenOptions)
wantBodyBytes := new(bytes.Buffer)
wantBodyBytes.ReadFrom(wantBody)
wantBodyString := wantBodyBytes.String()

roundTripper := RoundTrip{
rt: func(req *http.Request) (*http.Response, error) {
if strings.Contains(req.URL.Path, "//") {
return &http.Response{
Body: ioutil.NopCloser(strings.NewReader("Forbidden\n")),
StatusCode: 401,
}, fmt.Errorf("Got simulated 401 Github Forbidden response")
}

if req.URL.Path == "test_endpoint/" && req.Header.Get("Authorization") == fmt.Sprintf("token %s", tokenToBe) {
return &http.Response{
Body: ioutil.NopCloser(strings.NewReader("Beautiful\n")),
StatusCode: 200,
}, nil
}

// Convert io.ReadCloser to String without deleting body data.
var gotBodyBytes []byte
gotBodyBytes, _ = ioutil.ReadAll(req.Body)
req.Body = ioutil.NopCloser(bytes.NewBuffer(gotBodyBytes))
gotBodyString := string(gotBodyBytes)

// Compare request sent with request received.
if diff := cmp.Diff(wantBodyString, gotBodyString); diff != "" {
t.Errorf("HTTP body want->got: %s", diff)
}

// Return acceptable access token.
accessToken := accessToken{
Token: tokenToBe,
ExpiresAt: time.Now(),
Repositories: []github.Repository{{
ID: github.Int64(1234),
}},
Permissions: github.InstallationPermissions{
Contents: github.String("write"),
Issues: github.String("read"),
},
}
tokenReadWriter, err := GetReadWriter(accessToken)
if err != nil {
return nil, fmt.Errorf("error converting token into io.ReadWriter: %+v", err)
}
tokenBody := ioutil.NopCloser(tokenReadWriter)
return &http.Response{
Body: tokenBody,
StatusCode: 200,
}, nil
},
}

tr, err := New(roundTripper, appID, installationID, key)
if err != nil {
t.Fatal("unexpected error:", err)
}
tr.InstallationTokenOptions = installationTokenOptions
tr.BaseURL = "http://localhost/github/api/v3/"

// Convert InstallationTokenOptions into a ReadWriter to pass as an argument to http.NewRequest.
body, err := GetReadWriter(installationTokenOptions)
if err != nil {
t.Fatalf("error calling GetReadWriter: %v", err)
}

req, err := http.NewRequest("POST", "http://localhost/test_endpoint/", body)
if err != nil {
t.Fatal("unexpected error:", err)
}
res, err := tr.RoundTrip(req)
if err != nil {
t.Fatalf("error calling RoundTrip: %v", err)
}
if res.StatusCode != 200 {
t.Fatalf("Unexpected RoundTrip response code: %d", res.StatusCode)
}
}