Skip to content

Commit

Permalink
consolidate repeated code and add check for empty arn
Browse files Browse the repository at this point in the history
  • Loading branch information
vinay-gopalan committed Jan 24, 2024
1 parent c4c6ec7 commit 73ee86b
Showing 1 changed file with 18 additions and 50 deletions.
68 changes: 18 additions & 50 deletions awsutil/generate_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ type CredentialsConfig struct {
// Supported options: WithAccessKey, WithSecretKey, WithLogger, WithStsEndpoint,
// WithIamEndpoint, WithMaxRetries, WithRegion, WithHttpClient, WithRoleArn,
// WithRoleSessionName, WithRoleExternalId, WithRoleTags, WithWebIdentityTokenFile,
// WithWebIdentityToken.
// WithWebIdentityToken, WithWebIdentityTokenFetcher.
func NewCredentialsConfig(opt ...Option) (*CredentialsConfig, error) {
opts, err := getOpts(opt...)
if err != nil {
Expand Down Expand Up @@ -155,6 +155,10 @@ func NewCredentialsConfig(opt ...Option) (*CredentialsConfig, error) {
if len(c.WebIdentityToken) > 0 {
return nil, fmt.Errorf("web identity token specified without role ARN")
}

if c.WebIdentityTokenFetcher != nil {
return nil, fmt.Errorf("web identity token fetcher specified without role ARN")
}
}

c.HTTPClient = opts.withHttpClient
Expand Down Expand Up @@ -237,60 +241,24 @@ func (c *CredentialsConfig) GenerateCredentialChain(opt ...Option) (*credentials
roleSessionName = os.Getenv("AWS_ROLE_SESSION_NAME")
}
if roleARN != "" {
if tokenPath != "" {
// this session is only created to create the WebIdentityRoleProvider, variables used to
// assume a role are pulled from values provided in options. If the option values are
// not set, then the provider will default to using the environment variables.
c.log(hclog.Debug, "adding web identity provider", "roleARN", roleARN)
if tokenPath != "" || c.WebIdentityToken != "" || c.WebIdentityTokenFetcher != nil {
sess, err := session.NewSession()
if err != nil {
return nil, errors.Wrap(err, "error creating a new session to create a WebIdentityRoleProvider")
}
webIdentityProvider := stscreds.NewWebIdentityRoleProvider(sts.New(sess), roleARN, roleSessionName, tokenPath)

if opts.withSkipWebIdentityValidity {
// Add the web identity role credential provider without
// generating credentials to check validity first
providers = append(providers, webIdentityProvider)
} else {
// Check if the webIdentityProvider can successfully retrieve
// credentials (via sts:AssumeRole), and warn if there's a problem.
if _, err := webIdentityProvider.Retrieve(); err != nil {
c.log(hclog.Warn, "error assuming role", "roleARN", roleARN, "tokenPath", tokenPath, "sessionName", roleSessionName, "err", err)
} else {
// Add the web identity role credential provider
providers = append(providers, webIdentityProvider)
}
}
} else if c.WebIdentityToken != "" {
c.log(hclog.Debug, "adding web identity provider with token", "roleARN", roleARN)
sess, err := session.NewSession()
if err != nil {
return nil, errors.Wrap(err, "error creating a new session to create a WebIdentityRoleProvider with token")
}
webIdentityProvider := stscreds.NewWebIdentityRoleProviderWithToken(sts.New(sess), roleARN, roleSessionName, FetchTokenContents(c.WebIdentityToken))

if opts.withSkipWebIdentityValidity {
// Add the web identity role credential provider without
// generating credentials to check validity first
providers = append(providers, webIdentityProvider)
} else {
// Check if the webIdentityProvider can successfully retrieve
// credentials (via sts:AssumeRole), and warn if there's a problem.
if _, err := webIdentityProvider.Retrieve(); err != nil {
c.log(hclog.Warn, "error assuming role with WebIdentityToken", "roleARN", roleARN, "sessionName", roleSessionName, "err", err)
} else {
// Add the web identity role credential provider
providers = append(providers, webIdentityProvider)
}
}
} else if c.WebIdentityTokenFetcher != nil {
c.log(hclog.Debug, "adding web identity provider with token fetcher", "roleARN", roleARN)
sess, err := session.NewSession()
if err != nil {
return nil, errors.Wrap(err, "error creating a new session to create a WebIdentityRoleProvider with token fetcher")
var webIdentityProvider *stscreds.WebIdentityRoleProvider
switch {
case tokenPath != "":
c.log(hclog.Debug, "adding web identity provider", "roleARN", roleARN)
webIdentityProvider = stscreds.NewWebIdentityRoleProvider(sts.New(sess), roleARN, roleSessionName, tokenPath)
case c.WebIdentityToken != "":
c.log(hclog.Debug, "adding web identity provider with token", "roleARN", roleARN)
webIdentityProvider = stscreds.NewWebIdentityRoleProviderWithToken(sts.New(sess), roleARN, roleSessionName, FetchTokenContents(c.WebIdentityToken))
case c.WebIdentityTokenFetcher != nil:
c.log(hclog.Debug, "adding web identity provider with token fetcher", "roleARN", roleARN)
webIdentityProvider = stscreds.NewWebIdentityRoleProviderWithToken(sts.New(sess), roleARN, roleSessionName, c.WebIdentityTokenFetcher)
}
webIdentityProvider := stscreds.NewWebIdentityRoleProviderWithToken(sts.New(sess), roleARN, roleSessionName, c.WebIdentityTokenFetcher)

if opts.withSkipWebIdentityValidity {
// Add the web identity role credential provider without
Expand All @@ -300,7 +268,7 @@ func (c *CredentialsConfig) GenerateCredentialChain(opt ...Option) (*credentials
// Check if the webIdentityProvider can successfully retrieve
// credentials (via sts:AssumeRole), and warn if there's a problem.
if _, err := webIdentityProvider.Retrieve(); err != nil {
c.log(hclog.Warn, "error assuming role with WebIdentityToken", "roleARN", roleARN, "sessionName", roleSessionName, "err", err)
c.log(hclog.Warn, "error assuming role with web identity", "roleARN", roleARN, "sessionName", roleSessionName, "err", err)
} else {
// Add the web identity role credential provider
providers = append(providers, webIdentityProvider)
Expand Down

0 comments on commit 73ee86b

Please sign in to comment.