diff --git a/config.go b/config.go index a0124ff..0ae6712 100644 --- a/config.go +++ b/config.go @@ -62,8 +62,8 @@ func NewClient(ctx context.Context, config *Config, token *Token) *http.Client { // oauth_callback_confirmed is true. Returns the request token and secret // (temporary credentials). // See RFC 5849 2.1 Temporary Credentials. -func (c *Config) RequestToken() (requestToken, requestSecret string, err error) { - req, err := http.NewRequest("POST", c.Endpoint.RequestTokenURL, nil) +func (c *Config) RequestToken(ctx context.Context) (requestToken, requestSecret string, err error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.Endpoint.RequestTokenURL, nil) if err != nil { return "", "", err } @@ -71,7 +71,7 @@ func (c *Config) RequestToken() (requestToken, requestSecret string, err error) if err != nil { return "", "", err } - resp, err := http.DefaultClient.Do(req) + resp, err := contextClient(ctx).Do(req) if err != nil { return "", "", err } @@ -139,8 +139,8 @@ func ParseAuthorizationCallback(req *http.Request) (requestToken, verifier strin // Endpoint AccessTokenURL. Returns the access token and secret (token // credentials). // See RFC 5849 2.3 Token Credentials. -func (c *Config) AccessToken(requestToken, requestSecret, verifier string) (accessToken, accessSecret string, err error) { - req, err := http.NewRequest("POST", c.Endpoint.AccessTokenURL, nil) +func (c *Config) AccessToken(ctx context.Context, requestToken, requestSecret, verifier string) (accessToken, accessSecret string, err error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.Endpoint.AccessTokenURL, nil) if err != nil { return "", "", err } @@ -148,7 +148,7 @@ func (c *Config) AccessToken(requestToken, requestSecret, verifier string) (acce if err != nil { return "", "", err } - resp, err := http.DefaultClient.Do(req) + resp, err := contextClient(ctx).Do(req) if err != nil { return "", "", err } diff --git a/config_test.go b/config_test.go index 914f9f9..c5cc33d 100644 --- a/config_test.go +++ b/config_test.go @@ -105,7 +105,7 @@ func TestConfigRequestToken(t *testing.T) { RequestTokenURL: server.URL, }, } - requestToken, requestSecret, err := config.RequestToken() + requestToken, requestSecret, err := config.RequestToken(context.Background()) assert.Nil(t, err) assert.Equal(t, expectedToken, requestToken) assert.Equal(t, expectedSecret, requestSecret) @@ -117,7 +117,7 @@ func TestConfigRequestToken_InvalidRequestTokenURL(t *testing.T) { RequestTokenURL: "http://wrong.com/oauth/request_token", }, } - requestToken, requestSecret, err := config.RequestToken() + requestToken, requestSecret, err := config.RequestToken(context.Background()) assert.NotNil(t, err) assert.Equal(t, "", requestToken) assert.Equal(t, "", requestSecret) @@ -138,7 +138,7 @@ func TestConfigRequestToken_CallbackNotConfirmed(t *testing.T) { RequestTokenURL: server.URL, }, } - requestToken, requestSecret, err := config.RequestToken() + requestToken, requestSecret, err := config.RequestToken(context.Background()) if assert.Error(t, err) { assert.Equal(t, "oauth1: oauth_callback_confirmed was not true", err.Error()) } @@ -155,7 +155,7 @@ func TestConfigRequestToken_CannotParseBody(t *testing.T) { RequestTokenURL: server.URL, }, } - requestToken, requestSecret, err := config.RequestToken() + requestToken, requestSecret, err := config.RequestToken(context.Background()) if assert.Error(t, err) { assert.Contains(t, err.Error(), "invalid URL escape") } @@ -175,7 +175,7 @@ func TestConfigRequestToken_MissingTokenOrSecret(t *testing.T) { RequestTokenURL: server.URL, }, } - requestToken, requestSecret, err := config.RequestToken() + requestToken, requestSecret, err := config.RequestToken(context.Background()) if assert.Error(t, err) { assert.Equal(t, "oauth1: Response missing oauth_token or oauth_token_secret", err.Error()) } @@ -225,7 +225,7 @@ func TestConfigAccessToken(t *testing.T) { AccessTokenURL: server.URL, }, } - accessToken, accessSecret, err := config.AccessToken("request_token", "request_secret", expectedVerifier) + accessToken, accessSecret, err := config.AccessToken(context.Background(), "request_token", "request_secret", expectedVerifier) assert.Nil(t, err) assert.Equal(t, expectedToken, accessToken) assert.Equal(t, expectedSecret, accessSecret) @@ -237,7 +237,7 @@ func TestConfigAccessToken_InvalidAccessTokenURL(t *testing.T) { AccessTokenURL: "http://wrong.com/oauth/access_token", }, } - accessToken, accessSecret, err := config.AccessToken("any_token", "any_secret", "any_verifier") + accessToken, accessSecret, err := config.AccessToken(context.Background(), "any_token", "any_secret", "any_verifier") assert.NotNil(t, err) assert.Equal(t, "", accessToken) assert.Equal(t, "", accessSecret) @@ -252,7 +252,7 @@ func TestConfigAccessToken_CannotParseBody(t *testing.T) { AccessTokenURL: server.URL, }, } - accessToken, accessSecret, err := config.AccessToken("any_token", "any_secret", "any_verifier") + accessToken, accessSecret, err := config.AccessToken(context.Background(), "any_token", "any_secret", "any_verifier") if assert.Error(t, err) { assert.Contains(t, err.Error(), "invalid URL escape") } @@ -271,7 +271,7 @@ func TestConfigAccessToken_MissingTokenOrSecret(t *testing.T) { AccessTokenURL: server.URL, }, } - accessToken, accessSecret, err := config.AccessToken("request_token", "request_secret", expectedVerifier) + accessToken, accessSecret, err := config.AccessToken(context.Background(), "request_token", "request_secret", expectedVerifier) if assert.Error(t, err) { assert.Equal(t, "oauth1: Response missing oauth_token or oauth_token_secret", err.Error()) } diff --git a/context.go b/context.go index 25e9394..251b2d3 100644 --- a/context.go +++ b/context.go @@ -21,3 +21,12 @@ func contextTransport(ctx context.Context) http.RoundTripper { } return nil } + +func contextClient(ctx context.Context) *http.Client { + if ctx != nil { + if client, ok := ctx.Value(HTTPClient).(*http.Client); ok { + return client + } + } + return http.DefaultClient +}