Skip to content

Commit

Permalink
Add tls_config field to OAuth 2.0 Config (#331)
Browse files Browse the repository at this point in the history
* Add `TLSConfig` field

Signed-off-by: Levi Harrison <git@leviharrison.dev>

* Add tests

Signed-off-by: Levi Harrison <git@leviharrison.dev>
  • Loading branch information
LeviHarrison committed Oct 13, 2021
1 parent 5a26535 commit 4bfa954
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 7 deletions.
24 changes: 22 additions & 2 deletions config/http_config.go
Expand Up @@ -159,6 +159,9 @@ type OAuth2 struct {
Scopes []string `yaml:"scopes,omitempty" json:"scopes,omitempty"`
TokenURL string `yaml:"token_url" json:"token_url"`
EndpointParams map[string]string `yaml:"endpoint_params,omitempty" json:"endpoint_params,omitempty"`

// TLSConfig is used to connect to the token URL.
TLSConfig TLSConfig `yaml:"tls_config,omitempty"`
}

// SetDirectory joins any relative file paths with dir.
Expand Down Expand Up @@ -594,7 +597,25 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
EndpointParams: mapToValues(rt.config.EndpointParams),
}

tokenSource := config.TokenSource(context.Background())
tlsConfig, err := NewTLSConfig(&rt.config.TLSConfig)
if err != nil {
return nil, err
}

var t http.RoundTripper
if len(rt.config.TLSConfig.CAFile) == 0 {
t = &http.Transport{TLSClientConfig: tlsConfig}
} else {
t, err = NewTLSRoundTripper(tlsConfig, rt.config.TLSConfig.CAFile, func(tls *tls.Config) (http.RoundTripper, error) {
return &http.Transport{TLSClientConfig: tls}, nil
})
if err != nil {
return nil, err
}
}

ctx := context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{Transport: t})
tokenSource := config.TokenSource(ctx)

rt.mtx.Lock()
rt.secret = secret
Expand Down Expand Up @@ -763,7 +784,6 @@ func NewTLSRoundTripper(
return nil, err
}
t.rt = rt

_, t.hashCAFile, err = t.getCAWithHash()
if err != nil {
return nil, err
Expand Down
57 changes: 52 additions & 5 deletions config/http_config_test.go
Expand Up @@ -66,6 +66,7 @@ const (
ExpectedAuthenticationCredentials = AuthorizationType + " " + BearerToken
ExpectedUsername = "arthurdent"
ExpectedPassword = "42"
ExpectedAccessToken = "12345"
)

var invalidHTTPClientConfigs = []struct {
Expand Down Expand Up @@ -363,6 +364,45 @@ func TestNewClientFromConfig(t *testing.T) {
}
},
},
{
clientConfig: HTTPClientConfig{
OAuth2: &OAuth2{
ClientID: "ExpectedUsername",
ClientSecret: "ExpectedPassword",
TLSConfig: TLSConfig{
CAFile: TLSCAChainPath,
CertFile: ClientCertificatePath,
KeyFile: ClientKeyNoPassPath,
ServerName: "",
InsecureSkipVerify: false},
},
TLSConfig: TLSConfig{
CAFile: TLSCAChainPath,
CertFile: ClientCertificatePath,
KeyFile: ClientKeyNoPassPath,
ServerName: "",
InsecureSkipVerify: false},
},
handler: func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
res, _ := json.Marshal(oauth2TestServerResponse{
AccessToken: ExpectedAccessToken,
TokenType: "Bearer",
})
w.Header().Add("Content-Type", "application/json")
_, _ = w.Write(res)

default:
authorization := r.Header.Get("Authorization")
if authorization != "Bearer "+ExpectedAccessToken {
fmt.Fprintf(w, "Expected Authorization header %q, got %q", "Bearer "+ExpectedAccessToken, authorization)
} else {
fmt.Fprint(w, ExpectedMessage)
}
}
},
},
}

for _, validConfig := range newClientValidConfig {
Expand All @@ -372,6 +412,12 @@ func TestNewClientFromConfig(t *testing.T) {
}
defer testServer.Close()

if validConfig.clientConfig.OAuth2 != nil {
// We don't have access to the test server's URL when configuring the test cases,
// so it has to be specified here.
validConfig.clientConfig.OAuth2.TokenURL = testServer.URL + "/token"
}

err = validConfig.clientConfig.Validate()
if err != nil {
t.Fatal(err.Error())
Expand All @@ -381,6 +427,7 @@ func TestNewClientFromConfig(t *testing.T) {
t.Errorf("Can't create a client from this config: %+v", validConfig.clientConfig)
continue
}

response, err := client.Get(testServer.URL)
if err != nil {
t.Errorf("Can't connect to the test server using this config: %+v: %v", validConfig.clientConfig, err)
Expand Down Expand Up @@ -1129,14 +1176,14 @@ func NewRoundTripCheckRequest(checkRequest func(*http.Request), theResponse *htt
theError: theError}}
}

type testServerResponse struct {
type oauth2TestServerResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
}

func TestOAuth2(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
res, _ := json.Marshal(testServerResponse{
res, _ := json.Marshal(oauth2TestServerResponse{
AccessToken: "12345",
TokenType: "Bearer",
})
Expand Down Expand Up @@ -1169,7 +1216,7 @@ endpoint_params:
t.Fatalf("Expected no error unmarshalling yaml, got %v", err)
}
if !reflect.DeepEqual(unmarshalledConfig, expectedConfig) {
t.Fatalf("Got unmarshalled config %q, expected %q", unmarshalledConfig, expectedConfig)
t.Fatalf("Got unmarshalled config %v, expected %v", unmarshalledConfig, expectedConfig)
}

rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport)
Expand Down Expand Up @@ -1197,7 +1244,7 @@ func TestOAuth2WithFile(t *testing.T) {
t.Fatal("token endpoint called twice")
}
previousAuth = auth
res, _ := json.Marshal(testServerResponse{
res, _ := json.Marshal(oauth2TestServerResponse{
AccessToken: "12345",
TokenType: "Bearer",
})
Expand Down Expand Up @@ -1244,7 +1291,7 @@ endpoint_params:
t.Fatalf("Expected no error unmarshalling yaml, got %v", err)
}
if !reflect.DeepEqual(unmarshalledConfig, expectedConfig) {
t.Fatalf("Got unmarshalled config %q, expected %q", unmarshalledConfig, expectedConfig)
t.Fatalf("Got unmarshalled config %v, expected %v", unmarshalledConfig, expectedConfig)
}

rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport)
Expand Down

0 comments on commit 4bfa954

Please sign in to comment.