From 94b0f6eaae1478cd1b8ef842ddef5b6edd0f550e Mon Sep 17 00:00:00 2001 From: Sam Coe Date: Wed, 22 Jun 2022 11:28:59 +0100 Subject: [PATCH] Add support for localhost (#51) --- internal/api/gql_client.go | 20 ++++++---- internal/api/gql_client_test.go | 31 +++++++++++++++ internal/api/http.go | 16 +++++++- internal/api/http_test.go | 67 ++++++++++++++++++++++++++++++++ internal/api/rest_client.go | 6 ++- internal/api/rest_client_test.go | 31 +++++++++++++++ pkg/auth/auth.go | 6 ++- pkg/auth/auth_test.go | 67 ++++++++++++++++++++++++++++++++ 8 files changed, 233 insertions(+), 11 deletions(-) diff --git a/internal/api/gql_client.go b/internal/api/gql_client.go index 2c381ab..73aeb8f 100644 --- a/internal/api/gql_client.go +++ b/internal/api/gql_client.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "strings" "github.com/cli/go-gh/pkg/api" @@ -23,15 +24,9 @@ type gqlClient struct { func NewGQLClient(host string, opts *api.ClientOptions) api.GQLClient { httpClient := NewHTTPClient(opts) - if isEnterprise(host) { - host = fmt.Sprintf("https://%s/api/graphql", host) - } else { - host = "https://api.github.com/graphql" - } - return gqlClient{ client: graphql.NewClient(host, &httpClient), - host: host, + host: gqlEndpoint(host), httpClient: &httpClient, } } @@ -114,3 +109,14 @@ type gqlResponse struct { Data interface{} Errors []api.GQLErrorItem } + +func gqlEndpoint(host string) string { + host = normalizeHostname(host) + if isEnterprise(host) { + return fmt.Sprintf("https://%s/api/graphql", host) + } + if strings.EqualFold(host, localhost) { + return fmt.Sprintf("http://api.%s/graphql", host) + } + return fmt.Sprintf("https://api.%s/graphql", host) +} diff --git a/internal/api/gql_client_test.go b/internal/api/gql_client_test.go index c5c2eb0..34dab98 100644 --- a/internal/api/gql_client_test.go +++ b/internal/api/gql_client_test.go @@ -167,3 +167,34 @@ func TestGQLClientDoWithContext(t *testing.T) { }) } } + +func TestGQLEndpoint(t *testing.T) { + tests := []struct { + name string + host string + wantEndpoint string + }{ + { + name: "github", + host: "github.com", + wantEndpoint: "https://api.github.com/graphql", + }, + { + name: "localhost", + host: "github.localhost", + wantEndpoint: "http://api.github.localhost/graphql", + }, + { + name: "enterprise", + host: "enterprise.com", + wantEndpoint: "https://enterprise.com/api/graphql", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + endpoint := gqlEndpoint(tt.host) + assert.Equal(t, tt.wantEndpoint, endpoint) + }) + } +} diff --git a/internal/api/http.go b/internal/api/http.go index ef27f83..4104057 100644 --- a/internal/api/http.go +++ b/internal/api/http.go @@ -19,8 +19,9 @@ const ( accept = "Accept" authorization = "Authorization" contentType = "Content-Type" - defaultHostname = "github.com" + github = "github.com" jsonContentType = "application/json; charset=utf-8" + localhost = "github.localhost" modulePath = "github.com/cli/go-gh" timeZone = "Time-Zone" userAgent = "User-Agent" @@ -127,7 +128,18 @@ func isSameDomain(requestHost, domain string) bool { } func isEnterprise(host string) bool { - return host != defaultHostname + return host != github && host != localhost +} + +func normalizeHostname(hostname string) string { + hostname = strings.ToLower(hostname) + if strings.HasSuffix(hostname, "."+github) { + return github + } + if strings.HasSuffix(hostname, "."+localhost) { + return localhost + } + return hostname } type headerRoundTripper struct { diff --git a/internal/api/http_test.go b/internal/api/http_test.go index efeb65f..4384094 100644 --- a/internal/api/http_test.go +++ b/internal/api/http_test.go @@ -121,6 +121,73 @@ func TestNewHTTPClient(t *testing.T) { } } +func TestIsEnterprise(t *testing.T) { + tests := []struct { + name string + host string + wantOut bool + }{ + { + name: "github", + host: "github.com", + wantOut: false, + }, + { + name: "localhost", + host: "github.localhost", + wantOut: false, + }, + { + name: "enterprise", + host: "mygithub.com", + wantOut: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out := isEnterprise(tt.host) + assert.Equal(t, tt.wantOut, out) + }) + } +} + +func TestNormalizeHostname(t *testing.T) { + tests := []struct { + name string + host string + wantHost string + }{ + { + name: "github domain", + host: "test.github.com", + wantHost: "github.com", + }, + { + name: "capitalized", + host: "GitHub.com", + wantHost: "github.com", + }, + { + name: "localhost domain", + host: "test.github.localhost", + wantHost: "github.localhost", + }, + { + name: "enterprise domain", + host: "mygithub.com", + wantHost: "mygithub.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + normalized := normalizeHostname(tt.host) + assert.Equal(t, tt.wantHost, normalized) + }) + } +} + type tripper struct { roundTrip func(*http.Request) (*http.Response, error) } diff --git a/internal/api/rest_client.go b/internal/api/rest_client.go index 8545577..c95770b 100644 --- a/internal/api/rest_client.go +++ b/internal/api/rest_client.go @@ -119,8 +119,12 @@ func restURL(hostname string, pathOrURL string) string { } func restPrefix(hostname string) string { + hostname = normalizeHostname(hostname) if isEnterprise(hostname) { return fmt.Sprintf("https://%s/api/v3/", hostname) } - return "https://api.github.com/" + if strings.EqualFold(hostname, localhost) { + return fmt.Sprintf("http://api.%s/", hostname) + } + return fmt.Sprintf("https://api.%s/", hostname) } diff --git a/internal/api/rest_client_test.go b/internal/api/rest_client_test.go index fcf309e..3479087 100644 --- a/internal/api/rest_client_test.go +++ b/internal/api/rest_client_test.go @@ -390,6 +390,37 @@ func TestRESTClientRequestWithContext(t *testing.T) { } } +func TestRestPrefix(t *testing.T) { + tests := []struct { + name string + host string + wantEndpoint string + }{ + { + name: "github", + host: "github.com", + wantEndpoint: "https://api.github.com/", + }, + { + name: "localhost", + host: "github.localhost", + wantEndpoint: "http://api.github.localhost/", + }, + { + name: "enterprise", + host: "enterprise.com", + wantEndpoint: "https://enterprise.com/api/v3/", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + endpoint := restPrefix(tt.host) + assert.Equal(t, tt.wantEndpoint, endpoint) + }) + } +} + func printPendingMocks(mocks []gock.Mock) string { paths := []string{} for _, mock := range mocks { diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 1bcd5eb..4edc52f 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -21,6 +21,7 @@ const ( githubEnterpriseToken = "GITHUB_ENTERPRISE_TOKEN" githubToken = "GITHUB_TOKEN" hostsKey = "hosts" + localhost = "github.localhost" oauthToken = "oauth_token" ) @@ -114,7 +115,7 @@ func defaultHost(cfg *config.Config) (string, string) { } func isEnterprise(host string) bool { - return host != github + return host != github && host != localhost } func normalizeHostname(host string) string { @@ -122,5 +123,8 @@ func normalizeHostname(host string) string { if strings.HasSuffix(hostname, "."+github) { return github } + if strings.HasSuffix(hostname, "."+localhost) { + return localhost + } return hostname } diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index a07cdbf..9ab66cc 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -223,6 +223,73 @@ func TestKnownHosts(t *testing.T) { } } +func TestIsEnterprise(t *testing.T) { + tests := []struct { + name string + host string + wantOut bool + }{ + { + name: "github", + host: "github.com", + wantOut: false, + }, + { + name: "localhost", + host: "github.localhost", + wantOut: false, + }, + { + name: "enterprise", + host: "mygithub.com", + wantOut: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out := isEnterprise(tt.host) + assert.Equal(t, tt.wantOut, out) + }) + } +} + +func TestNormalizeHostname(t *testing.T) { + tests := []struct { + name string + host string + wantHost string + }{ + { + name: "github domain", + host: "test.github.com", + wantHost: "github.com", + }, + { + name: "capitalized", + host: "GitHub.com", + wantHost: "github.com", + }, + { + name: "localhost domain", + host: "test.github.localhost", + wantHost: "github.localhost", + }, + { + name: "enterprise domain", + host: "mygithub.com", + wantHost: "mygithub.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + normalized := normalizeHostname(tt.host) + assert.Equal(t, tt.wantHost, normalized) + }) + } +} + func testNoHostsConfig() *config.Config { var data = `` return config.ReadFromString(data)