Skip to content

Commit

Permalink
add User-Agent header to oauth2 requests
Browse files Browse the repository at this point in the history
  • Loading branch information
clayton-gonsalves committed Jun 28, 2022
1 parent 26d4974 commit ef6984c
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 2 deletions.
28 changes: 28 additions & 0 deletions config/http_config.go
Expand Up @@ -224,6 +224,8 @@ type OAuth2 struct {
ProxyURL URL `yaml:"proxy_url,omitempty" json:"proxy_url,omitempty"`
// TLSConfig is used to connect to the token URL.
TLSConfig TLSConfig `yaml:"tls_config,omitempty"`
// UserAgent is used to set a custom User-Agent http header while making the oauth request.
UserAgent string `yaml:"user_agent,omitempty" json:"user_agent,omitempty"`
}

// SetDirectory joins any relative file paths with dir.
Expand Down Expand Up @@ -681,6 +683,10 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
}
}

if rt.config.UserAgent != "" {
t = NewUserAgentRoundTripper(rt.config.UserAgent, t)
}

ctx := context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{Transport: t})
tokenSource := config.TokenSource(ctx)

Expand Down Expand Up @@ -911,6 +917,28 @@ func (t *tlsRoundTripper) CloseIdleConnections() {
}
}

type userAgentRoundTripper struct {
userAgent string
rt http.RoundTripper
}

// NewUserAgentRoundTripper adds the user agent every request header.
func NewUserAgentRoundTripper(userAgent string, rt http.RoundTripper) http.RoundTripper {
return &userAgentRoundTripper{userAgent, rt}
}

func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
req = cloneRequest(req)
req.Header.Set("User-Agent", rt.userAgent)
return rt.rt.RoundTrip(req)
}

func (rt *userAgentRoundTripper) CloseIdleConnections() {
if ci, ok := rt.rt.(closeIdler); ok {
ci.CloseIdleConnections()
}
}

func (c HTTPClientConfig) String() string {
b, err := yaml.Marshal(c)
if err != nil {
Expand Down
19 changes: 17 additions & 2 deletions config/http_config_test.go
Expand Up @@ -1183,6 +1183,12 @@ type oauth2TestServerResponse struct {

func TestOAuth2(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/token" {
if r.Header.Get("User-Agent") != "myuseragent" {
t.Fatalf("Expected User-Agent header in oauth request to be 'myuseragent', got '%s'", r.Header.Get("User-Agent"))
}
}

res, _ := json.Marshal(oauth2TestServerResponse{
AccessToken: "12345",
TokenType: "Bearer",
Expand All @@ -1198,7 +1204,8 @@ client_secret: 2
scopes:
- A
- B
token_url: %s
token_url: %s/token
user_agent: myuseragent
endpoint_params:
hi: hello
`, ts.URL)
Expand All @@ -1207,7 +1214,8 @@ endpoint_params:
ClientSecret: "2",
Scopes: []string{"A", "B"},
EndpointParams: map[string]string{"hi": "hello"},
TokenURL: ts.URL,
TokenURL: fmt.Sprintf("%s/token", ts.URL),
UserAgent: "myuseragent",
}

var unmarshalledConfig OAuth2
Expand Down Expand Up @@ -1488,3 +1496,10 @@ func TestOAuth2Proxy(t *testing.T) {
t.Errorf("Error loading OAuth2 client config: %v", err)
}
}

func TestOAuth2UserAgent(t *testing.T) {
_, _, err := LoadHTTPConfigFile("testdata/http.conf.oauth2-user-agent.good.yml")
if err != nil {
t.Errorf("Error loading OAuth2 client config: %v", err)
}
}
5 changes: 5 additions & 0 deletions config/testdata/http.conf.oauth2-user-agent.good.yml
@@ -0,0 +1,5 @@
oauth2:
client_id: "myclient"
client_secret: "mysecret"
token_url: "http://auth"
user_agent: "myuseragent"

0 comments on commit ef6984c

Please sign in to comment.