Skip to content

Commit

Permalink
Merge pull request #387 from roidelapluie/useragent
Browse files Browse the repository at this point in the history
Useragent for OAuth2
  • Loading branch information
roidelapluie committed Jul 8, 2022
2 parents d75e027 + db0284d commit cdc09f0
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 6 deletions.
44 changes: 42 additions & 2 deletions config/http_config.go
Expand Up @@ -372,6 +372,7 @@ type httpClientOptions struct {
keepAlivesEnabled bool
http2Enabled bool
idleConnTimeout time.Duration
userAgent string
}

// HTTPClientOption defines an option that can be applied to the HTTP client.
Expand Down Expand Up @@ -405,6 +406,13 @@ func WithIdleConnTimeout(timeout time.Duration) HTTPClientOption {
}
}

// WithUserAgent allows setting the user agent.
func WithUserAgent(ua string) HTTPClientOption {
return func(opts *httpClientOptions) {
opts.userAgent = ua
}
}

// NewClient returns a http.Client using the specified http.RoundTripper.
func newClient(rt http.RoundTripper) *http.Client {
return &http.Client{Transport: rt}
Expand Down Expand Up @@ -497,8 +505,12 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT
rt = NewBasicAuthRoundTripper(cfg.BasicAuth.Username, cfg.BasicAuth.Password, cfg.BasicAuth.PasswordFile, rt)
}

if opts.userAgent != "" {
rt = NewUserAgentRoundTripper(opts.userAgent, rt)
}

if cfg.OAuth2 != nil {
rt = NewOAuth2RoundTripper(cfg.OAuth2, rt)
rt = NewOAuth2RoundTripper(cfg.OAuth2, rt, &opts)
}
// Return a new configured RoundTripper.
return rt, nil
Expand Down Expand Up @@ -619,12 +631,14 @@ type oauth2RoundTripper struct {
next http.RoundTripper
secret string
mtx sync.RWMutex
opts *httpClientOptions
}

func NewOAuth2RoundTripper(config *OAuth2, next http.RoundTripper) http.RoundTripper {
func NewOAuth2RoundTripper(config *OAuth2, next http.RoundTripper, opts *httpClientOptions) http.RoundTripper {
return &oauth2RoundTripper{
config: config,
next: next,
opts: opts,
}
}

Expand Down Expand Up @@ -681,6 +695,10 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
}
}

if rt.opts.userAgent != "" {
t = NewUserAgentRoundTripper(rt.opts.userAgent, t)
}

ctx := context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{Transport: t})
tokenSource := config.TokenSource(ctx)

Expand Down Expand Up @@ -911,6 +929,28 @@ func (t *tlsRoundTripper) CloseIdleConnections() {
}
}

type userAgentRoundTripper struct {
userAgent string
rt http.RoundTripper
}

// NewUserAgentRoundTripper adds the user agent every request header.
func NewUserAgentRoundTripper(userAgent string, rt http.RoundTripper) http.RoundTripper {
return &userAgentRoundTripper{userAgent, rt}
}

func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
req = cloneRequest(req)
req.Header.Set("User-Agent", rt.userAgent)
return rt.rt.RoundTrip(req)
}

func (rt *userAgentRoundTripper) CloseIdleConnections() {
if ci, ok := rt.rt.(closeIdler); ok {
ci.CloseIdleConnections()
}
}

func (c HTTPClientConfig) String() string {
b, err := yaml.Marshal(c)
if err != nil {
Expand Down
51 changes: 47 additions & 4 deletions config/http_config_test.go
Expand Up @@ -1198,7 +1198,7 @@ client_secret: 2
scopes:
- A
- B
token_url: %s
token_url: %s/token
endpoint_params:
hi: hello
`, ts.URL)
Expand All @@ -1207,7 +1207,7 @@ endpoint_params:
ClientSecret: "2",
Scopes: []string{"A", "B"},
EndpointParams: map[string]string{"hi": "hello"},
TokenURL: ts.URL,
TokenURL: fmt.Sprintf("%s/token", ts.URL),
}

var unmarshalledConfig OAuth2
Expand All @@ -1219,7 +1219,7 @@ endpoint_params:
t.Fatalf("Got unmarshalled config %v, expected %v", unmarshalledConfig, expectedConfig)
}

rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport)
rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport, &defaultHTTPClientOptions)

client := http.Client{
Transport: rt,
Expand All @@ -1232,6 +1232,49 @@ endpoint_params:
}
}

func TestOAuth2UserAgent(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("User-Agent") != "myuseragent" {
t.Fatalf("Expected User-Agent header in oauth request to be 'myuseragent', got '%s'", r.Header.Get("User-Agent"))
}

res, _ := json.Marshal(oauth2TestServerResponse{
AccessToken: "12345",
TokenType: "Bearer",
})
w.Header().Add("Content-Type", "application/json")
_, _ = w.Write(res)
}))
defer ts.Close()

config := DefaultHTTPClientConfig
config.OAuth2 = &OAuth2{
ClientID: "1",
ClientSecret: "2",
Scopes: []string{"A", "B"},
EndpointParams: map[string]string{"hi": "hello"},
TokenURL: fmt.Sprintf("%s/token", ts.URL),
}

rt, err := NewRoundTripperFromConfig(config, "test_oauth2", WithUserAgent("myuseragent"))
if err != nil {
t.Fatal(err)
}

client := http.Client{
Transport: rt,
}
resp, err := client.Get(ts.URL)
if err != nil {
t.Fatal(err)
}

authorization := resp.Request.Header.Get("Authorization")
if authorization != "Bearer 12345" {
t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization)
}
}

func TestOAuth2WithFile(t *testing.T) {
var expectedAuth *string
var previousAuth string
Expand Down Expand Up @@ -1294,7 +1337,7 @@ endpoint_params:
t.Fatalf("Got unmarshalled config %v, expected %v", unmarshalledConfig, expectedConfig)
}

rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport)
rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport, &defaultHTTPClientOptions)

client := http.Client{
Transport: rt,
Expand Down

0 comments on commit cdc09f0

Please sign in to comment.