Skip to content

Commit

Permalink
plugins/rest: refactor AWS credentials provider config validation (op…
Browse files Browse the repository at this point in the history
…en-policy-agent#4081)

* plugins/rest: refactor AWS credentials provider config validation

* plugins/rest: cleanup env var handling in tests

The TestNew case of rest_test.go had failed when run in isolation,
but passed when running all of the packages tests. With these
changes to how env vars are handled, it passed both alone and when
run with all the other tests.

Also migrated the TestWebIdentityCredentialService test to use
test.WithTempFS instead of dealing with the files itself.

Signed-off-by: Stephan Renatus <stephan.renatus@gmail.com>
  • Loading branch information
srenatus committed Dec 7, 2021
1 parent 81ec24a commit e8d1b5f
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 147 deletions.
233 changes: 113 additions & 120 deletions plugins/rest/aws_test.go
Expand Up @@ -46,11 +46,15 @@ func assertErr(expected string, actual error, t *testing.T) {
}

func TestEnvironmentCredentialService(t *testing.T) {
os.Setenv("AWS_ACCESS_KEY_ID", "")
os.Setenv("AWS_SECRET_ACCESS_KEY", "")
os.Setenv("AWS_REGION", "")
os.Setenv("AWS_SECURITY_TOKEN", "")
os.Setenv("AWS_SESSION_TOKEN", "")
reset := func() {
os.Unsetenv("AWS_ACCESS_KEY_ID")
os.Unsetenv("AWS_SECRET_ACCESS_KEY")
os.Unsetenv("AWS_REGION")
os.Unsetenv("AWS_SECURITY_TOKEN")
os.Unsetenv("AWS_SESSION_TOKEN")
}
reset()
t.Cleanup(reset) // reset again when we're done

cs := &awsEnvironmentCredentialService{}

Expand Down Expand Up @@ -180,13 +184,13 @@ func TestProfileCredentialServiceWithEnvVars(t *testing.T) {
defaultSecret := "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
defaultSessionToken := "AQoEXAMPLEH4aoAH0gNCAPy"
defaultRegion := "us-east-1"

profile := "profileName"
config := fmt.Sprintf(`
[default]
aws_access_key_id=%v
aws_secret_access_key=%v
aws_session_token=%v
`, defaultKey, defaultSecret, defaultSessionToken)
[%s]
aws_access_key_id=%s
aws_secret_access_key=%s
aws_session_token=%s
`, profile, defaultKey, defaultSecret, defaultSessionToken)

files := map[string]string{
"example.ini": config,
Expand All @@ -196,10 +200,14 @@ aws_session_token=%v
cfgPath := filepath.Join(path, "example.ini")

os.Setenv(awsCredentialsFileEnvVar, cfgPath)
os.Setenv(awsProfileEnvVar, "default")
os.Setenv(awsProfileEnvVar, profile)
os.Setenv(awsRegionEnvVar, defaultRegion)

defer os.Unsetenv(awsCredentialsFileEnvVar)
defer os.Unsetenv(awsProfileEnvVar)
t.Cleanup(func() {
os.Unsetenv(awsCredentialsFileEnvVar)
os.Unsetenv(awsProfileEnvVar)
os.Unsetenv(awsRegionEnvVar)
})

cs := &awsProfileCredentialService{}
creds, err := cs.credentials()
Expand All @@ -224,24 +232,28 @@ func TestProfileCredentialServiceWithDefaultPath(t *testing.T) {
defaultKey := "AKIAIOSFODNN7EXAMPLE"
defaultSecret := "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
defaultSessionToken := "AQoEXAMPLEH4aoAH0gNCAPy"
defaultRegion := "us-west-2"
defaultRegion := "us-west-22"

config := fmt.Sprintf(`
[default]
aws_access_key_id=%v
aws_secret_access_key=%v
aws_session_token=%v
aws_access_key_id=%s
aws_secret_access_key=%s
aws_session_token=%s
`, defaultKey, defaultSecret, defaultSessionToken)

files := map[string]string{}
oldUserProfile := os.Getenv("USERPROFILE")
oldHome := os.Getenv("HOME")

test.WithTempFS(files, func(path string) {

os.Setenv("USERPROFILE", path)
os.Setenv("HOME", path)

defer os.Unsetenv("USERPROFILE")
defer os.Unsetenv("HOME")
t.Cleanup(func() {
os.Setenv("USERPROFILE", oldUserProfile)
os.Setenv("HOME", oldHome)
})

cfgDir := filepath.Join(path, ".aws")
err := os.MkdirAll(cfgDir, os.ModePerm)
Expand Down Expand Up @@ -285,14 +297,17 @@ aws_access_key_id=accessKey
tests := []struct {
note string
config string
err string
}{
{
note: "no aws_access_key_id",
config: configNoAccessKeyID,
err: "does not contain \"aws_access_key_id\"",
},
{
note: "no aws_secret_access_key",
config: configNoSecret,
err: "does not contain \"aws_secret_access_key\"",
},
}

Expand All @@ -312,6 +327,9 @@ aws_access_key_id=accessKey
if err == nil {
t.Fatal("Expected error but got nil")
}
if !strings.Contains(err.Error(), tc.err) {
t.Errorf("expected error to contain %v, got %v", tc.err, err.Error())
}
})
})
}
Expand Down Expand Up @@ -779,6 +797,16 @@ func (t *ec2CredTestServer) stop() {
}

func TestWebIdentityCredentialService(t *testing.T) {
reset := func() {
os.Unsetenv("AWS_WEB_IDENTITY_TOKEN_FILE")
os.Unsetenv("AWS_ROLE_ARN")
os.Unsetenv("AWS_REGION")
}
reset()
t.Cleanup(reset)

os.Setenv("AWS_REGION", "us-west-1")

testAccessKey := "ASgeIAIOSFODNN7EXAMPLE"
ts := stsTestServer{
t: t,
Expand All @@ -791,115 +819,80 @@ func TestWebIdentityCredentialService(t *testing.T) {
logger: logging.Get(),
}

goodTokenFile, err := ioutil.TempFile(os.TempDir(), "opa-aws-test-")
if err != nil {
t.Errorf("Error while creating token file: %s", err)
return
}
t.Cleanup(func() {
err := os.Remove(goodTokenFile.Name())
if err != nil {
t.Fatalf("unable to remove goodTokenFile %q: %v", goodTokenFile.Name(), err)
}
})
_, err = goodTokenFile.WriteString("good-token")
if err != nil {
t.Errorf("Error while creating token file: %s", err)
return
}
err = goodTokenFile.Close()
if err != nil {
t.Errorf("Error while creating token file: %s", err)
return
}

badTokenFile, err := ioutil.TempFile(os.TempDir(), "opa-aws-test-")
if err != nil {
t.Errorf("Error while creating token file: %s", err)
return
}
t.Cleanup(func() {
err := os.Remove(badTokenFile.Name())
if err != nil {
t.Fatalf("unable to remove badTokenFile %q: %v", badTokenFile.Name(), err)
}
})
_, err = badTokenFile.WriteString("bad-token")
if err != nil {
t.Errorf("Error while creating token file: %s", err)
return
}
err = badTokenFile.Close()
if err != nil {
t.Errorf("Error while creating token file: %s", err)
return
files := map[string]string{
"good_token_file": "good-token",
"bad_token_file": "bad-token",
}

// wrong path: no AWS_ROLE_ARN set
err = cs.populateFromEnv()
assertErr("no AWS_ROLE_ARN set in environment", err, t)
os.Setenv("AWS_ROLE_ARN", "role:arn")

// wrong path: no AWS_WEB_IDENTITY_TOKEN_FILE set
err = cs.populateFromEnv()
assertErr("no AWS_WEB_IDENTITY_TOKEN_FILE set in environment", err, t)
os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", "/nonsense")

// happy path: both env vars set
err = cs.populateFromEnv()
if err != nil {
t.Errorf("Error while getting env vars: %s", err)
return
}
test.WithTempFS(files, func(path string) {
goodTokenFile := filepath.Join(path, "good_token_file")
badTokenFile := filepath.Join(path, "bad_token_file")

// wrong path: refresh with invalid web token file
err = cs.refreshFromService()
assertErr("unable to read web token for sts HTTP request: open /nonsense: no such file or directory", err, t)
// wrong path: no AWS_ROLE_ARN set
err := cs.populateFromEnv()
assertErr("no AWS_ROLE_ARN set in environment", err, t)
os.Setenv("AWS_ROLE_ARN", "role:arn")

// wrong path: refresh with "bad token"
os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", badTokenFile.Name())
_ = cs.populateFromEnv()
err = cs.refreshFromService()
assertErr("STS HTTP request returned unexpected status: 401 Unauthorized", err, t)
// wrong path: no AWS_WEB_IDENTITY_TOKEN_FILE set
err = cs.populateFromEnv()
assertErr("no AWS_WEB_IDENTITY_TOKEN_FILE set in environment", err, t)
os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", "/nonsense")

// happy path: refresh with "good token"
os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", goodTokenFile.Name())
_ = cs.populateFromEnv()
err = cs.refreshFromService()
if err != nil {
t.Fatalf("Unexpected err: %s", err)
}
// happy path: both env vars set
err = cs.populateFromEnv()
if err != nil {
t.Fatalf("Error while getting env vars: %s", err)
}

// happy path: refresh and get credentials
creds, _ := cs.credentials()
assertEq(creds.AccessKey, testAccessKey, t)
// wrong path: refresh with invalid web token file
err = cs.refreshFromService()
assertErr("unable to read web token for sts HTTP request: open /nonsense: no such file or directory", err, t)

// happy path: refresh with session and get credentials
cs.expiration = time.Now()
cs.SessionName = "TEST_SESSION"
creds, _ = cs.credentials()
assertEq(creds.AccessKey, testAccessKey, t)
// wrong path: refresh with "bad token"
os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", badTokenFile)
_ = cs.populateFromEnv()
err = cs.refreshFromService()
assertErr("STS HTTP request returned unexpected status: 401 Unauthorized", err, t)

// happy path: don't refresh, but get credentials
ts.accessKey = "OTHERKEY"
creds, _ = cs.credentials()
assertEq(creds.AccessKey, testAccessKey, t)
// happy path: refresh with "good token"
os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", goodTokenFile)
_ = cs.populateFromEnv()
err = cs.refreshFromService()
if err != nil {
t.Fatalf("Unexpected err: %s", err)
}

// happy/wrong path: refresh with "bad token" but return previous credentials
os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", badTokenFile.Name())
_ = cs.populateFromEnv()
cs.expiration = time.Now()
creds, err = cs.credentials()
assertEq(creds.AccessKey, testAccessKey, t)
assertErr("STS HTTP request returned unexpected status: 401 Unauthorized", err, t)

// wrong path: refresh with "bad token" but return previous credentials
os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", goodTokenFile.Name())
os.Setenv("AWS_ROLE_ARN", "BrokenRole")
_ = cs.populateFromEnv()
cs.expiration = time.Now()
creds, err = cs.credentials()
assertErr("failed to parse credential response from STS service: EOF", err, t)
// happy path: refresh and get credentials
creds, _ := cs.credentials()
assertEq(creds.AccessKey, testAccessKey, t)

// happy path: refresh with session and get credentials
cs.expiration = time.Now()
cs.SessionName = "TEST_SESSION"
creds, _ = cs.credentials()
assertEq(creds.AccessKey, testAccessKey, t)

// happy path: don't refresh, but get credentials
ts.accessKey = "OTHERKEY"
creds, _ = cs.credentials()
assertEq(creds.AccessKey, testAccessKey, t)

// happy/wrong path: refresh with "bad token" but return previous credentials
os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", badTokenFile)
_ = cs.populateFromEnv()
cs.expiration = time.Now()
creds, err = cs.credentials()
assertEq(creds.AccessKey, testAccessKey, t)
assertErr("STS HTTP request returned unexpected status: 401 Unauthorized", err, t)

// wrong path: refresh with "bad token" but return previous credentials
os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", goodTokenFile)
os.Setenv("AWS_ROLE_ARN", "BrokenRole")
_ = cs.populateFromEnv()
cs.expiration = time.Now()
creds, err = cs.credentials()
assertErr("failed to parse credential response from STS service: EOF", err, t)
})
}

func TestStsPath(t *testing.T) {
Expand Down
54 changes: 30 additions & 24 deletions plugins/rest/rest_auth.go
Expand Up @@ -532,42 +532,48 @@ func (ap *awsSigningAuthPlugin) NewClient(c Config) (*http.Client, error) {
return nil, err
}

if ap.AWSEnvironmentCredentials == nil && ap.AWSWebIdentityCredentials == nil && ap.AWSMetadataCredentials == nil &&
ap.AWSProfileCredentials == nil {
return nil, errors.New("a AWS credential service must be specified when S3 signing is enabled")
if err := ap.validateConfig(); err != nil {
return nil, err
}

if (ap.AWSEnvironmentCredentials != nil && ap.AWSMetadataCredentials != nil) ||
(ap.AWSEnvironmentCredentials != nil && ap.AWSWebIdentityCredentials != nil) ||
(ap.AWSEnvironmentCredentials != nil && ap.AWSProfileCredentials != nil) ||
(ap.AWSMetadataCredentials != nil && ap.AWSWebIdentityCredentials != nil) ||
(ap.AWSMetadataCredentials != nil && ap.AWSProfileCredentials != nil) ||
(ap.AWSWebIdentityCredentials != nil && ap.AWSProfileCredentials != nil) {
return nil, errors.New("exactly one AWS credential service must be specified when S3 signing is enabled")
if ap.logger == nil {
ap.logger = c.logger
}

return DefaultRoundTripperClient(t, *c.ResponseHeaderTimeoutSeconds), nil
}

func (ap *awsSigningAuthPlugin) Prepare(req *http.Request) error {
ap.logger.Debug("Signing request with AWS credentials.")
return signV4(req, ap.AWSService, ap.awsCredentialService(), time.Now())
}

func (ap *awsSigningAuthPlugin) validateConfig() error {
cfgs := map[bool]int{}
cfgs[ap.AWSEnvironmentCredentials != nil]++
cfgs[ap.AWSMetadataCredentials != nil]++
cfgs[ap.AWSWebIdentityCredentials != nil]++
cfgs[ap.AWSProfileCredentials != nil]++

switch n := cfgs[true]; {
case n == 0:
return errors.New("a AWS credential service must be specified when S3 signing is enabled")
case n > 1:
return errors.New("exactly one AWS credential service must be specified when S3 signing is enabled")
}

if ap.AWSMetadataCredentials != nil {
if ap.AWSMetadataCredentials.RegionName == "" {
return nil, errors.New("at least aws_region must be specified for AWS metadata credential service")
return errors.New("at least aws_region must be specified for AWS metadata credential service")
}
}
if ap.AWSWebIdentityCredentials != nil {
if err := ap.AWSWebIdentityCredentials.populateFromEnv(); err != nil {
return nil, err
return err
}
}

if ap.logger == nil {
ap.logger = c.logger
}
if ap.AWSService == "" {
ap.AWSService = awsSigv4SigningDefaultService
}

return DefaultRoundTripperClient(t, *c.ResponseHeaderTimeoutSeconds), nil
}

func (ap *awsSigningAuthPlugin) Prepare(req *http.Request) error {
ap.logger.Debug("Signing request with AWS credentials.")
err := signV4(req, ap.AWSService, ap.awsCredentialService(), time.Now())
return err
return nil
}

0 comments on commit e8d1b5f

Please sign in to comment.