Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Go context-aware Wait functions for blocking operation #39

Merged
merged 3 commits into from Dec 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
42 changes: 19 additions & 23 deletions device/device_flow.go
Expand Up @@ -13,6 +13,7 @@
package device

import (
"context"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -103,16 +104,16 @@ const defaultGrantType = "urn:ietf:params:oauth:grant-type:device_code"

// PollToken polls the server at pollURL until an access token is granted or denied.
//
// Deprecated: use PollTokenWithOptions.
// Deprecated: use Wait.
func PollToken(c httpClient, pollURL string, clientID string, code *CodeResponse) (*api.AccessToken, error) {
return PollTokenWithOptions(c, pollURL, PollOptions{
return Wait(context.Background(), c, pollURL, WaitOptions{
ClientID: clientID,
DeviceCode: code,
})
}

// PollOptions specifies parameters to poll the server with until authentication completes.
type PollOptions struct {
// WaitOptions specifies parameters to poll the server with until authentication completes.
type WaitOptions struct {
// ClientID is the app client ID value.
ClientID string
// ClientSecret is the app client secret value. Optional: only pass if the server requires it.
Expand All @@ -122,30 +123,28 @@ type PollOptions struct {
// GrantType overrides the default value specified by OAuth 2.0 Device Code. Optional.
GrantType string

timeNow func() time.Time
timeSleep func(time.Duration)
newPoller pollerFactory
}

// PollTokenWithOptions polls the server at uri until authorization completes.
func PollTokenWithOptions(c httpClient, uri string, opts PollOptions) (*api.AccessToken, error) {
timeNow := opts.timeNow
if timeNow == nil {
timeNow = time.Now
}
timeSleep := opts.timeSleep
if timeSleep == nil {
timeSleep = time.Sleep
}

// Wait polls the server at uri until authorization completes.
func Wait(ctx context.Context, c httpClient, uri string, opts WaitOptions) (*api.AccessToken, error) {
checkInterval := time.Duration(opts.DeviceCode.Interval) * time.Second
expiresAt := timeNow().Add(time.Duration(opts.DeviceCode.ExpiresIn) * time.Second)
expiresIn := time.Duration(opts.DeviceCode.ExpiresIn) * time.Second
grantType := opts.GrantType
if opts.GrantType == "" {
grantType = defaultGrantType
}

makePoller := opts.newPoller
if makePoller == nil {
makePoller = newPoller
}
_, poll := makePoller(ctx, checkInterval, expiresIn)

for {
timeSleep(checkInterval)
if err := poll.Wait(); err != nil {
return nil, err
}

values := url.Values{
"client_id": {opts.ClientID},
Expand All @@ -158,6 +157,7 @@ func PollTokenWithOptions(c httpClient, uri string, opts PollOptions) (*api.Acce
values.Add("client_secret", opts.ClientSecret)
}

// TODO: pass tctx down to the HTTP layer
resp, err := api.PostForm(c, uri, values)
if err != nil {
return nil, err
Expand All @@ -170,9 +170,5 @@ func PollTokenWithOptions(c httpClient, uri string, opts PollOptions) (*api.Acce
} else if !(errors.As(err, &apiError) && apiError.Code == "authorization_pending") {
return nil, err
}

if timeNow().After(expiresAt) {
return nil, ErrTimeout
}
}
}
72 changes: 33 additions & 39 deletions device/device_flow_test.go
Expand Up @@ -2,6 +2,8 @@ package device

import (
"bytes"
"context"
"errors"
"io/ioutil"
"net/http"
"net/url"
Expand Down Expand Up @@ -230,28 +232,16 @@ func TestRequestCode(t *testing.T) {
}

func TestPollToken(t *testing.T) {
var totalSlept time.Duration
mockSleep := func(d time.Duration) {
totalSlept += d
}
duration := func(d string) time.Duration {
res, _ := time.ParseDuration(d)
return res
}
clock := func(durations ...string) func() time.Time {
count := 0
now := time.Now()
return func() time.Time {
t := now.Add(duration(durations[count]))
count++
return t
makeFakePoller := func(maxWaits int) pollerFactory {
return func(ctx context.Context, interval, expiresIn time.Duration) (context.Context, poller) {
return ctx, &fakePoller{maxWaits: maxWaits}
}
}

type args struct {
http apiClient
url string
opts PollOptions
opts WaitOptions
}
tests := []struct {
name string
Expand Down Expand Up @@ -279,7 +269,7 @@ func TestPollToken(t *testing.T) {
},
},
url: "https://github.com/oauth",
opts: PollOptions{
opts: WaitOptions{
ClientID: "CLIENT-ID",
DeviceCode: &CodeResponse{
DeviceCode: "DEVIC",
Expand All @@ -288,14 +278,12 @@ func TestPollToken(t *testing.T) {
ExpiresIn: 99,
Interval: 5,
},
timeSleep: mockSleep,
timeNow: clock("0", "5s", "10s"),
newPoller: makeFakePoller(2),
},
},
want: &api.AccessToken{
Token: "123abc",
},
slept: duration("10s"),
posts: []postArgs{
{
url: "https://github.com/oauth",
Expand Down Expand Up @@ -328,7 +316,7 @@ func TestPollToken(t *testing.T) {
},
},
url: "https://github.com/oauth",
opts: PollOptions{
opts: WaitOptions{
ClientID: "CLIENT-ID",
ClientSecret: "SEKRIT",
GrantType: "device_code",
Expand All @@ -339,14 +327,12 @@ func TestPollToken(t *testing.T) {
ExpiresIn: 99,
Interval: 5,
},
timeSleep: mockSleep,
timeNow: clock("0", "5s", "10s"),
newPoller: makeFakePoller(1),
},
},
want: &api.AccessToken{
Token: "123abc",
},
slept: duration("5s"),
posts: []postArgs{
{
url: "https://github.com/oauth",
Expand Down Expand Up @@ -377,21 +363,19 @@ func TestPollToken(t *testing.T) {
},
},
url: "https://github.com/oauth",
opts: PollOptions{
opts: WaitOptions{
ClientID: "CLIENT-ID",
DeviceCode: &CodeResponse{
DeviceCode: "DEVIC",
UserCode: "123-abc",
VerificationURI: "http://verify.me",
ExpiresIn: 99,
ExpiresIn: 14,
Interval: 5,
},
timeSleep: mockSleep,
timeNow: clock("0", "5s", "15m"),
newPoller: makeFakePoller(2),
},
},
wantErr: "authentication timed out",
slept: duration("10s"),
wantErr: "context deadline exceeded",
posts: []postArgs{
{
url: "https://github.com/oauth",
Expand Down Expand Up @@ -424,7 +408,7 @@ func TestPollToken(t *testing.T) {
},
},
url: "https://github.com/oauth",
opts: PollOptions{
opts: WaitOptions{
ClientID: "CLIENT-ID",
DeviceCode: &CodeResponse{
DeviceCode: "DEVIC",
Expand All @@ -433,12 +417,10 @@ func TestPollToken(t *testing.T) {
ExpiresIn: 99,
Interval: 5,
},
timeSleep: mockSleep,
timeNow: clock("0", "5s"),
newPoller: makeFakePoller(1),
},
},
wantErr: "access_denied",
slept: duration("5s"),
posts: []postArgs{
{
url: "https://github.com/oauth",
Expand All @@ -453,8 +435,7 @@ func TestPollToken(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
totalSlept = 0
got, err := PollTokenWithOptions(&tt.args.http, tt.args.url, tt.args.opts)
got, err := Wait(context.Background(), &tt.args.http, tt.args.url, tt.args.opts)
if (err != nil) != (tt.wantErr != "") {
t.Errorf("PollToken() error = %v, wantErr %v", err, tt.wantErr)
return
Expand All @@ -468,9 +449,22 @@ func TestPollToken(t *testing.T) {
if !reflect.DeepEqual(tt.args.http.calls, tt.posts) {
t.Errorf("PostForm() = %v, want %v", tt.args.http.calls, tt.posts)
}
if totalSlept != tt.slept {
t.Errorf("slept %v, wanted %v", totalSlept, tt.slept)
}
})
}
}

type fakePoller struct {
maxWaits int
count int
}

func (p *fakePoller) Wait() error {
if p.count == p.maxWaits {
return errors.New("context deadline exceeded")
}
p.count++
return nil
}

func (p *fakePoller) Cancel() {
}
3 changes: 2 additions & 1 deletion device/examples_test.go
@@ -1,6 +1,7 @@
package device

import (
"context"
"fmt"
"net/http"
"os"
Expand All @@ -22,7 +23,7 @@ func Example() {
fmt.Printf("Copy code: %s\n", code.UserCode)
fmt.Printf("then open: %s\n", code.VerificationURI)

accessToken, err := PollTokenWithOptions(httpClient, "https://github.com/login/oauth/access_token", PollOptions{
accessToken, err := Wait(context.TODO(), httpClient, "https://github.com/login/oauth/access_token", WaitOptions{
ClientID: clientID,
DeviceCode: code,
})
Expand Down
43 changes: 43 additions & 0 deletions device/poller.go
@@ -0,0 +1,43 @@
package device

import (
"context"
"time"
)

type poller interface {
Wait() error
Cancel()
}

type pollerFactory func(context.Context, time.Duration, time.Duration) (context.Context, poller)

func newPoller(ctx context.Context, checkInteval, expiresIn time.Duration) (context.Context, poller) {
c, cancel := context.WithTimeout(ctx, expiresIn)
return c, &intervalPoller{
ctx: c,
interval: checkInteval,
cancelFunc: cancel,
}
}

type intervalPoller struct {
ctx context.Context
interval time.Duration
cancelFunc func()
}

func (p intervalPoller) Wait() error {
t := time.NewTimer(p.interval)
select {
case <-p.ctx.Done():
t.Stop()
return p.ctx.Err()
case <-t.C:
return nil
}
}

func (p intervalPoller) Cancel() {
p.cancelFunc()
}
3 changes: 2 additions & 1 deletion oauth_device.go
Expand Up @@ -2,6 +2,7 @@ package oauth

import (
"bufio"
"context"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -58,7 +59,7 @@ func (oa *Flow) DeviceFlow() (*api.AccessToken, error) {
return nil, fmt.Errorf("error opening the web browser: %w", err)
}

return device.PollTokenWithOptions(httpClient, host.TokenURL, device.PollOptions{
return device.Wait(context.TODO(), httpClient, host.TokenURL, device.WaitOptions{
ClientID: oa.ClientID,
DeviceCode: code,
})
Expand Down
5 changes: 4 additions & 1 deletion oauth_webapp.go
@@ -1,6 +1,7 @@
package oauth

import (
"context"
"fmt"
"net/http"

Expand Down Expand Up @@ -52,5 +53,7 @@ func (oa *Flow) WebAppFlow() (*api.AccessToken, error) {
httpClient = http.DefaultClient
}

return flow.AccessToken(httpClient, host.TokenURL, oa.ClientSecret)
return flow.Wait(context.TODO(), httpClient, host.TokenURL, webapp.WaitOptions{
ClientSecret: oa.ClientSecret,
})
}
5 changes: 4 additions & 1 deletion webapp/examples_test.go
@@ -1,6 +1,7 @@
package webapp

import (
"context"
"fmt"
"net/http"
"os"
Expand Down Expand Up @@ -42,7 +43,9 @@ func Example() {
}

httpClient := http.DefaultClient
accessToken, err := flow.AccessToken(httpClient, "https://github.com/login/oauth/access_token", clientSecret)
accessToken, err := flow.Wait(context.TODO(), httpClient, "https://github.com/login/oauth/access_token", WaitOptions{
ClientSecret: clientSecret,
})
if err != nil {
panic(err)
}
Expand Down