diff --git a/config/http_config.go b/config/http_config.go index 063edde0..2ce312f6 100644 --- a/config/http_config.go +++ b/config/http_config.go @@ -372,6 +372,7 @@ type httpClientOptions struct { keepAlivesEnabled bool http2Enabled bool idleConnTimeout time.Duration + userAgent string } // HTTPClientOption defines an option that can be applied to the HTTP client. @@ -405,6 +406,13 @@ func WithIdleConnTimeout(timeout time.Duration) HTTPClientOption { } } +// WithUserAgent allows setting the user agent. +func WithUserAgent(ua string) HTTPClientOption { + return func(opts *httpClientOptions) { + opts.userAgent = ua + } +} + // NewClient returns a http.Client using the specified http.RoundTripper. func newClient(rt http.RoundTripper) *http.Client { return &http.Client{Transport: rt} @@ -497,8 +505,12 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT rt = NewBasicAuthRoundTripper(cfg.BasicAuth.Username, cfg.BasicAuth.Password, cfg.BasicAuth.PasswordFile, rt) } + if opts.userAgent != "" { + rt = NewUserAgentRoundTripper(opts.userAgent, rt) + } + if cfg.OAuth2 != nil { - rt = NewOAuth2RoundTripper(cfg.OAuth2, rt) + rt = NewOAuth2RoundTripper(cfg.OAuth2, rt, &opts) } // Return a new configured RoundTripper. return rt, nil @@ -619,12 +631,14 @@ type oauth2RoundTripper struct { next http.RoundTripper secret string mtx sync.RWMutex + opts *httpClientOptions } -func NewOAuth2RoundTripper(config *OAuth2, next http.RoundTripper) http.RoundTripper { +func NewOAuth2RoundTripper(config *OAuth2, next http.RoundTripper, opts *httpClientOptions) http.RoundTripper { return &oauth2RoundTripper{ config: config, next: next, + opts: opts, } } @@ -681,6 +695,10 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro } } + if rt.opts.userAgent != "" { + t = NewUserAgentRoundTripper(rt.opts.userAgent, t) + } + ctx := context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{Transport: t}) tokenSource := config.TokenSource(ctx) @@ -911,6 +929,28 @@ func (t *tlsRoundTripper) CloseIdleConnections() { } } +type userAgentRoundTripper struct { + userAgent string + rt http.RoundTripper +} + +// NewUserAgentRoundTripper adds the user agent every request header. +func NewUserAgentRoundTripper(userAgent string, rt http.RoundTripper) http.RoundTripper { + return &userAgentRoundTripper{userAgent, rt} +} + +func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + req = cloneRequest(req) + req.Header.Set("User-Agent", rt.userAgent) + return rt.rt.RoundTrip(req) +} + +func (rt *userAgentRoundTripper) CloseIdleConnections() { + if ci, ok := rt.rt.(closeIdler); ok { + ci.CloseIdleConnections() + } +} + func (c HTTPClientConfig) String() string { b, err := yaml.Marshal(c) if err != nil { diff --git a/config/http_config_test.go b/config/http_config_test.go index 42cd851b..722c7fea 100644 --- a/config/http_config_test.go +++ b/config/http_config_test.go @@ -1198,7 +1198,7 @@ client_secret: 2 scopes: - A - B -token_url: %s +token_url: %s/token endpoint_params: hi: hello `, ts.URL) @@ -1207,7 +1207,7 @@ endpoint_params: ClientSecret: "2", Scopes: []string{"A", "B"}, EndpointParams: map[string]string{"hi": "hello"}, - TokenURL: ts.URL, + TokenURL: fmt.Sprintf("%s/token", ts.URL), } var unmarshalledConfig OAuth2 @@ -1219,7 +1219,7 @@ endpoint_params: t.Fatalf("Got unmarshalled config %v, expected %v", unmarshalledConfig, expectedConfig) } - rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport) + rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport, &defaultHTTPClientOptions) client := http.Client{ Transport: rt, @@ -1232,6 +1232,49 @@ endpoint_params: } } +func TestOAuth2UserAgent(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("User-Agent") != "myuseragent" { + t.Fatalf("Expected User-Agent header in oauth request to be 'myuseragent', got '%s'", r.Header.Get("User-Agent")) + } + + res, _ := json.Marshal(oauth2TestServerResponse{ + AccessToken: "12345", + TokenType: "Bearer", + }) + w.Header().Add("Content-Type", "application/json") + _, _ = w.Write(res) + })) + defer ts.Close() + + config := DefaultHTTPClientConfig + config.OAuth2 = &OAuth2{ + ClientID: "1", + ClientSecret: "2", + Scopes: []string{"A", "B"}, + EndpointParams: map[string]string{"hi": "hello"}, + TokenURL: fmt.Sprintf("%s/token", ts.URL), + } + + rt, err := NewRoundTripperFromConfig(config, "test_oauth2", WithUserAgent("myuseragent")) + if err != nil { + t.Fatal(err) + } + + client := http.Client{ + Transport: rt, + } + resp, err := client.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + + authorization := resp.Request.Header.Get("Authorization") + if authorization != "Bearer 12345" { + t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization) + } +} + func TestOAuth2WithFile(t *testing.T) { var expectedAuth *string var previousAuth string @@ -1294,7 +1337,7 @@ endpoint_params: t.Fatalf("Got unmarshalled config %v, expected %v", unmarshalledConfig, expectedConfig) } - rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport) + rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport, &defaultHTTPClientOptions) client := http.Client{ Transport: rt,