From f3006cd8ac6bbc7aba4fa0606334f7a73b8cbd25 Mon Sep 17 00:00:00 2001 From: Daniel Hrabovcak Date: Thu, 30 Nov 2023 10:54:08 -0500 Subject: [PATCH] Add support for secret refs Signed-off-by: Daniel Hrabovcak --- config/http_config.go | 286 ++++++++++++++---- config/http_config_test.go | 92 +++++- config/testdata/http.conf.basic-auth.ref.yaml | 3 + 3 files changed, 303 insertions(+), 78 deletions(-) create mode 100644 config/testdata/http.conf.basic-auth.ref.yaml diff --git a/config/http_config.go b/config/http_config.go index b52786f2..26db43f0 100644 --- a/config/http_config.go +++ b/config/http_config.go @@ -20,6 +20,7 @@ import ( "crypto/tls" "crypto/x509" "encoding/json" + "errors" "fmt" "net" "net/http" @@ -131,8 +132,10 @@ func (tv *TLSVersion) String() string { type BasicAuth struct { Username string `yaml:"username" json:"username"` UsernameFile string `yaml:"username_file,omitempty" json:"username_file,omitempty"` + UsernameRef string `yaml:"username_ref,omitempty" json:"username_ref,omitempty"` Password Secret `yaml:"password,omitempty" json:"password,omitempty"` PasswordFile string `yaml:"password_file,omitempty" json:"password_file,omitempty"` + PasswordRef string `yaml:"password_ref,omitempty" json:"password_ref,omitempty"` } // SetDirectory joins any relative file paths with dir. @@ -149,6 +152,7 @@ type Authorization struct { Type string `yaml:"type,omitempty" json:"type,omitempty"` Credentials Secret `yaml:"credentials,omitempty" json:"credentials,omitempty"` CredentialsFile string `yaml:"credentials_file,omitempty" json:"credentials_file,omitempty"` + CredentialsRef string `yaml:"credentials_ref,omitempty" json:"credentials_ref,omitempty"` } // SetDirectory joins any relative file paths with dir. @@ -228,6 +232,7 @@ type OAuth2 struct { ClientID string `yaml:"client_id" json:"client_id"` ClientSecret Secret `yaml:"client_secret" json:"client_secret"` ClientSecretFile string `yaml:"client_secret_file" json:"client_secret_file"` + ClientSecretRef string `yaml:"client_secret_ref" json:"client_secret_ref"` Scopes []string `yaml:"scopes,omitempty" json:"scopes,omitempty"` TokenURL string `yaml:"token_url" json:"token_url"` EndpointParams map[string]string `yaml:"endpoint_params,omitempty" json:"endpoint_params,omitempty"` @@ -326,6 +331,18 @@ func (c *HTTPClientConfig) SetDirectory(dir string) { c.BearerTokenFile = JoinDir(dir, c.BearerTokenFile) } +// nonZeroCount returns the amount of values that are non-zero. +func nonZeroCount[T comparable](values ...T) int { + count := 0 + var zero T + for _, value := range values { + if value != zero { + count += 1 + } + } + return count +} + // Validate validates the HTTPClientConfig to check only one of BearerToken, // BasicAuth and BearerTokenFile is configured. It also validates that ProxyURL // is set if ProxyConnectHeader is set. @@ -337,17 +354,17 @@ func (c *HTTPClientConfig) Validate() error { if (c.BasicAuth != nil || c.OAuth2 != nil) && (len(c.BearerToken) > 0 || len(c.BearerTokenFile) > 0) { return fmt.Errorf("at most one of basic_auth, oauth2, bearer_token & bearer_token_file must be configured") } - if c.BasicAuth != nil && (string(c.BasicAuth.Username) != "" && c.BasicAuth.UsernameFile != "") { - return fmt.Errorf("at most one of basic_auth username & username_file must be configured") + if c.BasicAuth != nil && nonZeroCount(string(c.BasicAuth.Username) != "", c.BasicAuth.UsernameFile != "", c.BasicAuth.UsernameRef != "") > 1 { + return fmt.Errorf("at most one of basic_auth username, username_file & username_ref must be configured") } - if c.BasicAuth != nil && (string(c.BasicAuth.Password) != "" && c.BasicAuth.PasswordFile != "") { - return fmt.Errorf("at most one of basic_auth password & password_file must be configured") + if c.BasicAuth != nil && nonZeroCount(string(c.BasicAuth.Password) != "", c.BasicAuth.PasswordFile != "", c.BasicAuth.PasswordRef != "") > 1 { + return fmt.Errorf("at most one of basic_auth password, password_file & password_ref must be configured") } if c.Authorization != nil { if len(c.BearerToken) > 0 || len(c.BearerTokenFile) > 0 { return fmt.Errorf("authorization is not compatible with bearer_token & bearer_token_file") } - if string(c.Authorization.Credentials) != "" && c.Authorization.CredentialsFile != "" { + if nonZeroCount(string(c.Authorization.Credentials) != "", c.Authorization.CredentialsFile != "", c.Authorization.CredentialsRef != "") > 1 { return fmt.Errorf("at most one of authorization credentials & credentials_file must be configured") } c.Authorization.Type = strings.TrimSpace(c.Authorization.Type) @@ -382,8 +399,8 @@ func (c *HTTPClientConfig) Validate() error { if len(c.OAuth2.TokenURL) == 0 { return fmt.Errorf("oauth2 token_url must be configured") } - if len(c.OAuth2.ClientSecret) > 0 && len(c.OAuth2.ClientSecretFile) > 0 { - return fmt.Errorf("at most one of oauth2 client_secret & client_secret_file must be configured") + if nonZeroCount(len(c.OAuth2.ClientSecret) > 0, len(c.OAuth2.ClientSecretFile) > 0, len(c.OAuth2.ClientSecretRef) > 0) > 1 { + return fmt.Errorf("at most one of oauth2 client_secret, client_secret_file & client_secret_ref must be configured") } } if err := c.ProxyConfig.Validate(); err != nil { @@ -429,6 +446,7 @@ type httpClientOptions struct { idleConnTimeout time.Duration userAgent string host string + secretManager SecretManager } // HTTPClientOption defines an option that can be applied to the HTTP client. @@ -476,6 +494,13 @@ func WithHost(host string) HTTPClientOption { } } +// WithSecretManager allows setting the secret manager. +func WithSecretManager(manager SecretManager) HTTPClientOption { + return func(opts *httpClientOptions) { + opts.secretManager = manager + } +} + // NewClient returns a http.Client using the specified http.RoundTripper. func newClient(rt http.RoundTripper) *http.Client { return &http.Client{Transport: rt} @@ -502,6 +527,13 @@ func NewClientFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HTTPClie // given config.HTTPClientConfig and config.HTTPClientOption. // The name is used as go-conntrack metric label. func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HTTPClientOption) (http.RoundTripper, error) { + return NewRoundTripperFromConfigWithContext(context.Background(), cfg, name, optFuncs...) +} + +// NewRoundTripperFromConfig returns a new HTTP RoundTripper configured for the +// given config.HTTPClientConfig and config.HTTPClientOption. +// The name is used as go-conntrack metric label. +func NewRoundTripperFromConfigWithContext(ctx context.Context, cfg HTTPClientConfig, name string, optFuncs ...HTTPClientOption) (http.RoundTripper, error) { opts := defaultHTTPClientOptions for _, f := range optFuncs { f(&opts) @@ -553,16 +585,32 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT // If a authorization_credentials is provided, create a round tripper that will set the // Authorization header correctly on each request. if cfg.Authorization != nil { - rt = NewAuthorizationCredentialsRoundTripper(cfg.Authorization.Type, secretFrom(cfg.Authorization.Credentials, cfg.Authorization.CredentialsFile), rt) + credentialsSecret, err := secretFrom(opts.secretManager, cfg.Authorization.Credentials, cfg.Authorization.CredentialsFile, cfg.Authorization.CredentialsRef) + if err != nil { + return nil, fmt.Errorf("unable to use credentials: %w", err) + } + rt = NewAuthorizationCredentialsRoundTripper(cfg.Authorization.Type, credentialsSecret, rt) } // Backwards compatibility, be nice with importers who would not have // called Validate(). if len(cfg.BearerToken) > 0 || len(cfg.BearerTokenFile) > 0 { - rt = NewAuthorizationCredentialsRoundTripper("Bearer", secretFrom(cfg.BearerToken, cfg.BearerTokenFile), rt) + bearerSecret, err := secretFrom(opts.secretManager, cfg.BearerToken, cfg.BearerTokenFile, "") + if err != nil { + return nil, fmt.Errorf("unable to use bearer token: %w", err) + } + rt = NewAuthorizationCredentialsRoundTripper("Bearer", bearerSecret, rt) } if cfg.BasicAuth != nil { - rt = NewBasicAuthRoundTripper(secretFrom(Secret(cfg.BasicAuth.Username), cfg.BasicAuth.UsernameFile), secretFrom(cfg.BasicAuth.Password, cfg.BasicAuth.PasswordFile), rt) + usernameSecret, err := secretFrom(opts.secretManager, Secret(cfg.BasicAuth.Username), cfg.BasicAuth.UsernameFile, cfg.BasicAuth.UsernameRef) + if err != nil { + return nil, fmt.Errorf("unable to use username: %w", err) + } + passwordSecret, err := secretFrom(opts.secretManager, cfg.BasicAuth.Password, cfg.BasicAuth.PasswordFile, cfg.BasicAuth.PasswordRef) + if err != nil { + return nil, fmt.Errorf("unable to use password: %w", err) + } + rt = NewBasicAuthRoundTripper(usernameSecret, passwordSecret, rt) } if cfg.OAuth2 != nil { @@ -581,21 +629,28 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT return rt, nil } - tlsConfig, err := NewTLSConfig(&cfg.TLSConfig) + tlsConfig, err := NewTLSConfig(&cfg.TLSConfig, withSecretManager(opts.secretManager)) if err != nil { return nil, err } - tlsSettings := cfg.TLSConfig.roundTripperSettings() + tlsSettings, err := cfg.TLSConfig.roundTripperSettings(opts.secretManager) + if err != nil { + return nil, err + } if tlsSettings.CA == nil || tlsSettings.CA.immutable() { // No need for a RoundTripper that reloads the CA file automatically. return newRT(tlsConfig) } - return NewTLSRoundTripper(tlsConfig, tlsSettings, newRT) + return NewTLSRoundTripperWithContext(ctx, tlsConfig, tlsSettings, newRT) +} + +type SecretManager interface { + Fetch(ctx context.Context, secretRef string) (string, error) } type secret interface { - fetch() (string, error) + fetch(ctx context.Context) (string, error) description() string immutable() bool } @@ -604,7 +659,7 @@ type inlineSecret struct { text string } -func (s *inlineSecret) fetch() (string, error) { +func (s *inlineSecret) fetch(ctx context.Context) (string, error) { return s.text, nil } @@ -620,7 +675,7 @@ type fileSecret struct { file string } -func (s *fileSecret) fetch() (string, error) { +func (s *fileSecret) fetch(ctx context.Context) (string, error) { fileBytes, err := os.ReadFile(s.file) if err != nil { return "", fmt.Errorf("unable to read file %s: %w", s.file, err) @@ -636,18 +691,44 @@ func (s *fileSecret) immutable() bool { return false } -func secretFrom(text Secret, file string) secret { +type refSecret struct { + ref string + manager SecretManager +} + +func (s *refSecret) fetch(ctx context.Context) (string, error) { + return s.manager.Fetch(ctx, s.ref) +} + +func (s *refSecret) description() string { + return fmt.Sprintf("ref %s", s.ref) +} + +func (s *refSecret) immutable() bool { + return false +} + +func secretFrom(secretManager SecretManager, text Secret, file, ref string) (secret, error) { if text != "" { return &inlineSecret{ text: string(text), - } + }, nil } if file != "" { return &fileSecret{ file: file, + }, nil + } + if ref != "" { + if secretManager == nil { + return nil, errors.New("cannot use secret ref without manager") } + return &refSecret{ + ref: ref, + manager: secretManager, + }, nil } - return nil + return nil, nil } type authorizationCredentialsRoundTripper struct { @@ -668,7 +749,7 @@ func (rt *authorizationCredentialsRoundTripper) RoundTrip(req *http.Request) (*h var authCredentials string if rt.authCredentials != nil { var err error - authCredentials, err = rt.authCredentials.fetch() + authCredentials, err = rt.authCredentials.fetch(req.Context()) if err != nil { return nil, fmt.Errorf("unable to read authorization credentials: %w", err) } @@ -707,14 +788,14 @@ func (rt *basicAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, e var password string if rt.username != nil { var err error - username, err = rt.username.fetch() + username, err = rt.username.fetch(req.Context()) if err != nil { return nil, fmt.Errorf("unable to read basic auth username: %w", err) } } if rt.password != nil { var err error - password, err = rt.password.fetch() + password, err = rt.password.fetch(req.Context()) if err != nil { return nil, fmt.Errorf("unable to read basic auth password: %w", err) } @@ -731,22 +812,20 @@ func (rt *basicAuthRoundTripper) CloseIdleConnections() { } type oauth2RoundTripper struct { - config *OAuth2 - clientSecret secret - rt http.RoundTripper - next http.RoundTripper - secret string - mtx sync.RWMutex - opts *httpClientOptions - client *http.Client + config *OAuth2 + rt http.RoundTripper + next http.RoundTripper + secret string + mtx sync.RWMutex + opts *httpClientOptions + client *http.Client } func NewOAuth2RoundTripper(config *OAuth2, next http.RoundTripper, opts *httpClientOptions) http.RoundTripper { return &oauth2RoundTripper{ - config: config, - clientSecret: secretFrom(config.ClientSecret, config.ClientSecretFile), - next: next, - opts: opts, + config: config, + next: next, + opts: opts, } } @@ -756,9 +835,13 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro changed bool ) - if rt.clientSecret != nil { + clientSecret, err := secretFrom(rt.opts.secretManager, rt.config.ClientSecret, rt.config.ClientSecretFile, rt.config.ClientSecretRef) + if err != nil { + return nil, fmt.Errorf("unable to use client secret: %w", err) + } + if clientSecret != nil { var err error - secret, err = rt.clientSecret.fetch() + secret, err = clientSecret.fetch(req.Context()) if err != nil { return nil, fmt.Errorf("unable to read oauth2 client secret: %w", err) } @@ -776,7 +859,7 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro EndpointParams: mapToValues(rt.config.EndpointParams), } - tlsConfig, err := NewTLSConfig(&rt.config.TLSConfig) + tlsConfig, err := NewTLSConfig(&rt.config.TLSConfig, withSecretManager(rt.opts.secretManager)) if err != nil { return nil, err } @@ -796,11 +879,14 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro } var t http.RoundTripper - tlsSettings := rt.config.TLSConfig.roundTripperSettings() + tlsSettings, err := rt.config.TLSConfig.roundTripperSettings(rt.opts.secretManager) + if err != nil { + return nil, err + } if tlsSettings.CA == nil || tlsSettings.CA.immutable() { t, _ = tlsTransport(tlsConfig) } else { - t, err = NewTLSRoundTripper(tlsConfig, tlsSettings, tlsTransport) + t, err = NewTLSRoundTripperWithContext(req.Context(), tlsConfig, tlsSettings, tlsTransport) if err != nil { return nil, err } @@ -865,8 +951,32 @@ func cloneRequest(r *http.Request) *http.Request { return r2 } +type tlsConfigOptions struct { + secretManager SecretManager +} + +// TLSConfigOption defines an option that can be applied to the HTTP client. +type TLSConfigOption func(options *tlsConfigOptions) + +// WithSecretManager allows setting the secret manager. +func withSecretManager(manager SecretManager) TLSConfigOption { + return func(opts *tlsConfigOptions) { + opts.secretManager = manager + } +} + // NewTLSConfig creates a new tls.Config from the given TLSConfig. -func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) { +func NewTLSConfig(cfg *TLSConfig, optFuncs ...TLSConfigOption) (*tls.Config, error) { + return NewTLSConfigWithContext(context.Background(), cfg, optFuncs...) +} + +// NewTLSConfig creates a new tls.Config from the given TLSConfig. +func NewTLSConfigWithContext(ctx context.Context, cfg *TLSConfig, optFuncs ...TLSConfigOption) (*tls.Config, error) { + opts := tlsConfigOptions{} + for _, f := range optFuncs { + f(&opts) + } + if err := cfg.Validate(); err != nil { return nil, err } @@ -885,9 +995,12 @@ func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) { // If a CA cert is provided then let's read it in so we can validate the // scrape target's certificate properly. - caSecret := secretFrom(Secret(cfg.CA), cfg.CAFile) + caSecret, err := secretFrom(opts.secretManager, Secret(cfg.CA), cfg.CAFile, cfg.CARef) + if err != nil { + return nil, fmt.Errorf("unable to use CA cert: %w", err) + } if caSecret != nil { - ca, err := caSecret.fetch() + ca, err := caSecret.fetch(ctx) if err != nil { return nil, fmt.Errorf("unable to read CA cert: %w", err) } @@ -903,10 +1016,16 @@ func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) { // If a client cert & key is provided then configure TLS config accordingly. if cfg.usingClientCert() && cfg.usingClientKey() { // Verify that client cert and key are valid. - if _, err := cfg.getClientCertificate(nil); err != nil { + if _, err := cfg.getClientCertificate(ctx, opts.secretManager); err != nil { return nil, err } - tlsConfig.GetClientCertificate = cfg.getClientCertificate + tlsConfig.GetClientCertificate = func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { + var ctx context.Context + if cri != nil { + ctx = cri.Context() + } + return cfg.getClientCertificate(ctx, opts.secretManager) + } } return tlsConfig, nil @@ -926,6 +1045,12 @@ type TLSConfig struct { CertFile string `yaml:"cert_file,omitempty" json:"cert_file,omitempty"` // The client key file for the targets. KeyFile string `yaml:"key_file,omitempty" json:"key_file,omitempty"` + // The CA cert to use for the targets. + CARef string `yaml:"ca_ref,omitempty" json:"ca_ref,omitempty"` + // The client cert for the targets. + CertRef string `yaml:"cert_ref,omitempty" json:"cert_ref,omitempty"` + // The client key for the targets. + KeyRef string `yaml:"key_ref,omitempty" json:"key_ref,omitempty"` // Used to verify the hostname for the targets. ServerName string `yaml:"server_name,omitempty" json:"server_name,omitempty"` // Disable target certificate validation. @@ -959,13 +1084,13 @@ func (c *TLSConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { // file-based fields for the TLS CA, client certificate, and client key are // used. func (c *TLSConfig) Validate() error { - if len(c.CA) > 0 && len(c.CAFile) > 0 { - return fmt.Errorf("at most one of ca and ca_file must be configured") + if nonZeroCount(len(c.CA) > 0, len(c.CAFile) > 0, len(c.CARef) > 0) > 1 { + return fmt.Errorf("at most one of ca, ca_file & ca_ref must be configured") } - if len(c.Cert) > 0 && len(c.CertFile) > 0 { - return fmt.Errorf("at most one of cert and cert_file must be configured") + if nonZeroCount(len(c.Cert) > 0, len(c.CertFile) > 0, len(c.CertRef) > 0) > 1 { + return fmt.Errorf("at most one of cert, cert_file & cert_ref must be configured") } - if len(c.Key) > 0 && len(c.KeyFile) > 0 { + if nonZeroCount(len(c.Key) > 0, len(c.KeyFile) > 0, len(c.KeyRef) > 0) > 1 { return fmt.Errorf("at most one of key and key_file must be configured") } @@ -979,39 +1104,57 @@ func (c *TLSConfig) Validate() error { } func (c *TLSConfig) usingClientCert() bool { - return len(c.Cert) > 0 || len(c.CertFile) > 0 + return len(c.Cert) > 0 || len(c.CertFile) > 0 || len(c.CertRef) > 0 } func (c *TLSConfig) usingClientKey() bool { - return len(c.Key) > 0 || len(c.KeyFile) > 0 + return len(c.Key) > 0 || len(c.KeyFile) > 0 || len(c.KeyRef) > 0 } -func (c *TLSConfig) roundTripperSettings() TLSRoundTripperSettings { - return TLSRoundTripperSettings{ - CA: secretFrom(Secret(c.CA), c.CAFile), - Cert: secretFrom(Secret(c.Cert), c.CertFile), - Key: secretFrom(c.Key, c.KeyFile), +func (c *TLSConfig) roundTripperSettings(secretManager SecretManager) (TLSRoundTripperSettings, error) { + ca, err := secretFrom(secretManager, Secret(c.CA), c.CAFile, c.CARef) + if err != nil { + return TLSRoundTripperSettings{}, err } + cert, err := secretFrom(secretManager, Secret(c.Cert), c.CertFile, c.CertRef) + if err != nil { + return TLSRoundTripperSettings{}, err + } + key, err := secretFrom(secretManager, c.Key, c.KeyFile, c.KeyRef) + if err != nil { + return TLSRoundTripperSettings{}, err + } + return TLSRoundTripperSettings{ + CA: ca, + Cert: cert, + Key: key, + }, nil } // getClientCertificate reads the pair of client cert and key from disk and returns a tls.Certificate. -func (c *TLSConfig) getClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) { +func (c *TLSConfig) getClientCertificate(ctx context.Context, secretManager SecretManager) (*tls.Certificate, error) { var ( certData, keyData string err error ) - certSecret := secretFrom(Secret(c.Cert), c.CertFile) + certSecret, err := secretFrom(secretManager, Secret(c.Cert), c.CertFile, c.CertRef) + if err != nil { + return nil, fmt.Errorf("unable to use client cert: %w", err) + } if certSecret != nil { - certData, err = certSecret.fetch() + certData, err = certSecret.fetch(ctx) if err != nil { return nil, fmt.Errorf("unable to read specified client cert: %w", err) } } - keySecret := secretFrom(Secret(c.Key), c.KeyFile) + keySecret, err := secretFrom(secretManager, Secret(c.Key), c.KeyFile, c.KeyRef) + if err != nil { + return nil, fmt.Errorf("unable to use client key: %w", err) + } if keySecret != nil { - keyData, err = keySecret.fetch() + keyData, err = keySecret.fetch(ctx) if err != nil { return nil, fmt.Errorf("unable to read specified client key: %w", err) } @@ -1063,6 +1206,15 @@ func NewTLSRoundTripper( cfg *tls.Config, settings TLSRoundTripperSettings, newRT func(*tls.Config) (http.RoundTripper, error), +) (http.RoundTripper, error) { + return NewTLSRoundTripperWithContext(context.Background(), cfg, settings, newRT) +} + +func NewTLSRoundTripperWithContext( + ctx context.Context, + cfg *tls.Config, + settings TLSRoundTripperSettings, + newRT func(*tls.Config) (http.RoundTripper, error), ) (http.RoundTripper, error) { t := &tlsRoundTripper{ settings: settings, @@ -1075,7 +1227,7 @@ func NewTLSRoundTripper( return nil, err } t.rt = rt - _, t.hashCAData, t.hashCertData, t.hashKeyData, err = t.getTLSDataWithHash() + _, t.hashCAData, t.hashCertData, t.hashKeyData, err = t.getTLSDataWithHash(ctx) if err != nil { return nil, err } @@ -1083,11 +1235,11 @@ func NewTLSRoundTripper( return t, nil } -func (t *tlsRoundTripper) getTLSDataWithHash() ([]byte, []byte, []byte, []byte, error) { +func (t *tlsRoundTripper) getTLSDataWithHash(ctx context.Context) ([]byte, []byte, []byte, []byte, error) { var caBytes, certBytes, keyBytes []byte if t.settings.CA != nil { - ca, err := t.settings.CA.fetch() + ca, err := t.settings.CA.fetch(ctx) if err != nil { return nil, nil, nil, nil, fmt.Errorf("unable to read CA cert: %w", err) } @@ -1095,7 +1247,7 @@ func (t *tlsRoundTripper) getTLSDataWithHash() ([]byte, []byte, []byte, []byte, } if t.settings.Cert != nil { - cert, err := t.settings.Cert.fetch() + cert, err := t.settings.Cert.fetch(ctx) if err != nil { return nil, nil, nil, nil, fmt.Errorf("unable to read client cert: %w", err) } @@ -1103,7 +1255,7 @@ func (t *tlsRoundTripper) getTLSDataWithHash() ([]byte, []byte, []byte, []byte, } if t.settings.Key != nil { - key, err := t.settings.Key.fetch() + key, err := t.settings.Key.fetch(ctx) if err != nil { return nil, nil, nil, nil, fmt.Errorf("unable to read client key: %w", err) } @@ -1127,7 +1279,7 @@ func (t *tlsRoundTripper) getTLSDataWithHash() ([]byte, []byte, []byte, []byte, // RoundTrip implements the http.RoundTrip interface. func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - caData, caHash, certHash, keyHash, err := t.getTLSDataWithHash() + caData, caHash, certHash, keyHash, err := t.getTLSDataWithHash(req.Context()) if err != nil { return nil, err } diff --git a/config/http_config_test.go b/config/http_config_test.go index d2c7503e..101898fe 100644 --- a/config/http_config_test.go +++ b/config/http_config_test.go @@ -81,11 +81,11 @@ var invalidHTTPClientConfigs = []struct { }, { httpClientConfigFile: "testdata/http.conf.basic-auth.too-much.bad.yaml", - errMsg: "at most one of basic_auth password & password_file must be configured", + errMsg: "at most one of basic_auth password, password_file & password_ref must be configured", }, { httpClientConfigFile: "testdata/http.conf.basic-auth.bad-username.yaml", - errMsg: "at most one of basic_auth username & username_file must be configured", + errMsg: "at most one of basic_auth username, username_file & username_ref must be configured", }, { httpClientConfigFile: "testdata/http.conf.mix-bearer-and-creds.bad.yaml", @@ -109,7 +109,7 @@ var invalidHTTPClientConfigs = []struct { }, { httpClientConfigFile: "testdata/http.conf.oauth2-secret-and-file-set.bad.yml", - errMsg: "at most one of oauth2 client_secret & client_secret_file must be configured", + errMsg: "at most one of oauth2 client_secret, client_secret_file & client_secret_ref must be configured", }, { httpClientConfigFile: "testdata/http.conf.oauth2-no-client-id.bad.yaml", @@ -888,7 +888,7 @@ func TestTLSConfigInvalidCA(t *testing.T) { ServerName: "", InsecureSkipVerify: false, }, - errorMessage: "at most one of cert and cert_file must be configured", + errorMessage: "at most one of cert, cert_file & cert_ref must be configured", }, { configTLSConfig: TLSConfig{ @@ -930,7 +930,7 @@ func TestBasicAuthNoPassword(t *testing.T) { t.Fatalf("Error casting to basic auth transport, %v", client.Transport) } - if username, _ := rt.username.fetch(); username != "user" { + if username, _ := rt.username.fetch(context.Background()); username != "user" { t.Errorf("Bad HTTP client username: %s", username) } if rt.password != nil { @@ -956,7 +956,7 @@ func TestBasicAuthNoUsername(t *testing.T) { if rt.username != nil { t.Errorf("Got unexpected username") } - if password, _ := rt.password.fetch(); password != "secret" { + if password, _ := rt.password.fetch(context.Background()); password != "secret" { t.Errorf("Unexpected HTTP client password: %s", password) } } @@ -976,14 +976,84 @@ func TestBasicAuthPasswordFile(t *testing.T) { t.Fatalf("Error casting to basic auth transport, %v", client.Transport) } - if username, _ := rt.username.fetch(); username != "user" { + if username, _ := rt.username.fetch(context.Background()); username != "user" { t.Errorf("Bad HTTP client username: %s", username) } - if password, _ := rt.password.fetch(); password != "foobar" { + if password, _ := rt.password.fetch(context.Background()); password != "foobar" { t.Errorf("Bad HTTP client password: %s", password) } } +type secretManager struct { + data map[string]string +} + +func (m *secretManager) Fetch(ctx context.Context, secretRef string) (string, error) { + secretData, ok := m.data[secretRef] + if !ok { + return "", fmt.Errorf("unknown secret %s", secretRef) + } + return secretData, nil +} + +func TestBasicAuthSecretManager(t *testing.T) { + cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.basic-auth.ref.yaml") + if err != nil { + t.Fatalf("Error loading HTTP client config: %v", err) + } + manager := secretManager{ + data: map[string]string{ + "admin": "user", + "pass": "foobar", + }, + } + client, err := NewClientFromConfig(*cfg, "test", WithSecretManager(&manager)) + if err != nil { + t.Fatalf("Error creating HTTP Client: %v", err) + } + + rt, ok := client.Transport.(*basicAuthRoundTripper) + if !ok { + t.Fatalf("Error casting to basic auth transport, %v", client.Transport) + } + + if username, _ := rt.username.fetch(context.Background()); username != "user" { + t.Errorf("Bad HTTP client username: %s", username) + } + if password, _ := rt.password.fetch(context.Background()); password != "foobar" { + t.Errorf("Bad HTTP client password: %s", password) + } +} + +func TestBasicAuthSecretManagerNotFound(t *testing.T) { + cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.basic-auth.ref.yaml") + if err != nil { + t.Fatalf("Error loading HTTP client config: %v", err) + } + manager := secretManager{ + data: map[string]string{ + "admin1": "user", + "foobar": "pass", + }, + } + client, err := NewClientFromConfig(*cfg, "test", WithSecretManager(&manager)) + if err != nil { + t.Fatalf("Error creating HTTP Client: %v", err) + } + + rt, ok := client.Transport.(*basicAuthRoundTripper) + if !ok { + t.Fatalf("Error casting to basic auth transport, %v", client.Transport) + } + + if _, err := rt.username.fetch(context.Background()); !strings.Contains(err.Error(), "unknown secret admin") { + t.Errorf("Unexpected error message: %s", err) + } + if _, err := rt.password.fetch(context.Background()); !strings.Contains(err.Error(), "unknown secret pass") { + t.Errorf("Unexpected error message: %s", err) + } +} + func TestBasicUsernameFile(t *testing.T) { cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.basic-auth.username-file.good.yaml") if err != nil { @@ -999,10 +1069,10 @@ func TestBasicUsernameFile(t *testing.T) { t.Fatalf("Error casting to basic auth transport, %v", client.Transport) } - if username, _ := rt.username.fetch(); username != "testuser" { + if username, _ := rt.username.fetch(context.Background()); username != "testuser" { t.Errorf("Bad HTTP client username: %s", username) } - if password, _ := rt.password.fetch(); password != "foobar" { + if password, _ := rt.password.fetch(context.Background()); password != "foobar" { t.Errorf("Bad HTTP client passwordFile: %s", password) } } @@ -1392,7 +1462,7 @@ func TestTLSRoundTripperRaces(t *testing.T) { func TestHideHTTPClientConfigSecrets(t *testing.T) { c, _, err := LoadHTTPConfigFile("testdata/http.conf.good.yml") if err != nil { - t.Errorf("Error parsing %s: %s", "testdata/http.conf.good.yml", err) + t.Fatalf("Error parsing %s: %s", "testdata/http.conf.good.yml", err) } // String method must not reveal authentication credentials. diff --git a/config/testdata/http.conf.basic-auth.ref.yaml b/config/testdata/http.conf.basic-auth.ref.yaml new file mode 100644 index 00000000..68a7b3f8 --- /dev/null +++ b/config/testdata/http.conf.basic-auth.ref.yaml @@ -0,0 +1,3 @@ +basic_auth: + username_ref: admin + password_ref: pass \ No newline at end of file