Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support third-party OAuth hosts #7

Merged
merged 2 commits into from Oct 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
36 changes: 24 additions & 12 deletions api/form.go
@@ -1,12 +1,14 @@
package api

import (
"encoding/json"
"fmt"
"io"
"io/ioutil"
"mime"
"net/http"
"net/url"
"strings"
"strconv"
)

type httpClient interface {
Expand Down Expand Up @@ -71,7 +73,9 @@ func PostForm(c httpClient, u string, params url.Values) (*FormResponse, error)
requestURI: u,
}

if contentType(resp.Header.Get("Content-Type")) == formType {
mediaType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type"))
switch mediaType {
case "application/x-www-form-urlencoded":
var bb []byte
bb, err = ioutil.ReadAll(resp.Body)
if err != nil {
Expand All @@ -82,7 +86,24 @@ func PostForm(c httpClient, u string, params url.Values) (*FormResponse, error)
if err != nil {
return r, err
}
} else {
case "application/json":
var values map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&values); err != nil {
return r, err
}

r.values = make(url.Values)
for key, value := range values {
switch v := value.(type) {
case string:
r.values.Set(key, v)
case int64:
r.values.Set(key, strconv.FormatInt(v, 10))
case float64:
r.values.Set(key, strconv.FormatFloat(v, 'f', -1, 64))
}
}
default:
_, err = io.Copy(ioutil.Discard, resp.Body)
if err != nil {
return r, err
Expand All @@ -91,12 +112,3 @@ func PostForm(c httpClient, u string, params url.Values) (*FormResponse, error)

return r, nil
}

const formType = "application/x-www-form-urlencoded"

func contentType(t string) string {
if i := strings.IndexRune(t, ';'); i >= 0 {
return t[0:i]
}
return t
}
22 changes: 21 additions & 1 deletion api/form_test.go
Expand Up @@ -141,7 +141,7 @@ func TestPostForm(t *testing.T) {
wantErr bool
}{
{
name: "success",
name: "success urlencoded",
args: args{
url: "https://github.com/oauth",
},
Expand All @@ -160,6 +160,26 @@ func TestPostForm(t *testing.T) {
},
wantErr: false,
},
{
name: "success JSON",
args: args{
url: "https://github.com/oauth",
},
http: apiClient{
body: `{"access_token":"123abc", "scopes":"repo gist"}`,
status: 200,
contentType: "application/json; charset=utf-8",
},
want: &FormResponse{
StatusCode: 200,
requestURI: "https://github.com/oauth",
values: url.Values{
"access_token": {"123abc"},
"scopes": {"repo gist"},
},
},
wantErr: false,
},
{
name: "HTML response",
args: args{
Expand Down
2 changes: 1 addition & 1 deletion examples_test.go
Expand Up @@ -10,7 +10,7 @@ import (
// flow support is globally available, but enables logging in to hosted GitHub instances as well.
func Example() {
flow := &Flow{
Hostname: "github.com",
Host: GitHubHost("https://github.com"),
ClientID: os.Getenv("OAUTH_CLIENT_ID"),
ClientSecret: os.Getenv("OAUTH_CLIENT_SECRET"), // only applicable to web app flow
CallbackURI: "http://127.0.0.1/callback", // only applicable to web app flow
Expand Down
1 change: 1 addition & 0 deletions go.sum
@@ -1,3 +1,4 @@
github.com/cli/browser v1.0.0 h1:RIleZgXrhdiCVgFBSjtWwkLPUCWyhhhN5k5HGSBt1js=
github.com/cli/browser v1.0.0/go.mod h1:IEWkHYbLjkhtjwwWlwTHW2lGxeS5gezEQBMLTwDHf5Q=
github.com/cli/safeexec v1.0.0 h1:0VngyaIyqACHdcMNWfo6+KdUYnqEr2Sg+bSP1pdF+dI=
github.com/cli/safeexec v1.0.0/go.mod h1:Z/D4tTN8Vs5gXYHDCbaM1S/anmEDnJb1iW0+EJ5zx3Q=
36 changes: 23 additions & 13 deletions oauth.go
Expand Up @@ -17,10 +17,32 @@ type httpClient interface {
PostForm(string, url.Values) (*http.Response, error)
}

// Host defines the endpoints used to authorize against an OAuth server.
type Host struct {
DeviceCodeURL string
AuthorizeURL string
TokenURL string
}

// GitHubHost constructs a Host from the given URL to a GitHub instance.
func GitHubHost(hostURL string) *Host {
u, _ := url.Parse(hostURL)

return &Host{
DeviceCodeURL: fmt.Sprintf("%s://%s/login/device/code", u.Scheme, u.Host),
AuthorizeURL: fmt.Sprintf("%s://%s/login/oauth/authorize", u.Scheme, u.Host),
TokenURL: fmt.Sprintf("%s://%s/login/oauth/access_token", u.Scheme, u.Host),
}
}

// Flow facilitates a single OAuth authorization flow.
type Flow struct {
// The host to authorize the app with.
// The hostname to authorize the app with.
//
// Deprecated: Use Host instead.
Hostname string
// Host configuration to authorize the app with.
Host *Host
// OAuth scopes to request from the user.
Scopes []string
// OAuth application ID.
Expand All @@ -47,18 +69,6 @@ type Flow struct {
Stdout io.Writer
}

func deviceInitURL(host string) string {
return fmt.Sprintf("https://%s/login/device/code", host)
}

func webappInitURL(host string) string {
return fmt.Sprintf("https://%s/login/oauth/authorize", host)
}

func tokenURL(host string) string {
return fmt.Sprintf("https://%s/login/oauth/access_token", host)
}

// DetectFlow tries to perform Device flow first and falls back to Web application flow.
func (oa *Flow) DetectFlow() (*api.AccessToken, error) {
accessToken, err := oa.DeviceFlow()
Expand Down
8 changes: 6 additions & 2 deletions oauth_device.go
Expand Up @@ -28,8 +28,12 @@ func (oa *Flow) DeviceFlow() (*api.AccessToken, error) {
if stdout == nil {
stdout = os.Stdout
}
host := oa.Host
if host == nil {
host = GitHubHost("https://" + oa.Hostname)
}

code, err := device.RequestCode(httpClient, deviceInitURL(oa.Hostname), oa.ClientID, oa.Scopes)
code, err := device.RequestCode(httpClient, host.DeviceCodeURL, oa.ClientID, oa.Scopes)
if err != nil {
return nil, err
}
Expand All @@ -54,7 +58,7 @@ func (oa *Flow) DeviceFlow() (*api.AccessToken, error) {
return nil, fmt.Errorf("error opening the web browser: %w", err)
}

return device.PollToken(httpClient, tokenURL(oa.Hostname), oa.ClientID, code)
return device.PollToken(httpClient, host.TokenURL, oa.ClientID, code)
}

func waitForEnter(r io.Reader) error {
Expand Down
9 changes: 7 additions & 2 deletions oauth_webapp.go
Expand Up @@ -12,6 +12,11 @@ import (
// WebAppFlow starts a local HTTP server, opens the web browser to initiate the OAuth Web application
// flow, blocks until the user completes authorization and is redirected back, and returns the access token.
func (oa *Flow) WebAppFlow() (*api.AccessToken, error) {
host := oa.Host
if host == nil {
host = GitHubHost("https://" + oa.Hostname)
}

flow, err := webapp.InitFlow()
if err != nil {
return nil, err
Expand All @@ -23,7 +28,7 @@ func (oa *Flow) WebAppFlow() (*api.AccessToken, error) {
Scopes: oa.Scopes,
AllowSignup: true,
}
browserURL, err := flow.BrowserURL(webappInitURL(oa.Hostname), params)
browserURL, err := flow.BrowserURL(host.AuthorizeURL, params)
if err != nil {
return nil, err
}
Expand All @@ -47,5 +52,5 @@ func (oa *Flow) WebAppFlow() (*api.AccessToken, error) {
httpClient = http.DefaultClient
}

return flow.AccessToken(httpClient, tokenURL(oa.Hostname), oa.ClientSecret)
return flow.AccessToken(httpClient, host.TokenURL, oa.ClientSecret)
}