diff --git a/plugins/rest/aws_test.go b/plugins/rest/aws_test.go index 8b66807164..e1d5ca6a3a 100644 --- a/plugins/rest/aws_test.go +++ b/plugins/rest/aws_test.go @@ -46,11 +46,15 @@ func assertErr(expected string, actual error, t *testing.T) { } func TestEnvironmentCredentialService(t *testing.T) { - os.Setenv("AWS_ACCESS_KEY_ID", "") - os.Setenv("AWS_SECRET_ACCESS_KEY", "") - os.Setenv("AWS_REGION", "") - os.Setenv("AWS_SECURITY_TOKEN", "") - os.Setenv("AWS_SESSION_TOKEN", "") + reset := func() { + os.Unsetenv("AWS_ACCESS_KEY_ID") + os.Unsetenv("AWS_SECRET_ACCESS_KEY") + os.Unsetenv("AWS_REGION") + os.Unsetenv("AWS_SECURITY_TOKEN") + os.Unsetenv("AWS_SESSION_TOKEN") + } + reset() + t.Cleanup(reset) // reset again when we're done cs := &awsEnvironmentCredentialService{} @@ -180,13 +184,13 @@ func TestProfileCredentialServiceWithEnvVars(t *testing.T) { defaultSecret := "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" defaultSessionToken := "AQoEXAMPLEH4aoAH0gNCAPy" defaultRegion := "us-east-1" - + profile := "profileName" config := fmt.Sprintf(` -[default] -aws_access_key_id=%v -aws_secret_access_key=%v -aws_session_token=%v -`, defaultKey, defaultSecret, defaultSessionToken) +[%s] +aws_access_key_id=%s +aws_secret_access_key=%s +aws_session_token=%s +`, profile, defaultKey, defaultSecret, defaultSessionToken) files := map[string]string{ "example.ini": config, @@ -196,10 +200,14 @@ aws_session_token=%v cfgPath := filepath.Join(path, "example.ini") os.Setenv(awsCredentialsFileEnvVar, cfgPath) - os.Setenv(awsProfileEnvVar, "default") + os.Setenv(awsProfileEnvVar, profile) + os.Setenv(awsRegionEnvVar, defaultRegion) - defer os.Unsetenv(awsCredentialsFileEnvVar) - defer os.Unsetenv(awsProfileEnvVar) + t.Cleanup(func() { + os.Unsetenv(awsCredentialsFileEnvVar) + os.Unsetenv(awsProfileEnvVar) + os.Unsetenv(awsRegionEnvVar) + }) cs := &awsProfileCredentialService{} creds, err := cs.credentials() @@ -224,24 +232,28 @@ func TestProfileCredentialServiceWithDefaultPath(t *testing.T) { defaultKey := "AKIAIOSFODNN7EXAMPLE" defaultSecret := "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" defaultSessionToken := "AQoEXAMPLEH4aoAH0gNCAPy" - defaultRegion := "us-west-2" + defaultRegion := "us-west-22" config := fmt.Sprintf(` [default] -aws_access_key_id=%v -aws_secret_access_key=%v -aws_session_token=%v +aws_access_key_id=%s +aws_secret_access_key=%s +aws_session_token=%s `, defaultKey, defaultSecret, defaultSessionToken) files := map[string]string{} + oldUserProfile := os.Getenv("USERPROFILE") + oldHome := os.Getenv("HOME") test.WithTempFS(files, func(path string) { os.Setenv("USERPROFILE", path) os.Setenv("HOME", path) - defer os.Unsetenv("USERPROFILE") - defer os.Unsetenv("HOME") + t.Cleanup(func() { + os.Setenv("USERPROFILE", oldUserProfile) + os.Setenv("HOME", oldHome) + }) cfgDir := filepath.Join(path, ".aws") err := os.MkdirAll(cfgDir, os.ModePerm) @@ -285,14 +297,17 @@ aws_access_key_id=accessKey tests := []struct { note string config string + err string }{ { note: "no aws_access_key_id", config: configNoAccessKeyID, + err: "does not contain \"aws_access_key_id\"", }, { note: "no aws_secret_access_key", config: configNoSecret, + err: "does not contain \"aws_secret_access_key\"", }, } @@ -312,6 +327,9 @@ aws_access_key_id=accessKey if err == nil { t.Fatal("Expected error but got nil") } + if !strings.Contains(err.Error(), tc.err) { + t.Errorf("expected error to contain %v, got %v", tc.err, err.Error()) + } }) }) } @@ -779,6 +797,16 @@ func (t *ec2CredTestServer) stop() { } func TestWebIdentityCredentialService(t *testing.T) { + reset := func() { + os.Unsetenv("AWS_WEB_IDENTITY_TOKEN_FILE") + os.Unsetenv("AWS_ROLE_ARN") + os.Unsetenv("AWS_REGION") + } + reset() + t.Cleanup(reset) + + os.Setenv("AWS_REGION", "us-west-1") + testAccessKey := "ASgeIAIOSFODNN7EXAMPLE" ts := stsTestServer{ t: t, @@ -791,115 +819,80 @@ func TestWebIdentityCredentialService(t *testing.T) { logger: logging.Get(), } - goodTokenFile, err := ioutil.TempFile(os.TempDir(), "opa-aws-test-") - if err != nil { - t.Errorf("Error while creating token file: %s", err) - return - } - t.Cleanup(func() { - err := os.Remove(goodTokenFile.Name()) - if err != nil { - t.Fatalf("unable to remove goodTokenFile %q: %v", goodTokenFile.Name(), err) - } - }) - _, err = goodTokenFile.WriteString("good-token") - if err != nil { - t.Errorf("Error while creating token file: %s", err) - return - } - err = goodTokenFile.Close() - if err != nil { - t.Errorf("Error while creating token file: %s", err) - return - } - - badTokenFile, err := ioutil.TempFile(os.TempDir(), "opa-aws-test-") - if err != nil { - t.Errorf("Error while creating token file: %s", err) - return - } - t.Cleanup(func() { - err := os.Remove(badTokenFile.Name()) - if err != nil { - t.Fatalf("unable to remove badTokenFile %q: %v", badTokenFile.Name(), err) - } - }) - _, err = badTokenFile.WriteString("bad-token") - if err != nil { - t.Errorf("Error while creating token file: %s", err) - return - } - err = badTokenFile.Close() - if err != nil { - t.Errorf("Error while creating token file: %s", err) - return + files := map[string]string{ + "good_token_file": "good-token", + "bad_token_file": "bad-token", } - // wrong path: no AWS_ROLE_ARN set - err = cs.populateFromEnv() - assertErr("no AWS_ROLE_ARN set in environment", err, t) - os.Setenv("AWS_ROLE_ARN", "role:arn") - - // wrong path: no AWS_WEB_IDENTITY_TOKEN_FILE set - err = cs.populateFromEnv() - assertErr("no AWS_WEB_IDENTITY_TOKEN_FILE set in environment", err, t) - os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", "/nonsense") - - // happy path: both env vars set - err = cs.populateFromEnv() - if err != nil { - t.Errorf("Error while getting env vars: %s", err) - return - } + test.WithTempFS(files, func(path string) { + goodTokenFile := filepath.Join(path, "good_token_file") + badTokenFile := filepath.Join(path, "bad_token_file") - // wrong path: refresh with invalid web token file - err = cs.refreshFromService() - assertErr("unable to read web token for sts HTTP request: open /nonsense: no such file or directory", err, t) + // wrong path: no AWS_ROLE_ARN set + err := cs.populateFromEnv() + assertErr("no AWS_ROLE_ARN set in environment", err, t) + os.Setenv("AWS_ROLE_ARN", "role:arn") - // wrong path: refresh with "bad token" - os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", badTokenFile.Name()) - _ = cs.populateFromEnv() - err = cs.refreshFromService() - assertErr("STS HTTP request returned unexpected status: 401 Unauthorized", err, t) + // wrong path: no AWS_WEB_IDENTITY_TOKEN_FILE set + err = cs.populateFromEnv() + assertErr("no AWS_WEB_IDENTITY_TOKEN_FILE set in environment", err, t) + os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", "/nonsense") - // happy path: refresh with "good token" - os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", goodTokenFile.Name()) - _ = cs.populateFromEnv() - err = cs.refreshFromService() - if err != nil { - t.Fatalf("Unexpected err: %s", err) - } + // happy path: both env vars set + err = cs.populateFromEnv() + if err != nil { + t.Fatalf("Error while getting env vars: %s", err) + } - // happy path: refresh and get credentials - creds, _ := cs.credentials() - assertEq(creds.AccessKey, testAccessKey, t) + // wrong path: refresh with invalid web token file + err = cs.refreshFromService() + assertErr("unable to read web token for sts HTTP request: open /nonsense: no such file or directory", err, t) - // happy path: refresh with session and get credentials - cs.expiration = time.Now() - cs.SessionName = "TEST_SESSION" - creds, _ = cs.credentials() - assertEq(creds.AccessKey, testAccessKey, t) + // wrong path: refresh with "bad token" + os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", badTokenFile) + _ = cs.populateFromEnv() + err = cs.refreshFromService() + assertErr("STS HTTP request returned unexpected status: 401 Unauthorized", err, t) - // happy path: don't refresh, but get credentials - ts.accessKey = "OTHERKEY" - creds, _ = cs.credentials() - assertEq(creds.AccessKey, testAccessKey, t) + // happy path: refresh with "good token" + os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", goodTokenFile) + _ = cs.populateFromEnv() + err = cs.refreshFromService() + if err != nil { + t.Fatalf("Unexpected err: %s", err) + } - // happy/wrong path: refresh with "bad token" but return previous credentials - os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", badTokenFile.Name()) - _ = cs.populateFromEnv() - cs.expiration = time.Now() - creds, err = cs.credentials() - assertEq(creds.AccessKey, testAccessKey, t) - assertErr("STS HTTP request returned unexpected status: 401 Unauthorized", err, t) - - // wrong path: refresh with "bad token" but return previous credentials - os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", goodTokenFile.Name()) - os.Setenv("AWS_ROLE_ARN", "BrokenRole") - _ = cs.populateFromEnv() - cs.expiration = time.Now() - creds, err = cs.credentials() - assertErr("failed to parse credential response from STS service: EOF", err, t) + // happy path: refresh and get credentials + creds, _ := cs.credentials() + assertEq(creds.AccessKey, testAccessKey, t) + + // happy path: refresh with session and get credentials + cs.expiration = time.Now() + cs.SessionName = "TEST_SESSION" + creds, _ = cs.credentials() + assertEq(creds.AccessKey, testAccessKey, t) + + // happy path: don't refresh, but get credentials + ts.accessKey = "OTHERKEY" + creds, _ = cs.credentials() + assertEq(creds.AccessKey, testAccessKey, t) + + // happy/wrong path: refresh with "bad token" but return previous credentials + os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", badTokenFile) + _ = cs.populateFromEnv() + cs.expiration = time.Now() + creds, err = cs.credentials() + assertEq(creds.AccessKey, testAccessKey, t) + assertErr("STS HTTP request returned unexpected status: 401 Unauthorized", err, t) + + // wrong path: refresh with "bad token" but return previous credentials + os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", goodTokenFile) + os.Setenv("AWS_ROLE_ARN", "BrokenRole") + _ = cs.populateFromEnv() + cs.expiration = time.Now() + creds, err = cs.credentials() + assertErr("failed to parse credential response from STS service: EOF", err, t) + }) } func TestStsPath(t *testing.T) { diff --git a/plugins/rest/rest_auth.go b/plugins/rest/rest_auth.go index 7d36bbc869..3b718d41c0 100644 --- a/plugins/rest/rest_auth.go +++ b/plugins/rest/rest_auth.go @@ -532,42 +532,48 @@ func (ap *awsSigningAuthPlugin) NewClient(c Config) (*http.Client, error) { return nil, err } - if ap.AWSEnvironmentCredentials == nil && ap.AWSWebIdentityCredentials == nil && ap.AWSMetadataCredentials == nil && - ap.AWSProfileCredentials == nil { - return nil, errors.New("a AWS credential service must be specified when S3 signing is enabled") + if err := ap.validateConfig(); err != nil { + return nil, err } - if (ap.AWSEnvironmentCredentials != nil && ap.AWSMetadataCredentials != nil) || - (ap.AWSEnvironmentCredentials != nil && ap.AWSWebIdentityCredentials != nil) || - (ap.AWSEnvironmentCredentials != nil && ap.AWSProfileCredentials != nil) || - (ap.AWSMetadataCredentials != nil && ap.AWSWebIdentityCredentials != nil) || - (ap.AWSMetadataCredentials != nil && ap.AWSProfileCredentials != nil) || - (ap.AWSWebIdentityCredentials != nil && ap.AWSProfileCredentials != nil) { - return nil, errors.New("exactly one AWS credential service must be specified when S3 signing is enabled") + if ap.logger == nil { + ap.logger = c.logger + } + + return DefaultRoundTripperClient(t, *c.ResponseHeaderTimeoutSeconds), nil +} + +func (ap *awsSigningAuthPlugin) Prepare(req *http.Request) error { + ap.logger.Debug("Signing request with AWS credentials.") + return signV4(req, ap.AWSService, ap.awsCredentialService(), time.Now()) +} + +func (ap *awsSigningAuthPlugin) validateConfig() error { + cfgs := map[bool]int{} + cfgs[ap.AWSEnvironmentCredentials != nil]++ + cfgs[ap.AWSMetadataCredentials != nil]++ + cfgs[ap.AWSWebIdentityCredentials != nil]++ + cfgs[ap.AWSProfileCredentials != nil]++ + + switch n := cfgs[true]; { + case n == 0: + return errors.New("a AWS credential service must be specified when S3 signing is enabled") + case n > 1: + return errors.New("exactly one AWS credential service must be specified when S3 signing is enabled") } + if ap.AWSMetadataCredentials != nil { if ap.AWSMetadataCredentials.RegionName == "" { - return nil, errors.New("at least aws_region must be specified for AWS metadata credential service") + return errors.New("at least aws_region must be specified for AWS metadata credential service") } } if ap.AWSWebIdentityCredentials != nil { if err := ap.AWSWebIdentityCredentials.populateFromEnv(); err != nil { - return nil, err + return err } } - - if ap.logger == nil { - ap.logger = c.logger - } if ap.AWSService == "" { ap.AWSService = awsSigv4SigningDefaultService } - - return DefaultRoundTripperClient(t, *c.ResponseHeaderTimeoutSeconds), nil -} - -func (ap *awsSigningAuthPlugin) Prepare(req *http.Request) error { - ap.logger.Debug("Signing request with AWS credentials.") - err := signV4(req, ap.AWSService, ap.awsCredentialService(), time.Now()) - return err + return nil } diff --git a/plugins/rest/rest_test.go b/plugins/rest/rest_test.go index bab02d041c..002e9f08bc 100644 --- a/plugins/rest/rest_test.go +++ b/plugins/rest/rest_test.go @@ -193,7 +193,7 @@ func TestNew(t *testing.T) { wantErr: true, }, { - name: "TooManyS3CredOptions", + name: "TooManyS3CredOptions/metadata+environment", input: `{ "name": "foo", "url": "http://localhost", @@ -209,6 +209,22 @@ func TestNew(t *testing.T) { }`, wantErr: true, }, + { + name: "TooManyS3CredOptions/metadata+profile+environment+webidentity", + input: `{ + "name": "foo", + "url": "http://localhost", + "credentials": { + "s3_signing": { + "profile_credentials": {}, + "environment_credentials": {}, + "web_identity_credentials": {}, + "metadata_credentials": {} + } + } + }`, + wantErr: true, + }, { name: "TooManyCredentialsOptions", input: `{ @@ -502,6 +518,7 @@ func TestNew(t *testing.T) { }, } }`, + wantErr: true, }, { name: "S3WebIdentityCreds", @@ -517,6 +534,7 @@ func TestNew(t *testing.T) { env: map[string]string{ awsRoleArnEnvVar: "TEST", awsWebIdentityTokenFileEnvVar: "TEST", + awsRegionEnvVar: "us-west-1", }, }, { @@ -639,11 +657,11 @@ func TestNew(t *testing.T) { _ = os.Setenv(key, val) } - defer func() { + t.Cleanup(func() { for key := range tc.env { _ = os.Unsetenv(key) } - }() + }) client, err := New([]byte(tc.input), ks, AuthPluginLookup(mockAuthPluginLookup)) if err != nil {