Skip to content

Commit

Permalink
Set tokenProvider() for Bearer auth services
Browse files Browse the repository at this point in the history
  • Loading branch information
dagnir committed Dec 1, 2022
1 parent 966f995 commit 55c14ed
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 2 deletions.
Expand Up @@ -58,6 +58,7 @@
import software.amazon.awssdk.codegen.poet.ClassSpec;
import software.amazon.awssdk.codegen.poet.PoetExtension;
import software.amazon.awssdk.codegen.poet.PoetUtils;
import software.amazon.awssdk.codegen.utils.AuthUtils;
import software.amazon.awssdk.core.SdkSystemSetting;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.rules.testing.AsyncTestCase;
Expand Down Expand Up @@ -252,6 +253,9 @@ private CodeBlock syncOperationCallLambda(OperationModel opModel, Map<String, Tr
b.beginControlFlow("() -> ");
b.addStatement("$T builder = $T.builder()", syncClientBuilder(), syncClientClass());
b.addStatement("builder.credentialsProvider($T.CREDENTIALS_PROVIDER)", BaseRuleSetClientTest.class);
if (AuthUtils.usesBearerAuth(model)) {
b.addStatement("builder.tokenProvider($T.TOKEN_PROVIDER)", BaseRuleSetClientTest.class);
}
b.addStatement("builder.httpClient(getSyncHttpClient())");

if (params != null) {
Expand All @@ -276,6 +280,9 @@ private CodeBlock asyncOperationCallLambda(OperationModel opModel, Map<String, T
b.beginControlFlow("() -> ");
b.addStatement("$T builder = $T.builder()", asyncClientBuilder(), asyncClientClass());
b.addStatement("builder.credentialsProvider($T.CREDENTIALS_PROVIDER)", BaseRuleSetClientTest.class);
if (AuthUtils.usesBearerAuth(model)) {
b.addStatement("builder.tokenProvider($T.TOKEN_PROVIDER)", BaseRuleSetClientTest.class);
}
b.addStatement("builder.httpClient(getAsyncHttpClient())");

if (params != null) {
Expand Down
Expand Up @@ -4,6 +4,7 @@
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.auth.token.credentials.SdkTokenProvider;
import software.amazon.awssdk.awscore.client.config.AwsClientOption;
import software.amazon.awssdk.core.client.config.SdkClientConfiguration;
import software.amazon.awssdk.core.client.config.SdkClientOption;
import software.amazon.awssdk.services.json.endpoints.JsonEndpointProvider;

Expand All @@ -28,6 +29,8 @@ public DefaultJsonAsyncClientBuilder tokenProvider(SdkTokenProvider tokenProvide

@Override
protected final JsonAsyncClient buildClient() {
return new DefaultJsonAsyncClient(super.asyncClientConfiguration());
SdkClientConfiguration clientConfiguration = super.asyncClientConfiguration();
this.validateClientOptions(clientConfiguration);
return new DefaultJsonAsyncClient(clientConfiguration);
}
}
Expand Up @@ -20,6 +20,7 @@
import software.amazon.awssdk.services.json.endpoints.internal.JsonRequestSetEndpointInterceptor;
import software.amazon.awssdk.services.json.endpoints.internal.JsonResolveEndpointInterceptor;
import software.amazon.awssdk.utils.CollectionUtils;
import software.amazon.awssdk.utils.Validate;

/**
* Internal base class for {@link DefaultJsonClientBuilder} and {@link DefaultJsonAsyncClientBuilder}.
Expand Down Expand Up @@ -77,4 +78,11 @@ private SdkTokenProvider defaultTokenProvider() {
private Signer defaultTokenSigner() {
return BearerTokenSigner.create();
}

protected static void validateClientOptions(SdkClientConfiguration c) {
Validate.notNull(c.option(SdkAdvancedClientOption.TOKEN_SIGNER),
"The 'overrideConfiguration.advancedOption[TOKEN_SIGNER]' must be configured in the client builder.");
Validate.notNull(c.option(AwsClientOption.TOKEN_PROVIDER),
"The 'overrideConfiguration.advancedOption[TOKEN_PROVIDER]' must be configured in the client builder.");
}
}
Expand Up @@ -18,6 +18,7 @@
import software.amazon.awssdk.services.json.endpoints.internal.JsonRequestSetEndpointInterceptor;
import software.amazon.awssdk.services.json.endpoints.internal.JsonResolveEndpointInterceptor;
import software.amazon.awssdk.utils.CollectionUtils;
import software.amazon.awssdk.utils.Validate;

/**
* Internal base class for {@link DefaultJsonClientBuilder} and {@link DefaultJsonAsyncClientBuilder}.
Expand Down Expand Up @@ -78,4 +79,9 @@ protected final String signingName() {
private JsonEndpointProvider defaultEndpointProvider() {
return JsonEndpointProvider.defaultProvider();
}

protected static void validateClientOptions(SdkClientConfiguration c) {
Validate.notNull(c.option(SdkAdvancedClientOption.SIGNER),
"The 'overrideConfiguration.advancedOption[SIGNER]' must be configured in the client builder.");
}
}
Expand Up @@ -4,6 +4,7 @@
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.auth.token.credentials.SdkTokenProvider;
import software.amazon.awssdk.awscore.client.config.AwsClientOption;
import software.amazon.awssdk.core.client.config.SdkClientConfiguration;
import software.amazon.awssdk.core.client.config.SdkClientOption;
import software.amazon.awssdk.services.json.endpoints.JsonEndpointProvider;

Expand All @@ -28,6 +29,8 @@ public DefaultJsonClientBuilder tokenProvider(SdkTokenProvider tokenProvider) {

@Override
protected final JsonClient buildClient() {
return new DefaultJsonClient(super.syncClientConfiguration());
SdkClientConfiguration clientConfiguration = super.syncClientConfiguration();
this.validateClientOptions(clientConfiguration);
return new DefaultJsonClient(clientConfiguration);
}
}
Expand Up @@ -50,6 +50,7 @@ private static List<SyncTestCase> syncTestCases() {
new SyncTestCase("test case 1", () -> {
QueryClientBuilder builder = QueryClient.builder();
builder.credentialsProvider(BaseRuleSetClientTest.CREDENTIALS_PROVIDER);
builder.tokenProvider(BaseRuleSetClientTest.TOKEN_PROVIDER);
builder.httpClient(getSyncHttpClient());
builder.region(Region.of("us-east-1"));
APostOperationRequest request = APostOperationRequest.builder().build();
Expand All @@ -58,6 +59,7 @@ private static List<SyncTestCase> syncTestCases() {
new SyncTestCase("test case 2", () -> {
QueryClientBuilder builder = QueryClient.builder();
builder.credentialsProvider(BaseRuleSetClientTest.CREDENTIALS_PROVIDER);
builder.tokenProvider(BaseRuleSetClientTest.TOKEN_PROVIDER);
builder.httpClient(getSyncHttpClient());
builder.region(Region.of("us-east-1"));
builder.booleanContextParam(true);
Expand All @@ -68,6 +70,7 @@ private static List<SyncTestCase> syncTestCases() {
new SyncTestCase("test case 3", () -> {
QueryClientBuilder builder = QueryClient.builder();
builder.credentialsProvider(BaseRuleSetClientTest.CREDENTIALS_PROVIDER);
builder.tokenProvider(BaseRuleSetClientTest.TOKEN_PROVIDER);
builder.httpClient(getSyncHttpClient());
builder.region(Region.of("us-east-1"));
OperationWithContextParamRequest request = OperationWithContextParamRequest.builder()
Expand All @@ -77,6 +80,7 @@ private static List<SyncTestCase> syncTestCases() {
new SyncTestCase("test case 4", () -> {
QueryClientBuilder builder = QueryClient.builder();
builder.credentialsProvider(BaseRuleSetClientTest.CREDENTIALS_PROVIDER);
builder.tokenProvider(BaseRuleSetClientTest.TOKEN_PROVIDER);
builder.httpClient(getSyncHttpClient());
builder.region(Region.of("us-east-6"));
OperationWithContextParamRequest request = OperationWithContextParamRequest.builder()
Expand All @@ -87,13 +91,15 @@ private static List<SyncTestCase> syncTestCases() {
new SyncTestCase("For region us-iso-west-1 with FIPS enabled and DualStack enabled", () -> {
QueryClientBuilder builder = QueryClient.builder();
builder.credentialsProvider(BaseRuleSetClientTest.CREDENTIALS_PROVIDER);
builder.tokenProvider(BaseRuleSetClientTest.TOKEN_PROVIDER);
builder.httpClient(getSyncHttpClient());
APostOperationRequest request = APostOperationRequest.builder().build();
builder.build().aPostOperation(request);
}, Expect.builder().error("Should have been skipped!").build(), "Client builder does the validation"),
new SyncTestCase("Has complex operation input", () -> {
QueryClientBuilder builder = QueryClient.builder();
builder.credentialsProvider(BaseRuleSetClientTest.CREDENTIALS_PROVIDER);
builder.tokenProvider(BaseRuleSetClientTest.TOKEN_PROVIDER);
builder.httpClient(getSyncHttpClient());
OperationWithContextParamRequest request = OperationWithContextParamRequest.builder()
.nestedMember(ChecksumStructure.builder().checksumMode("foo").build()).build();
Expand All @@ -106,6 +112,7 @@ private static List<AsyncTestCase> asyncTestCases() {
new AsyncTestCase("test case 1", () -> {
QueryAsyncClientBuilder builder = QueryAsyncClient.builder();
builder.credentialsProvider(BaseRuleSetClientTest.CREDENTIALS_PROVIDER);
builder.tokenProvider(BaseRuleSetClientTest.TOKEN_PROVIDER);
builder.httpClient(getAsyncHttpClient());
builder.region(Region.of("us-east-1"));
APostOperationRequest request = APostOperationRequest.builder().build();
Expand All @@ -114,6 +121,7 @@ private static List<AsyncTestCase> asyncTestCases() {
new AsyncTestCase("test case 2", () -> {
QueryAsyncClientBuilder builder = QueryAsyncClient.builder();
builder.credentialsProvider(BaseRuleSetClientTest.CREDENTIALS_PROVIDER);
builder.tokenProvider(BaseRuleSetClientTest.TOKEN_PROVIDER);
builder.httpClient(getAsyncHttpClient());
builder.region(Region.of("us-east-1"));
builder.booleanContextParam(true);
Expand All @@ -124,6 +132,7 @@ private static List<AsyncTestCase> asyncTestCases() {
new AsyncTestCase("test case 3", () -> {
QueryAsyncClientBuilder builder = QueryAsyncClient.builder();
builder.credentialsProvider(BaseRuleSetClientTest.CREDENTIALS_PROVIDER);
builder.tokenProvider(BaseRuleSetClientTest.TOKEN_PROVIDER);
builder.httpClient(getAsyncHttpClient());
builder.region(Region.of("us-east-1"));
OperationWithContextParamRequest request = OperationWithContextParamRequest.builder()
Expand All @@ -133,6 +142,7 @@ private static List<AsyncTestCase> asyncTestCases() {
new AsyncTestCase("test case 4", () -> {
QueryAsyncClientBuilder builder = QueryAsyncClient.builder();
builder.credentialsProvider(BaseRuleSetClientTest.CREDENTIALS_PROVIDER);
builder.tokenProvider(BaseRuleSetClientTest.TOKEN_PROVIDER);
builder.httpClient(getAsyncHttpClient());
builder.region(Region.of("us-east-6"));
OperationWithContextParamRequest request = OperationWithContextParamRequest.builder()
Expand All @@ -143,13 +153,15 @@ private static List<AsyncTestCase> asyncTestCases() {
new AsyncTestCase("For region us-iso-west-1 with FIPS enabled and DualStack enabled", () -> {
QueryAsyncClientBuilder builder = QueryAsyncClient.builder();
builder.credentialsProvider(BaseRuleSetClientTest.CREDENTIALS_PROVIDER);
builder.tokenProvider(BaseRuleSetClientTest.TOKEN_PROVIDER);
builder.httpClient(getAsyncHttpClient());
APostOperationRequest request = APostOperationRequest.builder().build();
return builder.build().aPostOperation(request);
}, Expect.builder().error("Should have been skipped!").build(), "Client builder does the validation"),
new AsyncTestCase("Has complex operation input", () -> {
QueryAsyncClientBuilder builder = QueryAsyncClient.builder();
builder.credentialsProvider(BaseRuleSetClientTest.CREDENTIALS_PROVIDER);
builder.tokenProvider(BaseRuleSetClientTest.TOKEN_PROVIDER);
builder.httpClient(getAsyncHttpClient());
OperationWithContextParamRequest request = OperationWithContextParamRequest.builder()
.nestedMember(ChecksumStructure.builder().checksumMode("foo").build()).build();
Expand Down
Expand Up @@ -24,6 +24,8 @@
import static org.mockito.Mockito.when;

import java.net.URI;
import java.time.Instant;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.function.Supplier;
import org.junit.jupiter.api.Assumptions;
Expand All @@ -33,6 +35,10 @@
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.auth.token.credentials.SdkToken;
import software.amazon.awssdk.auth.token.credentials.SdkTokenProvider;
import software.amazon.awssdk.auth.token.credentials.StaticTokenProvider;
import software.amazon.awssdk.auth.token.credentials.aws.DefaultAwsTokenProvider;
import software.amazon.awssdk.core.rules.testing.model.Expect;
import software.amazon.awssdk.endpoints.Endpoint;
import software.amazon.awssdk.http.HttpExecuteRequest;
Expand All @@ -45,6 +51,7 @@
public abstract class BaseRuleSetClientTest {
protected static final AwsCredentialsProvider CREDENTIALS_PROVIDER =
StaticCredentialsProvider.create(AwsBasicCredentials.create("akid", "skid"));
protected static final SdkTokenProvider TOKEN_PROVIDER = StaticTokenProvider.create(new TestSdkToken());

private static SdkHttpClient syncHttpClient;
private static SdkAsyncHttpClient asyncHttpClient;
Expand Down Expand Up @@ -122,4 +129,17 @@ protected static SdkHttpClient getSyncHttpClient() {
protected static SdkAsyncHttpClient getAsyncHttpClient() {
return asyncHttpClient;
}

private static class TestSdkToken implements SdkToken {

@Override
public String token() {
return "TOKEN";
}

@Override
public Optional<Instant> expirationTime() {
return Optional.empty();
}
}
}

0 comments on commit 55c14ed

Please sign in to comment.