Skip to content

Commit

Permalink
Support HostPrefix in E2.0, enable E2.0 for all services
Browse files Browse the repository at this point in the history
  • Loading branch information
dagnir committed Nov 23, 2022
1 parent 70929c3 commit b0a4a97
Show file tree
Hide file tree
Showing 37 changed files with 595 additions and 105 deletions.
Expand Up @@ -18,7 +18,6 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import software.amazon.awssdk.codegen.emitters.GeneratorTask;
Expand Down Expand Up @@ -46,10 +45,6 @@ public EndpointProviderTasks(GeneratorTaskParams dependencies) {

@Override
protected List<GeneratorTask> createTasks() throws Exception {
if (!generatorTaskParams.getModel().getCustomizationConfig().useRuleBasedEndpoints()) {
return Collections.emptyList();
}

List<GeneratorTask> tasks = new ArrayList<>();
tasks.add(generateInterface());
tasks.add(generateParams());
Expand Down
Expand Up @@ -213,11 +213,6 @@ public class CustomizationConfig {

private boolean useGlobalEndpoint;

/**
* Whether Endpoints 2.0/rule based endpoints should be used for endpoint resolution.
*/
private boolean useRuleBasedEndpoints = false;

private List<String> interceptors = new ArrayList<>();

private CustomizationConfig() {
Expand Down Expand Up @@ -557,14 +552,6 @@ public void setSkipEndpointTests(Map<String, String> skipEndpointTests) {
this.skipEndpointTests = skipEndpointTests;
}

public boolean useRuleBasedEndpoints() {
return useRuleBasedEndpoints;
}

public void setUseRuleBasedEndpoints(boolean useRuleBasedEndpoints) {
this.useRuleBasedEndpoints = useRuleBasedEndpoints;
}

public List<String> getInterceptors() {
return interceptors;
}
Expand Down
Expand Up @@ -68,9 +68,7 @@ public TypeSpec poetSpec() {
}
}

if (endpointRulesSpecUtils.isEndpointRulesEnabled()) {
builder.addMethod(endpointProviderMethod());
}
builder.addMethod(endpointProviderMethod());

if (BearerAuthUtils.usesBearerAuth(model)) {
builder.addMethod(bearerTokenProviderMethod());
Expand Down
Expand Up @@ -111,11 +111,9 @@ public TypeSpec poetSpec() {
builder.addMethod(finalizeServiceConfigurationMethod());
defaultAwsAuthSignerMethod().ifPresent(builder::addMethod);
builder.addMethod(signingNameMethod());
if (endpointRulesSpecUtils.isEndpointRulesEnabled()) {
builder.addMethod(defaultEndpointProviderMethod());
}
builder.addMethod(defaultEndpointProviderMethod());

if (hasClientContextParams() && endpointRulesSpecUtils.isEndpointRulesEnabled()) {
if (hasClientContextParams()) {
model.getClientContextParams().forEach((n, m) -> {
builder.addMethod(clientContextParamSetter(n, m));
});
Expand Down Expand Up @@ -189,9 +187,8 @@ private MethodSpec mergeServiceDefaultsMethod() {
.addParameter(SdkClientConfiguration.class, "config")
.addCode("return config.merge(c -> c");

if (endpointRulesSpecUtils.isEndpointRulesEnabled()) {
builder.addCode(".option($T.ENDPOINT_PROVIDER, defaultEndpointProvider())", SdkClientOption.class);
}
builder.addCode(".option($T.ENDPOINT_PROVIDER, defaultEndpointProvider())", SdkClientOption.class);


if (defaultAwsAuthSignerMethod().isPresent()) {
builder.addCode(".option($T.SIGNER, defaultSigner())\n", SdkAdvancedClientOption.class);
Expand Down Expand Up @@ -259,11 +256,9 @@ private MethodSpec finalizeServiceConfigurationMethod() {

List<ClassName> builtInInterceptors = new ArrayList<>();

if (endpointRulesSpecUtils.isEndpointRulesEnabled()) {
builtInInterceptors.add(endpointRulesSpecUtils.resolverInterceptorName());
builtInInterceptors.add(endpointRulesSpecUtils.authSchemesInterceptorName());
builtInInterceptors.add(endpointRulesSpecUtils.requestModifierInterceptorName());
}
builtInInterceptors.add(endpointRulesSpecUtils.resolverInterceptorName());
builtInInterceptors.add(endpointRulesSpecUtils.authSchemesInterceptorName());
builtInInterceptors.add(endpointRulesSpecUtils.requestModifierInterceptorName());

for (String interceptor : model.getCustomizationConfig().getInterceptors()) {
builtInInterceptors.add(ClassName.bestGuess(interceptor));
Expand Down
Expand Up @@ -38,7 +38,6 @@
import software.amazon.awssdk.core.client.config.SdkAdvancedClientOption;
import software.amazon.awssdk.utils.internal.CodegenNamingUtils;


public class BaseClientBuilderInterface implements ClassSpec {
private final IntermediateModel model;
private final String basePackage;
Expand Down Expand Up @@ -73,14 +72,12 @@ public TypeSpec poetSpec() {
builder.addMethod(serviceConfigurationConsumerBuilderMethod());
}

if (endpointRulesSpecUtils.isEndpointRulesEnabled()) {
builder.addMethod(endpointProviderMethod());
builder.addMethod(endpointProviderMethod());

if (hasClientContextParams()) {
model.getClientContextParams().forEach((n, m) -> {
builder.addMethod(clientContextParamSetter(n, m));
});
}
if (hasClientContextParams()) {
model.getClientContextParams().forEach((n, m) -> {
builder.addMethod(clientContextParamSetter(n, m));
});
}

if (generateTokenProviderMethod()) {
Expand Down
Expand Up @@ -68,9 +68,7 @@ public TypeSpec poetSpec() {
}
}

if (endpointRulesSpecUtils.isEndpointRulesEnabled()) {
builder.addMethod(endpointProviderMethod());
}
builder.addMethod(endpointProviderMethod());

if (BearerAuthUtils.usesBearerAuth(model)) {
builder.addMethod(tokenProviderMethodImpl());
Expand Down
Expand Up @@ -88,7 +88,7 @@ private MethodSpec testsCasesMethod() {
b.addStatement("testCases.add(new $T($L, $L))",
EndpointProviderTestCase.class,
createTestCase(test),
TestGeneratorUtils.createExpect(test.getExpect()));
TestGeneratorUtils.createExpect(test.getExpect(), null));
});

b.addStatement("return testCases");
Expand Down
Expand Up @@ -20,7 +20,9 @@
import com.fasterxml.jackson.jr.stree.JrsString;
import com.squareup.javapoet.ClassName;
import com.squareup.javapoet.MethodSpec;
import com.squareup.javapoet.ParameterizedTypeName;
import com.squareup.javapoet.TypeSpec;
import java.util.Iterator;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletionException;
Expand All @@ -32,6 +34,8 @@
import software.amazon.awssdk.codegen.model.rules.endpoints.ParameterModel;
import software.amazon.awssdk.codegen.model.service.ClientContextParam;
import software.amazon.awssdk.codegen.model.service.ContextParam;
import software.amazon.awssdk.codegen.model.service.EndpointTrait;
import software.amazon.awssdk.codegen.model.service.HostPrefixProcessor;
import software.amazon.awssdk.codegen.model.service.StaticContextParam;
import software.amazon.awssdk.codegen.poet.ClassSpec;
import software.amazon.awssdk.codegen.poet.PoetExtension;
Expand All @@ -41,9 +45,12 @@
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute;
import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute;
import software.amazon.awssdk.endpoints.Endpoint;
import software.amazon.awssdk.utils.AttributeMap;
import software.amazon.awssdk.utils.HostnameValidator;
import software.amazon.awssdk.utils.StringUtils;

public class EndpointResolverInterceptorSpec implements ClassSpec {
private final IntermediateModel model;
Expand Down Expand Up @@ -76,6 +83,8 @@ public TypeSpec poetSpec() {
b.addMethod(setClientContextParamsMethod());
}

b.addMethod(hostPrefixMethod());

return b.build();
}

Expand Down Expand Up @@ -106,6 +115,13 @@ private MethodSpec modifyRequestMethod() {
b.beginControlFlow("try");
b.addStatement("$T result = $N.resolveEndpoint(ruleParams(context, executionAttributes)).join()", Endpoint.class,
providerVar);
b.addStatement("$T hostPrefix = hostPrefix(executionAttributes.getAttribute($T.OPERATION_NAME), context.request())",
ParameterizedTypeName.get(Optional.class, String.class), SdkExecutionAttribute.class);
b.beginControlFlow("if (hostPrefix.isPresent() && !$T.disableHostPrefixInjection(executionAttributes))",
endpointRulesSpecUtils.rulesRuntimeClassName("AwsEndpointProviderUtils"));
b.addStatement("result = $T.addHostPrefix(result, hostPrefix.get())",
endpointRulesSpecUtils.rulesRuntimeClassName("AwsEndpointProviderUtils"));
b.endControlFlow();
b.addStatement("executionAttributes.putAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT, result)");
b.addStatement("return context.request()");
b.endControlFlow();
Expand Down Expand Up @@ -171,7 +187,7 @@ private MethodSpec ruleParams() {
case AWS_S3_FORCE_PATH_STYLE:
case AWS_S3_USE_ARN_REGION:
case AWS_S3_CONTROL_USE_ARN_REGION:
// end of S3 specific builtins
// end of S3 specific builtins
case AWS_STS_USE_GLOBAL_ENDPOINT:
// V2 doesn't support this, only regional endpoints
return;
Expand Down Expand Up @@ -361,4 +377,66 @@ private MethodSpec setClientContextParamsMethod() {

return b.build();
}


private MethodSpec hostPrefixMethod() {
MethodSpec.Builder builder = MethodSpec.methodBuilder("hostPrefix")
.returns(ParameterizedTypeName.get(Optional.class, String.class))
.addParameter(String.class, "operationName")
.addParameter(SdkRequest.class, "request")
.addModifiers(Modifier.PRIVATE, Modifier.STATIC);

builder.beginControlFlow("switch (operationName)");

model.getOperations().forEach((name, opModel) -> {
String hostPrefix = getHostPrefix(opModel);
if (StringUtils.isBlank(hostPrefix)) {
return;
}

builder.beginControlFlow("case $S:", name);
HostPrefixProcessor processor = new HostPrefixProcessor(hostPrefix);

if (processor.c2jNames().isEmpty()) {
builder.addStatement("return $T.of($S)", Optional.class, processor.hostWithStringSpecifier());
} else {
String requestVar = opModel.getInput().getVariableName();
processor.c2jNames().forEach(c2jName -> {
builder.addStatement("$1T.validateHostnameCompliant(request.getValueForField($2S, $3T.class).orElse(null), "
+ "$2S, $4S)",
HostnameValidator.class,
c2jName,
String.class,
requestVar);
});

builder.addCode("return $T.of($T.format($S, ", Optional.class, String.class,
processor.hostWithStringSpecifier());
Iterator<String> c2jNamesIter = processor.c2jNames().listIterator();
while (c2jNamesIter.hasNext()) {
builder.addCode("request.getValueForField($S, $T.class).get()", c2jNamesIter.next(), String.class);
if (c2jNamesIter.hasNext()) {
builder.addCode(",");
}
}
builder.addStatement("))");
}
builder.endControlFlow();
});

builder.addCode("default:");
builder.addStatement("return $T.empty()", Optional.class);
builder.endControlFlow();

return builder.build();
}

private String getHostPrefix(OperationModel opModel) {
EndpointTrait endpointTrait = opModel.getEndpointTrait();
if (endpointTrait == null) {
return null;
}

return endpointTrait.getHostPrefix();
}
}
Expand Up @@ -217,7 +217,8 @@ private MethodSpec syncTestsSourceMethod() {
SyncTestCase.class,
test.getDocumentation(),
syncOperationCallLambda(opModel, test.getParams(), opInput.getOperationParams()),
TestGeneratorUtils.createExpect(test.getExpect()), getSkipReasonBlock(test.getDocumentation()));
TestGeneratorUtils.createExpect(test.getExpect(), opModel),
getSkipReasonBlock(test.getDocumentation()));

if (operationInputsIter.hasNext()) {
b.addCode(",");
Expand All @@ -228,7 +229,8 @@ private MethodSpec syncTestsSourceMethod() {
SyncTestCase.class,
test.getDocumentation(),
syncOperationCallLambda(defaultOpModel, test.getParams(), Collections.emptyMap()),
TestGeneratorUtils.createExpect(test.getExpect()), getSkipReasonBlock(test.getDocumentation()));
TestGeneratorUtils.createExpect(test.getExpect(), defaultOpModel),
getSkipReasonBlock(test.getDocumentation()));
}

if (testIter.hasNext()) {
Expand Down Expand Up @@ -355,7 +357,8 @@ private MethodSpec asyncTestsSourceMethod() {
AsyncTestCase.class,
test.getDocumentation(),
asyncOperationCallLambda(opModel, test.getParams(), opInput.getOperationParams()),
TestGeneratorUtils.createExpect(test.getExpect()), getSkipReasonBlock(test.getDocumentation()));
TestGeneratorUtils.createExpect(test.getExpect(), opModel),
getSkipReasonBlock(test.getDocumentation()));

if (operationInputsIter.hasNext()) {
b.addCode(",");
Expand All @@ -366,7 +369,8 @@ private MethodSpec asyncTestsSourceMethod() {
AsyncTestCase.class,
test.getDocumentation(),
asyncOperationCallLambda(defaultOpModel, test.getParams(), Collections.emptyMap()),
TestGeneratorUtils.createExpect(test.getExpect()), getSkipReasonBlock(test.getDocumentation()));
TestGeneratorUtils.createExpect(test.getExpect(), defaultOpModel),
getSkipReasonBlock(test.getDocumentation()));
}

if (testIter.hasNext()) {
Expand Down
Expand Up @@ -182,8 +182,4 @@ public boolean isS3Control() {
public TypeName resolverReturnType() {
return ParameterizedTypeName.get(CompletableFuture.class, Endpoint.class);
}

public boolean isEndpointRulesEnabled() {
return intermediateModel.getCustomizationConfig().useRuleBasedEndpoints();
}
}

0 comments on commit b0a4a97

Please sign in to comment.