From 4bfa954c24f5ae9811682883069d53cbf21053c8 Mon Sep 17 00:00:00 2001 From: Levi Harrison Date: Wed, 13 Oct 2021 07:55:52 -0400 Subject: [PATCH] Add `tls_config` field to OAuth 2.0 Config (#331) * Add `TLSConfig` field Signed-off-by: Levi Harrison * Add tests Signed-off-by: Levi Harrison --- config/http_config.go | 24 ++++++++++++++-- config/http_config_test.go | 57 ++++++++++++++++++++++++++++++++++---- 2 files changed, 74 insertions(+), 7 deletions(-) diff --git a/config/http_config.go b/config/http_config.go index 8c07c41f..5f520edc 100644 --- a/config/http_config.go +++ b/config/http_config.go @@ -159,6 +159,9 @@ type OAuth2 struct { Scopes []string `yaml:"scopes,omitempty" json:"scopes,omitempty"` TokenURL string `yaml:"token_url" json:"token_url"` EndpointParams map[string]string `yaml:"endpoint_params,omitempty" json:"endpoint_params,omitempty"` + + // TLSConfig is used to connect to the token URL. + TLSConfig TLSConfig `yaml:"tls_config,omitempty"` } // SetDirectory joins any relative file paths with dir. @@ -594,7 +597,25 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro EndpointParams: mapToValues(rt.config.EndpointParams), } - tokenSource := config.TokenSource(context.Background()) + tlsConfig, err := NewTLSConfig(&rt.config.TLSConfig) + if err != nil { + return nil, err + } + + var t http.RoundTripper + if len(rt.config.TLSConfig.CAFile) == 0 { + t = &http.Transport{TLSClientConfig: tlsConfig} + } else { + t, err = NewTLSRoundTripper(tlsConfig, rt.config.TLSConfig.CAFile, func(tls *tls.Config) (http.RoundTripper, error) { + return &http.Transport{TLSClientConfig: tls}, nil + }) + if err != nil { + return nil, err + } + } + + ctx := context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{Transport: t}) + tokenSource := config.TokenSource(ctx) rt.mtx.Lock() rt.secret = secret @@ -763,7 +784,6 @@ func NewTLSRoundTripper( return nil, err } t.rt = rt - _, t.hashCAFile, err = t.getCAWithHash() if err != nil { return nil, err diff --git a/config/http_config_test.go b/config/http_config_test.go index 689bbde2..cbaaba0a 100644 --- a/config/http_config_test.go +++ b/config/http_config_test.go @@ -66,6 +66,7 @@ const ( ExpectedAuthenticationCredentials = AuthorizationType + " " + BearerToken ExpectedUsername = "arthurdent" ExpectedPassword = "42" + ExpectedAccessToken = "12345" ) var invalidHTTPClientConfigs = []struct { @@ -363,6 +364,45 @@ func TestNewClientFromConfig(t *testing.T) { } }, }, + { + clientConfig: HTTPClientConfig{ + OAuth2: &OAuth2{ + ClientID: "ExpectedUsername", + ClientSecret: "ExpectedPassword", + TLSConfig: TLSConfig{ + CAFile: TLSCAChainPath, + CertFile: ClientCertificatePath, + KeyFile: ClientKeyNoPassPath, + ServerName: "", + InsecureSkipVerify: false}, + }, + TLSConfig: TLSConfig{ + CAFile: TLSCAChainPath, + CertFile: ClientCertificatePath, + KeyFile: ClientKeyNoPassPath, + ServerName: "", + InsecureSkipVerify: false}, + }, + handler: func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + res, _ := json.Marshal(oauth2TestServerResponse{ + AccessToken: ExpectedAccessToken, + TokenType: "Bearer", + }) + w.Header().Add("Content-Type", "application/json") + _, _ = w.Write(res) + + default: + authorization := r.Header.Get("Authorization") + if authorization != "Bearer "+ExpectedAccessToken { + fmt.Fprintf(w, "Expected Authorization header %q, got %q", "Bearer "+ExpectedAccessToken, authorization) + } else { + fmt.Fprint(w, ExpectedMessage) + } + } + }, + }, } for _, validConfig := range newClientValidConfig { @@ -372,6 +412,12 @@ func TestNewClientFromConfig(t *testing.T) { } defer testServer.Close() + if validConfig.clientConfig.OAuth2 != nil { + // We don't have access to the test server's URL when configuring the test cases, + // so it has to be specified here. + validConfig.clientConfig.OAuth2.TokenURL = testServer.URL + "/token" + } + err = validConfig.clientConfig.Validate() if err != nil { t.Fatal(err.Error()) @@ -381,6 +427,7 @@ func TestNewClientFromConfig(t *testing.T) { t.Errorf("Can't create a client from this config: %+v", validConfig.clientConfig) continue } + response, err := client.Get(testServer.URL) if err != nil { t.Errorf("Can't connect to the test server using this config: %+v: %v", validConfig.clientConfig, err) @@ -1129,14 +1176,14 @@ func NewRoundTripCheckRequest(checkRequest func(*http.Request), theResponse *htt theError: theError}} } -type testServerResponse struct { +type oauth2TestServerResponse struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` } func TestOAuth2(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - res, _ := json.Marshal(testServerResponse{ + res, _ := json.Marshal(oauth2TestServerResponse{ AccessToken: "12345", TokenType: "Bearer", }) @@ -1169,7 +1216,7 @@ endpoint_params: t.Fatalf("Expected no error unmarshalling yaml, got %v", err) } if !reflect.DeepEqual(unmarshalledConfig, expectedConfig) { - t.Fatalf("Got unmarshalled config %q, expected %q", unmarshalledConfig, expectedConfig) + t.Fatalf("Got unmarshalled config %v, expected %v", unmarshalledConfig, expectedConfig) } rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport) @@ -1197,7 +1244,7 @@ func TestOAuth2WithFile(t *testing.T) { t.Fatal("token endpoint called twice") } previousAuth = auth - res, _ := json.Marshal(testServerResponse{ + res, _ := json.Marshal(oauth2TestServerResponse{ AccessToken: "12345", TokenType: "Bearer", }) @@ -1244,7 +1291,7 @@ endpoint_params: t.Fatalf("Expected no error unmarshalling yaml, got %v", err) } if !reflect.DeepEqual(unmarshalledConfig, expectedConfig) { - t.Fatalf("Got unmarshalled config %q, expected %q", unmarshalledConfig, expectedConfig) + t.Fatalf("Got unmarshalled config %v, expected %v", unmarshalledConfig, expectedConfig) } rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport)