Skip to content

Commit

Permalink
Merge pull request #116 from hashicorp/VAULT-23359/add-identiy-token-…
Browse files Browse the repository at this point in the history
…fetcher

Add ability to set a custom identity token fetcher implementation
  • Loading branch information
vinay-gopalan committed Jan 24, 2024
2 parents 7a5e901 + 73ee86b commit 44cae24
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 28 deletions.
51 changes: 23 additions & 28 deletions awsutil/generate_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ type CredentialsConfig struct {
// identity token provider
WebIdentityToken string

// The web identity token fetcher to use with the web identity token provider
WebIdentityTokenFetcher stscreds.TokenFetcher

// The http.Client to use, or nil for the client to use its default
HTTPClient *http.Client

Expand All @@ -92,7 +95,7 @@ type CredentialsConfig struct {
// Supported options: WithAccessKey, WithSecretKey, WithLogger, WithStsEndpoint,
// WithIamEndpoint, WithMaxRetries, WithRegion, WithHttpClient, WithRoleArn,
// WithRoleSessionName, WithRoleExternalId, WithRoleTags, WithWebIdentityTokenFile,
// WithWebIdentityToken.
// WithWebIdentityToken, WithWebIdentityTokenFetcher.
func NewCredentialsConfig(opt ...Option) (*CredentialsConfig, error) {
opts, err := getOpts(opt...)
if err != nil {
Expand Down Expand Up @@ -134,6 +137,8 @@ func NewCredentialsConfig(opt ...Option) (*CredentialsConfig, error) {
}
c.WebIdentityToken = opts.withWebIdentityToken

c.WebIdentityTokenFetcher = opts.withWebIdentityTokenFetcher

if c.RoleARN == "" {
if c.RoleSessionName != "" {
return nil, fmt.Errorf("role session name specified without role ARN")
Expand All @@ -150,6 +155,10 @@ func NewCredentialsConfig(opt ...Option) (*CredentialsConfig, error) {
if len(c.WebIdentityToken) > 0 {
return nil, fmt.Errorf("web identity token specified without role ARN")
}

if c.WebIdentityTokenFetcher != nil {
return nil, fmt.Errorf("web identity token fetcher specified without role ARN")
}
}

c.HTTPClient = opts.withHttpClient
Expand Down Expand Up @@ -232,38 +241,24 @@ func (c *CredentialsConfig) GenerateCredentialChain(opt ...Option) (*credentials
roleSessionName = os.Getenv("AWS_ROLE_SESSION_NAME")
}
if roleARN != "" {
if tokenPath != "" {
// this session is only created to create the WebIdentityRoleProvider, variables used to
// assume a role are pulled from values provided in options. If the option values are
// not set, then the provider will default to using the environment variables.
c.log(hclog.Debug, "adding web identity provider", "roleARN", roleARN)
if tokenPath != "" || c.WebIdentityToken != "" || c.WebIdentityTokenFetcher != nil {
sess, err := session.NewSession()
if err != nil {
return nil, errors.Wrap(err, "error creating a new session to create a WebIdentityRoleProvider")
}
webIdentityProvider := stscreds.NewWebIdentityRoleProvider(sts.New(sess), roleARN, roleSessionName, tokenPath)

if opts.withSkipWebIdentityValidity {
// Add the web identity role credential provider without
// generating credentials to check validity first
providers = append(providers, webIdentityProvider)
} else {
// Check if the webIdentityProvider can successfully retrieve
// credentials (via sts:AssumeRole), and warn if there's a problem.
if _, err := webIdentityProvider.Retrieve(); err != nil {
c.log(hclog.Warn, "error assuming role", "roleARN", roleARN, "tokenPath", tokenPath, "sessionName", roleSessionName, "err", err)
} else {
// Add the web identity role credential provider
providers = append(providers, webIdentityProvider)
}
}
} else if c.WebIdentityToken != "" {
c.log(hclog.Debug, "adding web identity provider with token", "roleARN", roleARN)
sess, err := session.NewSession()
if err != nil {
return nil, errors.Wrap(err, "error creating a new session to create a WebIdentityRoleProvider with token")
var webIdentityProvider *stscreds.WebIdentityRoleProvider
switch {
case tokenPath != "":
c.log(hclog.Debug, "adding web identity provider", "roleARN", roleARN)
webIdentityProvider = stscreds.NewWebIdentityRoleProvider(sts.New(sess), roleARN, roleSessionName, tokenPath)
case c.WebIdentityToken != "":
c.log(hclog.Debug, "adding web identity provider with token", "roleARN", roleARN)
webIdentityProvider = stscreds.NewWebIdentityRoleProviderWithToken(sts.New(sess), roleARN, roleSessionName, FetchTokenContents(c.WebIdentityToken))
case c.WebIdentityTokenFetcher != nil:
c.log(hclog.Debug, "adding web identity provider with token fetcher", "roleARN", roleARN)
webIdentityProvider = stscreds.NewWebIdentityRoleProviderWithToken(sts.New(sess), roleARN, roleSessionName, c.WebIdentityTokenFetcher)
}
webIdentityProvider := stscreds.NewWebIdentityRoleProviderWithToken(sts.New(sess), roleARN, roleSessionName, FetchTokenContents(c.WebIdentityToken))

if opts.withSkipWebIdentityValidity {
// Add the web identity role credential provider without
Expand All @@ -273,7 +268,7 @@ func (c *CredentialsConfig) GenerateCredentialChain(opt ...Option) (*credentials
// Check if the webIdentityProvider can successfully retrieve
// credentials (via sts:AssumeRole), and warn if there's a problem.
if _, err := webIdentityProvider.Retrieve(); err != nil {
c.log(hclog.Warn, "error assuming role with WebIdentityToken", "roleARN", roleARN, "sessionName", roleSessionName, "err", err)
c.log(hclog.Warn, "error assuming role with web identity", "roleARN", roleARN, "sessionName", roleSessionName, "err", err)
} else {
// Add the web identity role credential provider
providers = append(providers, webIdentityProvider)
Expand Down
12 changes: 12 additions & 0 deletions awsutil/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http"
"time"

"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/hashicorp/go-hclog"
)
Expand Down Expand Up @@ -50,6 +51,7 @@ type options struct {
withWebIdentityTokenFile string
withWebIdentityToken string
withSkipWebIdentityValidity bool
withWebIdentityTokenFetcher stscreds.TokenFetcher
withHttpClient *http.Client
withValidityCheckTimeout time.Duration
withIAMAPIFunc IAMAPIFunc
Expand Down Expand Up @@ -124,6 +126,16 @@ func WithWebIdentityToken(with string) Option {
}
}

// WithWebIdentityTokenFetcher allows passing an STS TokenFetcher which
// allows the AWS SDK client automatically to refresh the web identity token
// from any source.
func WithWebIdentityTokenFetcher(with stscreds.TokenFetcher) Option {
return func(o *options) error {
o.withWebIdentityTokenFetcher = with
return nil
}
}

// WithSkipWebIdentityValidity allows controlling whether the validity check is
// skipped for the web identity provider
func WithSkipWebIdentityValidity(with bool) Option {
Expand Down
14 changes: 14 additions & 0 deletions awsutil/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,14 @@ func Test_GetOpts(t *testing.T) {
testOpts.withWebIdentityToken = "foo"
assert.Equal(t, opts, testOpts)
})
t.Run("WithWebIdentityTokenFetcher", func(t *testing.T) {
f := testFetcher{}
opts, err := getOpts(WithWebIdentityTokenFetcher(f))
require.NoError(t, err)
testOpts := getDefaultOptions()
testOpts.withWebIdentityTokenFetcher = f
assert.Equal(t, opts, testOpts)
})
t.Run("WithSkipWebIdentityValidity", func(t *testing.T) {
opts, err := getOpts(WithSkipWebIdentityValidity(true))
require.NoError(t, err)
Expand All @@ -185,3 +193,9 @@ func Test_GetOpts(t *testing.T) {
assert.Equal(t, opts, testOpts)
})
}

type testFetcher struct{}

func (testFetcher) FetchToken(_ aws.Context) ([]byte, error) {
return nil, nil
}

0 comments on commit 44cae24

Please sign in to comment.