Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(credential-providers): source accountId from credential providers #6019

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
18 changes: 16 additions & 2 deletions clients/client-sts/src/defaultStsRoleAssumers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,24 @@ export const getDefaultRoleAssumer = (
logger: logger as any,
});
}
const { Credentials } = await stsClient.send(new AssumeRoleCommand(params));
const { Credentials, AssumedRoleUser } = await stsClient.send(new AssumeRoleCommand(params));
if (!Credentials || !Credentials.AccessKeyId || !Credentials.SecretAccessKey) {
throw new Error(`Invalid response from STS.assumeRole call with role ${params.RoleArn}`);
}
let accountId;
try {
accountId = AssumedRoleUser.Arn.split(":")[4];
} catch (error) {
accountId = undefined;
}
return {
accessKeyId: Credentials.AccessKeyId,
secretAccessKey: Credentials.SecretAccessKey,
sessionToken: Credentials.SessionToken,
expiration: Credentials.Expiration,
// TODO(credentialScope): access normally when shape is updated.
credentialScope: (Credentials as any).CredentialScope,
accountId,
};
};
};
Expand Down Expand Up @@ -134,17 +141,24 @@ export const getDefaultRoleAssumerWithWebIdentity = (
logger: logger as any,
});
}
const { Credentials } = await stsClient.send(new AssumeRoleWithWebIdentityCommand(params));
const { Credentials, AssumedRoleUser } = await stsClient.send(new AssumeRoleWithWebIdentityCommand(params));
if (!Credentials || !Credentials.AccessKeyId || !Credentials.SecretAccessKey) {
throw new Error(`Invalid response from STS.assumeRoleWithWebIdentity call with role ${params.RoleArn}`);
}
let accountId;
try {
accountId = AssumedRoleUser.Arn.split(":")[4];
} catch (error) {
accountId = undefined;
}
return {
accessKeyId: Credentials.AccessKeyId,
secretAccessKey: Credentials.SecretAccessKey,
sessionToken: Credentials.SessionToken,
expiration: Credentials.Expiration,
// TODO(credentialScope): access normally when shape is updated.
credentialScope: (Credentials as any).CredentialScope,
accountId,
};
};
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,17 @@ describe("getDefaultRoleAssumer", () => {
);
});

it("should return accountId in the credentials", async () => {
const roleAssumer = getDefaultRoleAssumer();
const params: AssumeRoleCommandInput = {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
};
const sourceCred = { accessKeyId: "key", secretAccessKey: "secrete" };
const assumedRole = await roleAssumer(sourceCred, params);
expect(assumedRole.accountId).toEqual("123456789012");
});
siddsriv marked this conversation as resolved.
Show resolved Hide resolved

it("should use the STS client config", async () => {
const logger = console;
const region = "some-region";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,24 @@ export const getDefaultRoleAssumer = (
logger: logger as any,
});
}
const { Credentials } = await stsClient.send(new AssumeRoleCommand(params));
const { Credentials, AssumedRoleUser } = await stsClient.send(new AssumeRoleCommand(params));
if (!Credentials || !Credentials.AccessKeyId || !Credentials.SecretAccessKey) {
throw new Error(`Invalid response from STS.assumeRole call with role ${params.RoleArn}`);
}
let accountId;
try {
accountId = AssumedRoleUser.Arn.split(":")[4];
} catch (error) {
accountId = undefined;
}
return {
accessKeyId: Credentials.AccessKeyId,
secretAccessKey: Credentials.SecretAccessKey,
sessionToken: Credentials.SessionToken,
expiration: Credentials.Expiration,
// TODO(credentialScope): access normally when shape is updated.
credentialScope: (Credentials as any).CredentialScope,
accountId,
};
};
};
Expand Down Expand Up @@ -131,17 +138,24 @@ export const getDefaultRoleAssumerWithWebIdentity = (
logger: logger as any,
});
}
const { Credentials } = await stsClient.send(new AssumeRoleWithWebIdentityCommand(params));
const { Credentials, AssumedRoleUser } = await stsClient.send(new AssumeRoleWithWebIdentityCommand(params));
if (!Credentials || !Credentials.AccessKeyId || !Credentials.SecretAccessKey) {
throw new Error(`Invalid response from STS.assumeRoleWithWebIdentity call with role ${params.RoleArn}`);
}
let accountId;
try {
accountId = AssumedRoleUser.Arn.split(":")[4];
} catch (error) {
accountId = undefined;
}
return {
accessKeyId: Credentials.AccessKeyId,
secretAccessKey: Credentials.SecretAccessKey,
sessionToken: Credentials.SessionToken,
expiration: Credentials.Expiration,
// TODO(credentialScope): access normally when shape is updated.
credentialScope: (Credentials as any).CredentialScope,
accountId,
};
};
};
Expand Down
20 changes: 18 additions & 2 deletions packages/credential-provider-env/src/fromEnv.spec.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import { CredentialsProviderError } from "@smithy/property-provider";

import { ENV_EXPIRATION, ENV_KEY, ENV_SECRET, ENV_SESSION, fromEnv } from "./fromEnv";
import { ENV_ACCOUNT_ID, ENV_EXPIRATION, ENV_KEY, ENV_SECRET, ENV_SESSION, fromEnv } from "./fromEnv";

describe(fromEnv.name, () => {
const ORIGINAL_ENV = process.env;
const mockAccessKeyId = "mockAccessKeyId";
const mockSecretAccessKey = "mockSecretAccessKey";
const mockSessionToken = "mockSessionToken";
const mockExpiration = new Date().toISOString();
const mockAccountId = "123456789012";

beforeEach(() => {
process.env = {
Expand All @@ -16,6 +17,7 @@ describe(fromEnv.name, () => {
[ENV_SECRET]: mockSecretAccessKey,
[ENV_SESSION]: mockSessionToken,
[ENV_EXPIRATION]: mockExpiration,
[ENV_ACCOUNT_ID]: mockAccountId,
};
});

Expand All @@ -30,19 +32,33 @@ describe(fromEnv.name, () => {
secretAccessKey: mockSecretAccessKey,
sessionToken: mockSessionToken,
expiration: new Date(mockExpiration),
accountId: mockAccountId,
});
});

it("can create credentials without a session token or expiration", async () => {
it("can create credentials without a session token, accountId, or expiration", async () => {
delete process.env[ENV_SESSION];
delete process.env[ENV_EXPIRATION];
delete process.env[ENV_ACCOUNT_ID];
const receivedCreds = await fromEnv()();
expect(receivedCreds).toStrictEqual({
accessKeyId: mockAccessKeyId,
secretAccessKey: mockSecretAccessKey,
});
});

it("should include accountId when it is provided in environment variables", async () => {
process.env[ENV_ACCOUNT_ID] = mockAccountId;
const receivedCreds = await fromEnv()();
expect(receivedCreds).toHaveProperty("accountId", mockAccountId);
});

it("should not include accountId when it is not provided in environment variables", async () => {
delete process.env[ENV_ACCOUNT_ID]; // Ensure accountId is not set
const receivedCreds = await fromEnv()();
expect(receivedCreds).not.toHaveProperty("accountId");
});

it.each([ENV_KEY, ENV_SECRET])("throws if env['%s'] is not found", async (key) => {
delete process.env[key];
const expectedError = new CredentialsProviderError("Unable to find environment variable credentials.");
Expand Down
6 changes: 6 additions & 0 deletions packages/credential-provider-env/src/fromEnv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ export const ENV_EXPIRATION = "AWS_CREDENTIAL_EXPIRATION";
* @internal
*/
export const ENV_CREDENTIAL_SCOPE = "AWS_CREDENTIAL_SCOPE";
/**
* @internal
*/
export const ENV_ACCOUNT_ID = "AWS_ACCOUNT_ID";

/**
* @internal
Expand All @@ -41,6 +45,7 @@ export const fromEnv =
const sessionToken: string | undefined = process.env[ENV_SESSION];
const expiry: string | undefined = process.env[ENV_EXPIRATION];
const credentialScope: string | undefined = process.env[ENV_CREDENTIAL_SCOPE];
const accountId: string | undefined = process.env[ENV_ACCOUNT_ID];

if (accessKeyId && secretAccessKey) {
return {
Expand All @@ -49,6 +54,7 @@ export const fromEnv =
...(sessionToken && { sessionToken }),
...(expiry && { expiration: new Date(expiry) }),
...(credentialScope && { credentialScope }),
...(accountId && { accountId }),
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ const getMockStaticCredsProfile = () => ({
aws_secret_access_key: "mock_aws_secret_access_key",
aws_session_token: "mock_aws_session_token",
aws_credential_scope: "mock_aws_credential_scope",
aws_account_id: "mock_aws_account_id",
});

describe(isStaticCredsProfile.name, () => {
Expand Down Expand Up @@ -32,6 +33,12 @@ describe(isStaticCredsProfile.name, () => {
});
});

it.each(["aws_account_id"])("value at '%s' is not of type string | undefined", (key) => {
[true, null, 1, NaN, {}].forEach((value) => {
expect(isStaticCredsProfile({ ...getMockStaticCredsProfile(), [key]: value })).toEqual(false);
});
});

it("returns true for StaticCredentialsProfile", () => {
expect(isStaticCredsProfile(getMockStaticCredsProfile())).toEqual(true);
});
Expand All @@ -46,6 +53,7 @@ describe(resolveStaticCredentials.name, () => {
secretAccessKey: mockProfile.aws_secret_access_key,
sessionToken: mockProfile.aws_session_token,
credentialScope: mockProfile.aws_credential_scope,
accountId: mockProfile.aws_account_id,
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export interface StaticCredsProfile extends Profile {
aws_secret_access_key: string;
aws_session_token?: string;
aws_credential_scope?: string;
aws_account_id?: string;
}

/**
Expand All @@ -35,5 +36,6 @@ export const resolveStaticCredentials = (
secretAccessKey: profile.aws_secret_access_key,
sessionToken: profile.aws_session_token,
credentialScope: profile.aws_credential_scope,
accountId: profile.aws_account_id,
});
};
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ export type ProcessCredentials = {
SessionToken?: string;
Expiration?: number;
CredentialScope?: string;
AccountId?: string;
};
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { AwsCredentialIdentity } from "@smithy/types";
import { AwsCredentialIdentity, ParsedIniData } from "@smithy/types";

import { getValidatedProcessCredentials } from "./getValidatedProcessCredentials";
import { ProcessCredentials } from "./ProcessCredentials";
Expand All @@ -9,40 +9,60 @@ describe(getValidatedProcessCredentials.name, () => {
const mockSecretAccessKey = "mockSecretAccessKey";
const mockSessionToken = "mockSessionToken";
const mockExpiration = Date.now() + 24 * 60 * 60 * 1000;
const mockAccountId = "123456789012";

const mockProfiles: ParsedIniData = {
[mockProfileName]: {
aws_account_id: mockAccountId,
},
};

const getMockProcessCreds = (): ProcessCredentials => ({
Version: 1,
AccessKeyId: mockAccessKeyId,
SecretAccessKey: mockSecretAccessKey,
SessionToken: mockSessionToken,
Expiration: mockExpiration,
AccountId: mockAccountId,
});

it.each([undefined, 2])("throws Error when Version is %s", (Version) => {
expect(() => {
getValidatedProcessCredentials(mockProfileName, {
...getMockProcessCreds(),
Version,
});
getValidatedProcessCredentials(
mockProfileName,
{
...getMockProcessCreds(),
Version,
},
mockProfiles
);
}).toThrow(`Profile ${mockProfileName} credential_process did not return Version 1.`);
});

it.each(["AccessKeyId", "SecretAccessKey"])("throws Error when '%s' is not defined", (key) => {
expect(() => {
getValidatedProcessCredentials(mockProfileName, {
...getMockProcessCreds(),
[key]: undefined,
});
getValidatedProcessCredentials(
mockProfileName,
{
...getMockProcessCreds(),
[key]: undefined,
},
mockProfiles
);
}).toThrow(`Profile ${mockProfileName} credential_process returned invalid credentials.`);
});

it("throws error when credentials are expired", () => {
const expirationDayBefore = Date.now() - 24 * 60 * 60 * 1000;
expect(() => {
getValidatedProcessCredentials(mockProfileName, {
...getMockProcessCreds(),
Expiration: expirationDayBefore,
});
getValidatedProcessCredentials(
mockProfileName,
{
...getMockProcessCreds(),
Expiration: expirationDayBefore,
},
mockProfiles
);
}).toThrow(`Profile ${mockProfileName} credential_process returned expired credentials.`);
});

Expand All @@ -52,18 +72,23 @@ describe(getValidatedProcessCredentials.name, () => {
secretAccessKey: data.SecretAccessKey,
...(data.SessionToken && { sessionToken: data.SessionToken }),
...(data.Expiration && { expiration: new Date(data.Expiration) }),
...(data.AccountId && { accountId: data.AccountId }),
});

it("with all values", () => {
const mockProcessCreds = getMockProcessCreds();
const mockOutputCreds = getValidatedCredentials(mockProcessCreds);
expect(getValidatedProcessCredentials(mockProfileName, mockProcessCreds)).toStrictEqual(mockOutputCreds);
expect(getValidatedProcessCredentials(mockProfileName, mockProcessCreds, mockProfiles)).toStrictEqual(
mockOutputCreds
);
});

it.each(["SessionToken", "Expiration"])("without '%s'", (key) => {
const mockProcessCreds = { ...getMockProcessCreds(), [key]: undefined };
const mockOutputCreds = getValidatedCredentials(mockProcessCreds);
expect(getValidatedProcessCredentials(mockProfileName, mockProcessCreds)).toStrictEqual(mockOutputCreds);
expect(getValidatedProcessCredentials(mockProfileName, mockProcessCreds, mockProfiles)).toStrictEqual(
mockOutputCreds
);
});
});
});
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { AwsCredentialIdentity } from "@smithy/types";
import { AwsCredentialIdentity, ParsedIniData } from "@smithy/types";

import { ProcessCredentials } from "./ProcessCredentials";

Expand All @@ -7,7 +7,8 @@ import { ProcessCredentials } from "./ProcessCredentials";
*/
export const getValidatedProcessCredentials = (
profileName: string,
data: ProcessCredentials
data: ProcessCredentials,
profiles: ParsedIniData
): AwsCredentialIdentity => {
if (data.Version !== 1) {
throw Error(`Profile ${profileName} credential_process did not return Version 1.`);
Expand All @@ -25,11 +26,17 @@ export const getValidatedProcessCredentials = (
}
}

let accountId = data.AccountId;
if (!accountId && profiles?.[profileName]?.aws_account_id) {
accountId = profiles[profileName].aws_account_id;
}

return {
accessKeyId: data.AccessKeyId,
secretAccessKey: data.SecretAccessKey,
...(data.SessionToken && { sessionToken: data.SessionToken }),
...(data.Expiration && { expiration: new Date(data.Expiration) }),
...(data.CredentialScope && { credentialScope: data.CredentialScope }),
...(accountId && { accountId }),
};
};