From ef6984c1fa55f7bc3bd14acd8a0678794ae1a3d8 Mon Sep 17 00:00:00 2001 From: clayton-gonsalves Date: Tue, 28 Jun 2022 16:09:51 +0530 Subject: [PATCH] add User-Agent header to oauth2 requests --- config/http_config.go | 28 +++++++++++++++++++ config/http_config_test.go | 19 +++++++++++-- .../http.conf.oauth2-user-agent.good.yml | 5 ++++ 3 files changed, 50 insertions(+), 2 deletions(-) create mode 100644 config/testdata/http.conf.oauth2-user-agent.good.yml diff --git a/config/http_config.go b/config/http_config.go index 063edde0..803f6a1c 100644 --- a/config/http_config.go +++ b/config/http_config.go @@ -224,6 +224,8 @@ type OAuth2 struct { ProxyURL URL `yaml:"proxy_url,omitempty" json:"proxy_url,omitempty"` // TLSConfig is used to connect to the token URL. TLSConfig TLSConfig `yaml:"tls_config,omitempty"` + // UserAgent is used to set a custom User-Agent http header while making the oauth request. + UserAgent string `yaml:"user_agent,omitempty" json:"user_agent,omitempty"` } // SetDirectory joins any relative file paths with dir. @@ -681,6 +683,10 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro } } + if rt.config.UserAgent != "" { + t = NewUserAgentRoundTripper(rt.config.UserAgent, t) + } + ctx := context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{Transport: t}) tokenSource := config.TokenSource(ctx) @@ -911,6 +917,28 @@ func (t *tlsRoundTripper) CloseIdleConnections() { } } +type userAgentRoundTripper struct { + userAgent string + rt http.RoundTripper +} + +// NewUserAgentRoundTripper adds the user agent every request header. +func NewUserAgentRoundTripper(userAgent string, rt http.RoundTripper) http.RoundTripper { + return &userAgentRoundTripper{userAgent, rt} +} + +func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + req = cloneRequest(req) + req.Header.Set("User-Agent", rt.userAgent) + return rt.rt.RoundTrip(req) +} + +func (rt *userAgentRoundTripper) CloseIdleConnections() { + if ci, ok := rt.rt.(closeIdler); ok { + ci.CloseIdleConnections() + } +} + func (c HTTPClientConfig) String() string { b, err := yaml.Marshal(c) if err != nil { diff --git a/config/http_config_test.go b/config/http_config_test.go index 06eb6d04..9c9c6237 100644 --- a/config/http_config_test.go +++ b/config/http_config_test.go @@ -1183,6 +1183,12 @@ type oauth2TestServerResponse struct { func TestOAuth2(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/token" { + if r.Header.Get("User-Agent") != "myuseragent" { + t.Fatalf("Expected User-Agent header in oauth request to be 'myuseragent', got '%s'", r.Header.Get("User-Agent")) + } + } + res, _ := json.Marshal(oauth2TestServerResponse{ AccessToken: "12345", TokenType: "Bearer", @@ -1198,7 +1204,8 @@ client_secret: 2 scopes: - A - B -token_url: %s +token_url: %s/token +user_agent: myuseragent endpoint_params: hi: hello `, ts.URL) @@ -1207,7 +1214,8 @@ endpoint_params: ClientSecret: "2", Scopes: []string{"A", "B"}, EndpointParams: map[string]string{"hi": "hello"}, - TokenURL: ts.URL, + TokenURL: fmt.Sprintf("%s/token", ts.URL), + UserAgent: "myuseragent", } var unmarshalledConfig OAuth2 @@ -1488,3 +1496,10 @@ func TestOAuth2Proxy(t *testing.T) { t.Errorf("Error loading OAuth2 client config: %v", err) } } + +func TestOAuth2UserAgent(t *testing.T) { + _, _, err := LoadHTTPConfigFile("testdata/http.conf.oauth2-user-agent.good.yml") + if err != nil { + t.Errorf("Error loading OAuth2 client config: %v", err) + } +} diff --git a/config/testdata/http.conf.oauth2-user-agent.good.yml b/config/testdata/http.conf.oauth2-user-agent.good.yml new file mode 100644 index 00000000..a0a407f2 --- /dev/null +++ b/config/testdata/http.conf.oauth2-user-agent.good.yml @@ -0,0 +1,5 @@ +oauth2: + client_id: "myclient" + client_secret: "mysecret" + token_url: "http://auth" + user_agent: "myuseragent"