Skip to content

Commit

Permalink
Allow to override the http.Client for RequestToken and AccessToken
Browse files Browse the repository at this point in the history
Before, the `http.DefaultClient` was always used.  This is e.g. not
sufficient, if the server being authenticated against uses TLS
certificates signed by a custom root CA.  In this case, a custom
`http.Client` needs to be passed.

This change follows the design of https://github.com/golang/oauth2
and allows passing the HTTP client via `context.Context`.

Since the user of a function that accepts a `context.Context` would
likely expect it also to be used for HTTP calls, we also pass it to
`http.NewRequest`.
  • Loading branch information
fdcds committed May 17, 2021
1 parent d117b84 commit 75b456f
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 15 deletions.
12 changes: 6 additions & 6 deletions config.go
Expand Up @@ -62,16 +62,16 @@ func NewClient(ctx context.Context, config *Config, token *Token) *http.Client {
// oauth_callback_confirmed is true. Returns the request token and secret
// (temporary credentials).
// See RFC 5849 2.1 Temporary Credentials.
func (c *Config) RequestToken() (requestToken, requestSecret string, err error) {
req, err := http.NewRequest("POST", c.Endpoint.RequestTokenURL, nil)
func (c *Config) RequestToken(ctx context.Context) (requestToken, requestSecret string, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.Endpoint.RequestTokenURL, nil)
if err != nil {
return "", "", err
}
err = newAuther(c).setRequestTokenAuthHeader(req)
if err != nil {
return "", "", err
}
resp, err := http.DefaultClient.Do(req)
resp, err := contextClient(ctx).Do(req)
if err != nil {
return "", "", err
}
Expand Down Expand Up @@ -139,16 +139,16 @@ func ParseAuthorizationCallback(req *http.Request) (requestToken, verifier strin
// Endpoint AccessTokenURL. Returns the access token and secret (token
// credentials).
// See RFC 5849 2.3 Token Credentials.
func (c *Config) AccessToken(requestToken, requestSecret, verifier string) (accessToken, accessSecret string, err error) {
req, err := http.NewRequest("POST", c.Endpoint.AccessTokenURL, nil)
func (c *Config) AccessToken(ctx context.Context, requestToken, requestSecret, verifier string) (accessToken, accessSecret string, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.Endpoint.AccessTokenURL, nil)
if err != nil {
return "", "", err
}
err = newAuther(c).setAccessTokenAuthHeader(req, requestToken, requestSecret, verifier)
if err != nil {
return "", "", err
}
resp, err := http.DefaultClient.Do(req)
resp, err := contextClient(ctx).Do(req)
if err != nil {
return "", "", err
}
Expand Down
18 changes: 9 additions & 9 deletions config_test.go
Expand Up @@ -105,7 +105,7 @@ func TestConfigRequestToken(t *testing.T) {
RequestTokenURL: server.URL,
},
}
requestToken, requestSecret, err := config.RequestToken()
requestToken, requestSecret, err := config.RequestToken(context.Background())
assert.Nil(t, err)
assert.Equal(t, expectedToken, requestToken)
assert.Equal(t, expectedSecret, requestSecret)
Expand All @@ -117,7 +117,7 @@ func TestConfigRequestToken_InvalidRequestTokenURL(t *testing.T) {
RequestTokenURL: "http://wrong.com/oauth/request_token",
},
}
requestToken, requestSecret, err := config.RequestToken()
requestToken, requestSecret, err := config.RequestToken(context.Background())
assert.NotNil(t, err)
assert.Equal(t, "", requestToken)
assert.Equal(t, "", requestSecret)
Expand All @@ -138,7 +138,7 @@ func TestConfigRequestToken_CallbackNotConfirmed(t *testing.T) {
RequestTokenURL: server.URL,
},
}
requestToken, requestSecret, err := config.RequestToken()
requestToken, requestSecret, err := config.RequestToken(context.Background())
if assert.Error(t, err) {
assert.Equal(t, "oauth1: oauth_callback_confirmed was not true", err.Error())
}
Expand All @@ -155,7 +155,7 @@ func TestConfigRequestToken_CannotParseBody(t *testing.T) {
RequestTokenURL: server.URL,
},
}
requestToken, requestSecret, err := config.RequestToken()
requestToken, requestSecret, err := config.RequestToken(context.Background())
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "invalid URL escape")
}
Expand All @@ -175,7 +175,7 @@ func TestConfigRequestToken_MissingTokenOrSecret(t *testing.T) {
RequestTokenURL: server.URL,
},
}
requestToken, requestSecret, err := config.RequestToken()
requestToken, requestSecret, err := config.RequestToken(context.Background())
if assert.Error(t, err) {
assert.Equal(t, "oauth1: Response missing oauth_token or oauth_token_secret", err.Error())
}
Expand Down Expand Up @@ -225,7 +225,7 @@ func TestConfigAccessToken(t *testing.T) {
AccessTokenURL: server.URL,
},
}
accessToken, accessSecret, err := config.AccessToken("request_token", "request_secret", expectedVerifier)
accessToken, accessSecret, err := config.AccessToken(context.Background(), "request_token", "request_secret", expectedVerifier)
assert.Nil(t, err)
assert.Equal(t, expectedToken, accessToken)
assert.Equal(t, expectedSecret, accessSecret)
Expand All @@ -237,7 +237,7 @@ func TestConfigAccessToken_InvalidAccessTokenURL(t *testing.T) {
AccessTokenURL: "http://wrong.com/oauth/access_token",
},
}
accessToken, accessSecret, err := config.AccessToken("any_token", "any_secret", "any_verifier")
accessToken, accessSecret, err := config.AccessToken(context.Background(), "any_token", "any_secret", "any_verifier")
assert.NotNil(t, err)
assert.Equal(t, "", accessToken)
assert.Equal(t, "", accessSecret)
Expand All @@ -252,7 +252,7 @@ func TestConfigAccessToken_CannotParseBody(t *testing.T) {
AccessTokenURL: server.URL,
},
}
accessToken, accessSecret, err := config.AccessToken("any_token", "any_secret", "any_verifier")
accessToken, accessSecret, err := config.AccessToken(context.Background(), "any_token", "any_secret", "any_verifier")
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "invalid URL escape")
}
Expand All @@ -271,7 +271,7 @@ func TestConfigAccessToken_MissingTokenOrSecret(t *testing.T) {
AccessTokenURL: server.URL,
},
}
accessToken, accessSecret, err := config.AccessToken("request_token", "request_secret", expectedVerifier)
accessToken, accessSecret, err := config.AccessToken(context.Background(), "request_token", "request_secret", expectedVerifier)
if assert.Error(t, err) {
assert.Equal(t, "oauth1: Response missing oauth_token or oauth_token_secret", err.Error())
}
Expand Down
9 changes: 9 additions & 0 deletions context.go
Expand Up @@ -21,3 +21,12 @@ func contextTransport(ctx context.Context) http.RoundTripper {
}
return nil
}

func contextClient(ctx context.Context) *http.Client {
if ctx != nil {
if client, ok := ctx.Value(HTTPClient).(*http.Client); ok {
return client
}
}
return http.DefaultClient
}

0 comments on commit 75b456f

Please sign in to comment.