From e8d1b5f49073097cfb38efb2c5c4616a3ae9547f Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Tue, 7 Dec 2021 18:38:03 +0100 Subject: [PATCH] plugins/rest: refactor AWS credentials provider config validation (#4081) * plugins/rest: refactor AWS credentials provider config validation * plugins/rest: cleanup env var handling in tests The TestNew case of rest_test.go had failed when run in isolation, but passed when running all of the packages tests. With these changes to how env vars are handled, it passed both alone and when run with all the other tests. Also migrated the TestWebIdentityCredentialService test to use test.WithTempFS instead of dealing with the files itself. Signed-off-by: Stephan Renatus --- plugins/rest/aws_test.go | 233 ++++++++++++++++++-------------------- plugins/rest/rest_auth.go | 54 +++++---- plugins/rest/rest_test.go | 24 +++- 3 files changed, 164 insertions(+), 147 deletions(-) diff --git a/plugins/rest/aws_test.go b/plugins/rest/aws_test.go index 8b66807164..e1d5ca6a3a 100644 --- a/plugins/rest/aws_test.go +++ b/plugins/rest/aws_test.go @@ -46,11 +46,15 @@ func assertErr(expected string, actual error, t *testing.T) { } func TestEnvironmentCredentialService(t *testing.T) { - os.Setenv("AWS_ACCESS_KEY_ID", "") - os.Setenv("AWS_SECRET_ACCESS_KEY", "") - os.Setenv("AWS_REGION", "") - os.Setenv("AWS_SECURITY_TOKEN", "") - os.Setenv("AWS_SESSION_TOKEN", "") + reset := func() { + os.Unsetenv("AWS_ACCESS_KEY_ID") + os.Unsetenv("AWS_SECRET_ACCESS_KEY") + os.Unsetenv("AWS_REGION") + os.Unsetenv("AWS_SECURITY_TOKEN") + os.Unsetenv("AWS_SESSION_TOKEN") + } + reset() + t.Cleanup(reset) // reset again when we're done cs := &awsEnvironmentCredentialService{} @@ -180,13 +184,13 @@ func TestProfileCredentialServiceWithEnvVars(t *testing.T) { defaultSecret := "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" defaultSessionToken := "AQoEXAMPLEH4aoAH0gNCAPy" defaultRegion := "us-east-1" - + profile := "profileName" config := fmt.Sprintf(` -[default] -aws_access_key_id=%v -aws_secret_access_key=%v -aws_session_token=%v -`, defaultKey, defaultSecret, defaultSessionToken) +[%s] +aws_access_key_id=%s +aws_secret_access_key=%s +aws_session_token=%s +`, profile, defaultKey, defaultSecret, defaultSessionToken) files := map[string]string{ "example.ini": config, @@ -196,10 +200,14 @@ aws_session_token=%v cfgPath := filepath.Join(path, "example.ini") os.Setenv(awsCredentialsFileEnvVar, cfgPath) - os.Setenv(awsProfileEnvVar, "default") + os.Setenv(awsProfileEnvVar, profile) + os.Setenv(awsRegionEnvVar, defaultRegion) - defer os.Unsetenv(awsCredentialsFileEnvVar) - defer os.Unsetenv(awsProfileEnvVar) + t.Cleanup(func() { + os.Unsetenv(awsCredentialsFileEnvVar) + os.Unsetenv(awsProfileEnvVar) + os.Unsetenv(awsRegionEnvVar) + }) cs := &awsProfileCredentialService{} creds, err := cs.credentials() @@ -224,24 +232,28 @@ func TestProfileCredentialServiceWithDefaultPath(t *testing.T) { defaultKey := "AKIAIOSFODNN7EXAMPLE" defaultSecret := "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" defaultSessionToken := "AQoEXAMPLEH4aoAH0gNCAPy" - defaultRegion := "us-west-2" + defaultRegion := "us-west-22" config := fmt.Sprintf(` [default] -aws_access_key_id=%v -aws_secret_access_key=%v -aws_session_token=%v +aws_access_key_id=%s +aws_secret_access_key=%s +aws_session_token=%s `, defaultKey, defaultSecret, defaultSessionToken) files := map[string]string{} + oldUserProfile := os.Getenv("USERPROFILE") + oldHome := os.Getenv("HOME") test.WithTempFS(files, func(path string) { os.Setenv("USERPROFILE", path) os.Setenv("HOME", path) - defer os.Unsetenv("USERPROFILE") - defer os.Unsetenv("HOME") + t.Cleanup(func() { + os.Setenv("USERPROFILE", oldUserProfile) + os.Setenv("HOME", oldHome) + }) cfgDir := filepath.Join(path, ".aws") err := os.MkdirAll(cfgDir, os.ModePerm) @@ -285,14 +297,17 @@ aws_access_key_id=accessKey tests := []struct { note string config string + err string }{ { note: "no aws_access_key_id", config: configNoAccessKeyID, + err: "does not contain \"aws_access_key_id\"", }, { note: "no aws_secret_access_key", config: configNoSecret, + err: "does not contain \"aws_secret_access_key\"", }, } @@ -312,6 +327,9 @@ aws_access_key_id=accessKey if err == nil { t.Fatal("Expected error but got nil") } + if !strings.Contains(err.Error(), tc.err) { + t.Errorf("expected error to contain %v, got %v", tc.err, err.Error()) + } }) }) } @@ -779,6 +797,16 @@ func (t *ec2CredTestServer) stop() { } func TestWebIdentityCredentialService(t *testing.T) { + reset := func() { + os.Unsetenv("AWS_WEB_IDENTITY_TOKEN_FILE") + os.Unsetenv("AWS_ROLE_ARN") + os.Unsetenv("AWS_REGION") + } + reset() + t.Cleanup(reset) + + os.Setenv("AWS_REGION", "us-west-1") + testAccessKey := "ASgeIAIOSFODNN7EXAMPLE" ts := stsTestServer{ t: t, @@ -791,115 +819,80 @@ func TestWebIdentityCredentialService(t *testing.T) { logger: logging.Get(), } - goodTokenFile, err := ioutil.TempFile(os.TempDir(), "opa-aws-test-") - if err != nil { - t.Errorf("Error while creating token file: %s", err) - return - } - t.Cleanup(func() { - err := os.Remove(goodTokenFile.Name()) - if err != nil { - t.Fatalf("unable to remove goodTokenFile %q: %v", goodTokenFile.Name(), err) - } - }) - _, err = goodTokenFile.WriteString("good-token") - if err != nil { - t.Errorf("Error while creating token file: %s", err) - return - } - err = goodTokenFile.Close() - if err != nil { - t.Errorf("Error while creating token file: %s", err) - return - } - - badTokenFile, err := ioutil.TempFile(os.TempDir(), "opa-aws-test-") - if err != nil { - t.Errorf("Error while creating token file: %s", err) - return - } - t.Cleanup(func() { - err := os.Remove(badTokenFile.Name()) - if err != nil { - t.Fatalf("unable to remove badTokenFile %q: %v", badTokenFile.Name(), err) - } - }) - _, err = badTokenFile.WriteString("bad-token") - if err != nil { - t.Errorf("Error while creating token file: %s", err) - return - } - err = badTokenFile.Close() - if err != nil { - t.Errorf("Error while creating token file: %s", err) - return + files := map[string]string{ + "good_token_file": "good-token", + "bad_token_file": "bad-token", } - // wrong path: no AWS_ROLE_ARN set - err = cs.populateFromEnv() - assertErr("no AWS_ROLE_ARN set in environment", err, t) - os.Setenv("AWS_ROLE_ARN", "role:arn") - - // wrong path: no AWS_WEB_IDENTITY_TOKEN_FILE set - err = cs.populateFromEnv() - assertErr("no AWS_WEB_IDENTITY_TOKEN_FILE set in environment", err, t) - os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", "/nonsense") - - // happy path: both env vars set - err = cs.populateFromEnv() - if err != nil { - t.Errorf("Error while getting env vars: %s", err) - return - } + test.WithTempFS(files, func(path string) { + goodTokenFile := filepath.Join(path, "good_token_file") + badTokenFile := filepath.Join(path, "bad_token_file") - // wrong path: refresh with invalid web token file - err = cs.refreshFromService() - assertErr("unable to read web token for sts HTTP request: open /nonsense: no such file or directory", err, t) + // wrong path: no AWS_ROLE_ARN set + err := cs.populateFromEnv() + assertErr("no AWS_ROLE_ARN set in environment", err, t) + os.Setenv("AWS_ROLE_ARN", "role:arn") - // wrong path: refresh with "bad token" - os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", badTokenFile.Name()) - _ = cs.populateFromEnv() - err = cs.refreshFromService() - assertErr("STS HTTP request returned unexpected status: 401 Unauthorized", err, t) + // wrong path: no AWS_WEB_IDENTITY_TOKEN_FILE set + err = cs.populateFromEnv() + assertErr("no AWS_WEB_IDENTITY_TOKEN_FILE set in environment", err, t) + os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", "/nonsense") - // happy path: refresh with "good token" - os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", goodTokenFile.Name()) - _ = cs.populateFromEnv() - err = cs.refreshFromService() - if err != nil { - t.Fatalf("Unexpected err: %s", err) - } + // happy path: both env vars set + err = cs.populateFromEnv() + if err != nil { + t.Fatalf("Error while getting env vars: %s", err) + } - // happy path: refresh and get credentials - creds, _ := cs.credentials() - assertEq(creds.AccessKey, testAccessKey, t) + // wrong path: refresh with invalid web token file + err = cs.refreshFromService() + assertErr("unable to read web token for sts HTTP request: open /nonsense: no such file or directory", err, t) - // happy path: refresh with session and get credentials - cs.expiration = time.Now() - cs.SessionName = "TEST_SESSION" - creds, _ = cs.credentials() - assertEq(creds.AccessKey, testAccessKey, t) + // wrong path: refresh with "bad token" + os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", badTokenFile) + _ = cs.populateFromEnv() + err = cs.refreshFromService() + assertErr("STS HTTP request returned unexpected status: 401 Unauthorized", err, t) - // happy path: don't refresh, but get credentials - ts.accessKey = "OTHERKEY" - creds, _ = cs.credentials() - assertEq(creds.AccessKey, testAccessKey, t) + // happy path: refresh with "good token" + os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", goodTokenFile) + _ = cs.populateFromEnv() + err = cs.refreshFromService() + if err != nil { + t.Fatalf("Unexpected err: %s", err) + } - // happy/wrong path: refresh with "bad token" but return previous credentials - os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", badTokenFile.Name()) - _ = cs.populateFromEnv() - cs.expiration = time.Now() - creds, err = cs.credentials() - assertEq(creds.AccessKey, testAccessKey, t) - assertErr("STS HTTP request returned unexpected status: 401 Unauthorized", err, t) - - // wrong path: refresh with "bad token" but return previous credentials - os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", goodTokenFile.Name()) - os.Setenv("AWS_ROLE_ARN", "BrokenRole") - _ = cs.populateFromEnv() - cs.expiration = time.Now() - creds, err = cs.credentials() - assertErr("failed to parse credential response from STS service: EOF", err, t) + // happy path: refresh and get credentials + creds, _ := cs.credentials() + assertEq(creds.AccessKey, testAccessKey, t) + + // happy path: refresh with session and get credentials + cs.expiration = time.Now() + cs.SessionName = "TEST_SESSION" + creds, _ = cs.credentials() + assertEq(creds.AccessKey, testAccessKey, t) + + // happy path: don't refresh, but get credentials + ts.accessKey = "OTHERKEY" + creds, _ = cs.credentials() + assertEq(creds.AccessKey, testAccessKey, t) + + // happy/wrong path: refresh with "bad token" but return previous credentials + os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", badTokenFile) + _ = cs.populateFromEnv() + cs.expiration = time.Now() + creds, err = cs.credentials() + assertEq(creds.AccessKey, testAccessKey, t) + assertErr("STS HTTP request returned unexpected status: 401 Unauthorized", err, t) + + // wrong path: refresh with "bad token" but return previous credentials + os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", goodTokenFile) + os.Setenv("AWS_ROLE_ARN", "BrokenRole") + _ = cs.populateFromEnv() + cs.expiration = time.Now() + creds, err = cs.credentials() + assertErr("failed to parse credential response from STS service: EOF", err, t) + }) } func TestStsPath(t *testing.T) { diff --git a/plugins/rest/rest_auth.go b/plugins/rest/rest_auth.go index 7d36bbc869..3b718d41c0 100644 --- a/plugins/rest/rest_auth.go +++ b/plugins/rest/rest_auth.go @@ -532,42 +532,48 @@ func (ap *awsSigningAuthPlugin) NewClient(c Config) (*http.Client, error) { return nil, err } - if ap.AWSEnvironmentCredentials == nil && ap.AWSWebIdentityCredentials == nil && ap.AWSMetadataCredentials == nil && - ap.AWSProfileCredentials == nil { - return nil, errors.New("a AWS credential service must be specified when S3 signing is enabled") + if err := ap.validateConfig(); err != nil { + return nil, err } - if (ap.AWSEnvironmentCredentials != nil && ap.AWSMetadataCredentials != nil) || - (ap.AWSEnvironmentCredentials != nil && ap.AWSWebIdentityCredentials != nil) || - (ap.AWSEnvironmentCredentials != nil && ap.AWSProfileCredentials != nil) || - (ap.AWSMetadataCredentials != nil && ap.AWSWebIdentityCredentials != nil) || - (ap.AWSMetadataCredentials != nil && ap.AWSProfileCredentials != nil) || - (ap.AWSWebIdentityCredentials != nil && ap.AWSProfileCredentials != nil) { - return nil, errors.New("exactly one AWS credential service must be specified when S3 signing is enabled") + if ap.logger == nil { + ap.logger = c.logger + } + + return DefaultRoundTripperClient(t, *c.ResponseHeaderTimeoutSeconds), nil +} + +func (ap *awsSigningAuthPlugin) Prepare(req *http.Request) error { + ap.logger.Debug("Signing request with AWS credentials.") + return signV4(req, ap.AWSService, ap.awsCredentialService(), time.Now()) +} + +func (ap *awsSigningAuthPlugin) validateConfig() error { + cfgs := map[bool]int{} + cfgs[ap.AWSEnvironmentCredentials != nil]++ + cfgs[ap.AWSMetadataCredentials != nil]++ + cfgs[ap.AWSWebIdentityCredentials != nil]++ + cfgs[ap.AWSProfileCredentials != nil]++ + + switch n := cfgs[true]; { + case n == 0: + return errors.New("a AWS credential service must be specified when S3 signing is enabled") + case n > 1: + return errors.New("exactly one AWS credential service must be specified when S3 signing is enabled") } + if ap.AWSMetadataCredentials != nil { if ap.AWSMetadataCredentials.RegionName == "" { - return nil, errors.New("at least aws_region must be specified for AWS metadata credential service") + return errors.New("at least aws_region must be specified for AWS metadata credential service") } } if ap.AWSWebIdentityCredentials != nil { if err := ap.AWSWebIdentityCredentials.populateFromEnv(); err != nil { - return nil, err + return err } } - - if ap.logger == nil { - ap.logger = c.logger - } if ap.AWSService == "" { ap.AWSService = awsSigv4SigningDefaultService } - - return DefaultRoundTripperClient(t, *c.ResponseHeaderTimeoutSeconds), nil -} - -func (ap *awsSigningAuthPlugin) Prepare(req *http.Request) error { - ap.logger.Debug("Signing request with AWS credentials.") - err := signV4(req, ap.AWSService, ap.awsCredentialService(), time.Now()) - return err + return nil } diff --git a/plugins/rest/rest_test.go b/plugins/rest/rest_test.go index bab02d041c..002e9f08bc 100644 --- a/plugins/rest/rest_test.go +++ b/plugins/rest/rest_test.go @@ -193,7 +193,7 @@ func TestNew(t *testing.T) { wantErr: true, }, { - name: "TooManyS3CredOptions", + name: "TooManyS3CredOptions/metadata+environment", input: `{ "name": "foo", "url": "http://localhost", @@ -209,6 +209,22 @@ func TestNew(t *testing.T) { }`, wantErr: true, }, + { + name: "TooManyS3CredOptions/metadata+profile+environment+webidentity", + input: `{ + "name": "foo", + "url": "http://localhost", + "credentials": { + "s3_signing": { + "profile_credentials": {}, + "environment_credentials": {}, + "web_identity_credentials": {}, + "metadata_credentials": {} + } + } + }`, + wantErr: true, + }, { name: "TooManyCredentialsOptions", input: `{ @@ -502,6 +518,7 @@ func TestNew(t *testing.T) { }, } }`, + wantErr: true, }, { name: "S3WebIdentityCreds", @@ -517,6 +534,7 @@ func TestNew(t *testing.T) { env: map[string]string{ awsRoleArnEnvVar: "TEST", awsWebIdentityTokenFileEnvVar: "TEST", + awsRegionEnvVar: "us-west-1", }, }, { @@ -639,11 +657,11 @@ func TestNew(t *testing.T) { _ = os.Setenv(key, val) } - defer func() { + t.Cleanup(func() { for key := range tc.env { _ = os.Unsetenv(key) } - }() + }) client, err := New([]byte(tc.input), ks, AuthPluginLookup(mockAuthPluginLookup)) if err != nil {