Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add User-Agent header to oauth2 requests #386

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
28 changes: 28 additions & 0 deletions config/http_config.go
Expand Up @@ -224,6 +224,8 @@ type OAuth2 struct {
ProxyURL URL `yaml:"proxy_url,omitempty" json:"proxy_url,omitempty"`
// TLSConfig is used to connect to the token URL.
TLSConfig TLSConfig `yaml:"tls_config,omitempty"`
// UserAgent is used to set a custom User-Agent http header while making the oauth request.
UserAgent string `yaml:"user_agent,omitempty" json:"user_agent,omitempty"`
}

// SetDirectory joins any relative file paths with dir.
Expand Down Expand Up @@ -681,6 +683,10 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
}
}

if rt.config.UserAgent != "" {
t = NewUserAgentRoundTripper(rt.config.UserAgent, t)
}

ctx := context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{Transport: t})
tokenSource := config.TokenSource(ctx)

Expand Down Expand Up @@ -911,6 +917,28 @@ func (t *tlsRoundTripper) CloseIdleConnections() {
}
}

type userAgentRoundTripper struct {
userAgent string
rt http.RoundTripper
}

// NewUserAgentRoundTripper adds the user agent every request header.
func NewUserAgentRoundTripper(userAgent string, rt http.RoundTripper) http.RoundTripper {
return &userAgentRoundTripper{userAgent, rt}
}

func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
req = cloneRequest(req)
req.Header.Set("User-Agent", rt.userAgent)
return rt.rt.RoundTrip(req)
}

func (rt *userAgentRoundTripper) CloseIdleConnections() {
if ci, ok := rt.rt.(closeIdler); ok {
ci.CloseIdleConnections()
}
}

func (c HTTPClientConfig) String() string {
b, err := yaml.Marshal(c)
if err != nil {
Expand Down
19 changes: 17 additions & 2 deletions config/http_config_test.go
Expand Up @@ -1183,6 +1183,12 @@ type oauth2TestServerResponse struct {

func TestOAuth2(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/token" {
if r.Header.Get("User-Agent") != "myuseragent" {
t.Fatalf("Expected User-Agent header in oauth request to be 'myuseragent', got '%s'", r.Header.Get("User-Agent"))
}
}

res, _ := json.Marshal(oauth2TestServerResponse{
AccessToken: "12345",
TokenType: "Bearer",
Expand All @@ -1198,7 +1204,8 @@ client_secret: 2
scopes:
- A
- B
token_url: %s
token_url: %s/token
user_agent: myuseragent
endpoint_params:
hi: hello
`, ts.URL)
Expand All @@ -1207,7 +1214,8 @@ endpoint_params:
ClientSecret: "2",
Scopes: []string{"A", "B"},
EndpointParams: map[string]string{"hi": "hello"},
TokenURL: ts.URL,
TokenURL: fmt.Sprintf("%s/token", ts.URL),
UserAgent: "myuseragent",
}

var unmarshalledConfig OAuth2
Expand Down Expand Up @@ -1488,3 +1496,10 @@ func TestOAuth2Proxy(t *testing.T) {
t.Errorf("Error loading OAuth2 client config: %v", err)
}
}

func TestOAuth2UserAgent(t *testing.T) {
_, _, err := LoadHTTPConfigFile("testdata/http.conf.oauth2-user-agent.good.yml")
if err != nil {
t.Errorf("Error loading OAuth2 client config: %v", err)
}
}
5 changes: 5 additions & 0 deletions config/testdata/http.conf.oauth2-user-agent.good.yml
@@ -0,0 +1,5 @@
oauth2:
client_id: "myclient"
client_secret: "mysecret"
token_url: "http://auth"
user_agent: "myuseragent"