diff --git a/device/device_flow.go b/device/device_flow.go index c665a5a..f14eb43 100644 --- a/device/device_flow.go +++ b/device/device_flow.go @@ -51,9 +51,6 @@ type CodeResponse struct { // The minimum number of seconds that must pass before you can make a new access token request to // complete the device authorization. Interval int - - timeNow func() time.Time - timeSleep func(time.Duration) } // RequestCode initiates the authorization flow by requesting a code from uri. @@ -67,6 +64,10 @@ func RequestCode(c httpClient, uri string, clientID string, scopes []string) (*C } verificationURI := resp.Get("verification_uri") + if verificationURI == "" { + // Google's "OAuth 2.0 for TV and Limited-Input Device Applications" uses `verification_url`. + verificationURI = resp.Get("verification_url") + } if resp.StatusCode == 401 || resp.StatusCode == 403 || resp.StatusCode == 404 || resp.StatusCode == 422 || (resp.StatusCode == 200 && verificationURI == "") || @@ -98,30 +99,66 @@ func RequestCode(c httpClient, uri string, clientID string, scopes []string) (*C }, nil } -const grantType = "urn:ietf:params:oauth:grant-type:device_code" +const defaultGrantType = "urn:ietf:params:oauth:grant-type:device_code" // PollToken polls the server at pollURL until an access token is granted or denied. +// +// Deprecated: use PollTokenWithOptions. func PollToken(c httpClient, pollURL string, clientID string, code *CodeResponse) (*api.AccessToken, error) { - timeNow := code.timeNow + return PollTokenWithOptions(c, pollURL, PollOptions{ + ClientID: clientID, + DeviceCode: code, + }) +} + +// PollOptions specifies parameters to poll the server with until authentication completes. +type PollOptions struct { + // ClientID is the app client ID value. + ClientID string + // ClientSecret is the app client secret value. Optional: only pass if the server requires it. + ClientSecret string + // DeviceCode is the value obtained from RequestCode. + DeviceCode *CodeResponse + // GrantType overrides the default value specified by OAuth 2.0 Device Code. Optional. + GrantType string + + timeNow func() time.Time + timeSleep func(time.Duration) +} + +// PollTokenWithOptions polls the server at uri until authorization completes. +func PollTokenWithOptions(c httpClient, uri string, opts PollOptions) (*api.AccessToken, error) { + timeNow := opts.timeNow if timeNow == nil { timeNow = time.Now } - timeSleep := code.timeSleep + timeSleep := opts.timeSleep if timeSleep == nil { timeSleep = time.Sleep } - checkInterval := time.Duration(code.Interval) * time.Second - expiresAt := timeNow().Add(time.Duration(code.ExpiresIn) * time.Second) + checkInterval := time.Duration(opts.DeviceCode.Interval) * time.Second + expiresAt := timeNow().Add(time.Duration(opts.DeviceCode.ExpiresIn) * time.Second) + grantType := opts.GrantType + if opts.GrantType == "" { + grantType = defaultGrantType + } for { timeSleep(checkInterval) - resp, err := api.PostForm(c, pollURL, url.Values{ - "client_id": {clientID}, - "device_code": {code.DeviceCode}, + values := url.Values{ + "client_id": {opts.ClientID}, + "device_code": {opts.DeviceCode.DeviceCode}, "grant_type": {grantType}, - }) + } + + // Google's "OAuth 2.0 for TV and Limited-Input Device Applications" requires `client_secret`. + if opts.ClientSecret != "" { + values.Add("client_secret", opts.ClientSecret) + } + + resp, err := api.PostForm(c, uri, values) if err != nil { return nil, err } diff --git a/device/device_flow_test.go b/device/device_flow_test.go index ba93139..d8fb4f3 100644 --- a/device/device_flow_test.go +++ b/device/device_flow_test.go @@ -249,10 +249,9 @@ func TestPollToken(t *testing.T) { } type args struct { - http apiClient - url string - clientID string - code *CodeResponse + http apiClient + url string + opts PollOptions } tests := []struct { name string @@ -279,16 +278,18 @@ func TestPollToken(t *testing.T) { }, }, }, - url: "https://github.com/oauth", - clientID: "CLIENT-ID", - code: &CodeResponse{ - DeviceCode: "DEVIC", - UserCode: "123-abc", - VerificationURI: "http://verify.me", - ExpiresIn: 99, - Interval: 5, - timeSleep: mockSleep, - timeNow: clock("0", "5s", "10s"), + url: "https://github.com/oauth", + opts: PollOptions{ + ClientID: "CLIENT-ID", + DeviceCode: &CodeResponse{ + DeviceCode: "DEVIC", + UserCode: "123-abc", + VerificationURI: "http://verify.me", + ExpiresIn: 99, + Interval: 5, + }, + timeSleep: mockSleep, + timeNow: clock("0", "5s", "10s"), }, }, want: &api.AccessToken{ @@ -314,6 +315,50 @@ func TestPollToken(t *testing.T) { }, }, }, + { + name: "with client secret and grant type", + args: args{ + http: apiClient{ + stubs: []apiStub{ + { + body: "access_token=123abc", + status: 200, + contentType: "application/x-www-form-urlencoded; charset=utf-8", + }, + }, + }, + url: "https://github.com/oauth", + opts: PollOptions{ + ClientID: "CLIENT-ID", + ClientSecret: "SEKRIT", + GrantType: "device_code", + DeviceCode: &CodeResponse{ + DeviceCode: "DEVIC", + UserCode: "123-abc", + VerificationURI: "http://verify.me", + ExpiresIn: 99, + Interval: 5, + }, + timeSleep: mockSleep, + timeNow: clock("0", "5s", "10s"), + }, + }, + want: &api.AccessToken{ + Token: "123abc", + }, + slept: duration("5s"), + posts: []postArgs{ + { + url: "https://github.com/oauth", + params: url.Values{ + "client_id": {"CLIENT-ID"}, + "client_secret": {"SEKRIT"}, + "device_code": {"DEVIC"}, + "grant_type": {"device_code"}, + }, + }, + }, + }, { name: "timed out", args: args{ @@ -331,16 +376,18 @@ func TestPollToken(t *testing.T) { }, }, }, - url: "https://github.com/oauth", - clientID: "CLIENT-ID", - code: &CodeResponse{ - DeviceCode: "DEVIC", - UserCode: "123-abc", - VerificationURI: "http://verify.me", - ExpiresIn: 99, - Interval: 5, - timeSleep: mockSleep, - timeNow: clock("0", "5s", "15m"), + url: "https://github.com/oauth", + opts: PollOptions{ + ClientID: "CLIENT-ID", + DeviceCode: &CodeResponse{ + DeviceCode: "DEVIC", + UserCode: "123-abc", + VerificationURI: "http://verify.me", + ExpiresIn: 99, + Interval: 5, + }, + timeSleep: mockSleep, + timeNow: clock("0", "5s", "15m"), }, }, wantErr: "authentication timed out", @@ -376,16 +423,18 @@ func TestPollToken(t *testing.T) { }, }, }, - url: "https://github.com/oauth", - clientID: "CLIENT-ID", - code: &CodeResponse{ - DeviceCode: "DEVIC", - UserCode: "123-abc", - VerificationURI: "http://verify.me", - ExpiresIn: 99, - Interval: 5, - timeSleep: mockSleep, - timeNow: clock("0", "5s"), + url: "https://github.com/oauth", + opts: PollOptions{ + ClientID: "CLIENT-ID", + DeviceCode: &CodeResponse{ + DeviceCode: "DEVIC", + UserCode: "123-abc", + VerificationURI: "http://verify.me", + ExpiresIn: 99, + Interval: 5, + }, + timeSleep: mockSleep, + timeNow: clock("0", "5s"), }, }, wantErr: "access_denied", @@ -405,7 +454,7 @@ func TestPollToken(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { totalSlept = 0 - got, err := PollToken(&tt.args.http, tt.args.url, tt.args.clientID, tt.args.code) + got, err := PollTokenWithOptions(&tt.args.http, tt.args.url, tt.args.opts) if (err != nil) != (tt.wantErr != "") { t.Errorf("PollToken() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/device/examples_test.go b/device/examples_test.go index c03fa8e..778a5ac 100644 --- a/device/examples_test.go +++ b/device/examples_test.go @@ -22,7 +22,10 @@ func Example() { fmt.Printf("Copy code: %s\n", code.UserCode) fmt.Printf("then open: %s\n", code.VerificationURI) - accessToken, err := PollToken(httpClient, "https://github.com/login/oauth/access_token", clientID, code) + accessToken, err := PollTokenWithOptions(httpClient, "https://github.com/login/oauth/access_token", PollOptions{ + ClientID: clientID, + DeviceCode: code, + }) if err != nil { panic(err) } diff --git a/oauth_device.go b/oauth_device.go index 4faa4d6..139216a 100644 --- a/oauth_device.go +++ b/oauth_device.go @@ -58,7 +58,10 @@ func (oa *Flow) DeviceFlow() (*api.AccessToken, error) { return nil, fmt.Errorf("error opening the web browser: %w", err) } - return device.PollToken(httpClient, host.TokenURL, oa.ClientID, code) + return device.PollTokenWithOptions(httpClient, host.TokenURL, device.PollOptions{ + ClientID: oa.ClientID, + DeviceCode: code, + }) } func waitForEnter(r io.Reader) error {