diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index cf15871921..f97ce21668 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -214,3 +214,128 @@ message = "Several breaking changes have been made to errors. See [the upgrade g references = ["smithy-rs#1926", "smithy-rs#1819"] meta = { "breaking" = true, "tada" = false, "bug" = false } author = "jdisanti" + +[[smithy-rs]] +message = """ +[Constraint traits](https://awslabs.github.io/smithy/2.0/spec/constraint-traits.html) in server SDKs are beginning to be supported. The following are now supported: + +* The `length` trait on `string` shapes. +* The `length` trait on `map` shapes. + +Upon receiving a request that violates the modeled constraints, the server SDK will reject it with a message indicating why. + +Unsupported (constraint trait, target shape) combinations will now fail at code generation time, whereas previously they were just ignored. This is a breaking change to raise awareness in service owners of their server SDKs behaving differently than what was modeled. To continue generating a server SDK with unsupported constraint traits, set `codegenConfig.ignoreUnsupportedConstraints` to `true` in your `smithy-build.json`. +""" +references = ["smithy-rs#1199", "smithy-rs#1342", "smithy-rs#1401"] +meta = { "breaking" = true, "tada" = true, "bug" = false, "target" = "server" } +author = "david-perez" + +[[smithy-rs]] +message = """ +Server SDKs now generate "constrained types" for constrained shapes. Constrained types are [newtypes](https://rust-unofficial.github.io/patterns/patterns/behavioural/newtype.html) that encapsulate the modeled constraints. They constitute a [widespread pattern to guarantee domain invariants](https://www.lpalmieri.com/posts/2020-12-11-zero-to-production-6-domain-modelling/) and promote correctness in your business logic. So, for example, the model: + +```smithy +@length(min: 1, max: 69) +string NiceString +``` + +will now render a `struct NiceString(String)`. Instantiating a `NiceString` is a fallible operation: + +```rust +let data: String = ... ; +let nice_string = NiceString::try_from(data).expect("data is not nice"); +``` + +A failed attempt to instantiate a constrained type will yield a `ConstraintViolation` error type you may want to handle. This type's API is subject to change. + +Constrained types _guarantee_, by virtue of the type system, that your service's operation outputs adhere to the modeled constraints. To learn more about the motivation for constrained types and how they work, see [the RFC](https://github.com/awslabs/smithy-rs/pull/1199). + +If you'd like to opt-out of generating constrained types, you can set `codegenConfig.publicConstrainedTypes` to `false`. Note that if you do, the generated server SDK will still honor your operation input's modeled constraints upon receiving a request, but will not help you in writing business logic code that adheres to the constraints, and _will not prevent you from returning responses containing operation outputs that violate said constraints_. +""" +references = ["smithy-rs#1342", "smithy-rs#1119"] +meta = { "breaking" = true, "tada" = true, "bug" = false, "target" = "server" } +author = "david-perez" + +[[smithy-rs]] +message = """ +Structure builders in server SDKs have undergone significant changes. + +The API surface has been reduced. It is now simpler and closely follows what you would get when using the [`derive_builder`](https://docs.rs/derive_builder/latest/derive_builder/) crate: + +1. Builders no longer have `set_*` methods taking in `Option`. You must use the unprefixed method, named exactly after the structure's field name, and taking in a value _whose type matches exactly that of the structure's field_. +2. Builders no longer have convenience methods to pass in an element for a field whose type is a vector or a map. You must pass in the entire contents of the collection up front. +3. Builders no longer implement [`PartialEq`](https://doc.rust-lang.org/std/cmp/trait.PartialEq.html). + +Bug fixes: + +4. Builders now always fail to build if a value for a `required` member is not provided. Previously, builders were falling back to a default value (e.g. `""` for `String`s) for some shapes. This was a bug. + +Additions: + +5. A structure `Structure` with builder `Builder` now implements `TryFrom for Structure` or `From for Structure`, depending on whether the structure [is constrained](https://awslabs.github.io/smithy/2.0/spec/constraint-traits.html) or not, respectively. + +To illustrate how to migrate to the new API, consider the example model below. + +```smithy +structure Pokemon { + @required + name: String, + @required + description: String, + @required + evolvesTo: PokemonList +} + +list PokemonList { + member: Pokemon +} +``` + +In the Rust code below, note the references calling out the changes described in the numbered list above. + +Before: + +```rust +let eevee_builder = Pokemon::builder() + // (1) `set_description` takes in `Some`. + .set_description(Some("Su código genético es muy inestable. Puede evolucionar en diversas razas de Pokémon.".to_owned())) + // (2) Convenience method to add one element to the `evolvesTo` list. + .evolves_to(vaporeon) + .evolves_to(jolteon) + .evolves_to(flareon); + +// (3) Builder types can be compared. +assert_ne!(eevee_builder, Pokemon::builder()); + +// (4) Builds fine even though we didn't provide a value for `name`, which is `required`! +let _eevee = eevee_builder.build(); +``` + +After: + +```rust +let eevee_builder = Pokemon::builder() + // (1) `set_description` no longer exists. Use `description`, which directly takes in `String`. + .description("Su código genético es muy inestable. Puede evolucionar en diversas razas de Pokémon.".to_owned()) + // (2) Convenience methods removed; provide the entire collection up front. + .evolves_to(vec![vaporeon, jolteon, flareon]); + +// (3) Binary operation `==` cannot be applied to `pokemon::Builder`. +// assert_ne!(eevee_builder, Pokemon::builder()); + +// (4) `required` member `name` was not set. +// (5) Builder type can be fallibly converted to the structure using `TryFrom` or `TryInto`. +let _error = Pokemon::try_from(eevee_builder).expect_err("name was not provided"); +``` +""" +references = ["smithy-rs#1714", "smithy-rs#1342"] +meta = { "breaking" = true, "tada" = true, "bug" = true, "target" = "server" } +author = "david-perez" + +[[smithy-rs]] +message = """ +Server SDKs now correctly reject operation inputs that don't set values for `required` structure members. Previously, in some scenarios, server SDKs would accept the request and set a default value for the member (e.g. `""` for a `String`), even when the member shape did not have [Smithy IDL v2's `default` trait](https://awslabs.github.io/smithy/2.0/spec/type-refinement-traits.html#smithy-api-default-trait) attached. The `default` trait is [still unsupported](https://github.com/awslabs/smithy-rs/issues/1860). +""" +references = ["smithy-rs#1714", "smithy-rs#1342", "smithy-rs#1860"] +meta = { "breaking" = true, "tada" = false, "bug" = true, "target" = "server" } +author = "david-perez" diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiator.kt index 5e47701f55..b74079cc21 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiator.kt @@ -6,20 +6,34 @@ package software.amazon.smithy.rust.codegen.client.smithy.generators import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.Instantiator +import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName private fun enumFromStringFn(enumSymbol: Symbol, data: String): Writable = writable { rust("#T::from($data)", enumSymbol) } +class ClientBuilderKindBehavior(val codegenContext: CodegenContext) : Instantiator.BuilderKindBehavior { + override fun hasFallibleBuilder(shape: StructureShape): Boolean = + BuilderGenerator.hasFallibleBuilder(shape, codegenContext.symbolProvider) + + override fun setterName(memberShape: MemberShape): String = memberShape.setterName() + + override fun doesSetterTakeInOption(memberShape: MemberShape): Boolean = true +} + fun clientInstantiator(codegenContext: CodegenContext) = Instantiator( codegenContext.symbolProvider, codegenContext.model, codegenContext.runtimeConfig, + ClientBuilderKindBehavior(codegenContext), ::enumFromStringFn, ) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt index 912a0f668a..0f1eac1d67 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt @@ -26,7 +26,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations -import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.error.errorSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.http.ResponseBindingGenerator @@ -332,7 +332,7 @@ class HttpBoundProtocolTraitImplGenerator( } } - val err = if (StructureGenerator.hasFallibleBuilder(outputShape, symbolProvider)) { + val err = if (BuilderGenerator.hasFallibleBuilder(outputShape, symbolProvider)) { ".map_err(${format(errorSymbol)}::unhandled)?" } else "" diff --git a/codegen-core/common-test-models/constraints.smithy b/codegen-core/common-test-models/constraints.smithy new file mode 100644 index 0000000000..d43ea8b7b2 --- /dev/null +++ b/codegen-core/common-test-models/constraints.smithy @@ -0,0 +1,453 @@ +$version: "1.0" + +namespace com.amazonaws.constraints + +use aws.protocols#restJson1 +use smithy.framework#ValidationException + +/// A service to test aspects of code generation where shapes have constraint traits. +@restJson1 +@title("ConstraintsService") +service ConstraintsService { + operations: [ + // TODO Rename as {Verb}[{Qualifier}]{Noun}: https://github.com/awslabs/smithy-rs/pull/1342#discussion_r980936650 + ConstrainedShapesOperation, + ConstrainedHttpBoundShapesOperation, + ConstrainedRecursiveShapesOperation, + // `httpQueryParams` and `httpPrefixHeaders` are structurually + // exclusive, so we need one operation per target shape type + // combination. + QueryParamsTargetingLengthMapOperation, + QueryParamsTargetingMapOfLengthStringOperation, + QueryParamsTargetingMapOfEnumStringOperation, + QueryParamsTargetingMapOfListOfLengthStringOperation, + QueryParamsTargetingMapOfSetOfLengthStringOperation, + QueryParamsTargetingMapOfListOfEnumStringOperation, + HttpPrefixHeadersTargetingLengthMapOperation, + // TODO(https://github.com/awslabs/smithy-rs/issues/1431) + // HttpPrefixHeadersTargetingMapOfEnumStringOperation, + + NonStreamingBlobOperation, + + StreamingBlobOperation, + EventStreamsOperation, + ], +} + +@http(uri: "/constrained-shapes-operation", method: "POST") +operation ConstrainedShapesOperation { + input: ConstrainedShapesOperationInputOutput, + output: ConstrainedShapesOperationInputOutput, + errors: [ValidationException] +} + +@http(uri: "/constrained-http-bound-shapes-operation/{lengthStringLabel}/{enumStringLabel}", method: "POST") +operation ConstrainedHttpBoundShapesOperation { + input: ConstrainedHttpBoundShapesOperationInputOutput, + output: ConstrainedHttpBoundShapesOperationInputOutput, + errors: [ValidationException] +} + +@http(uri: "/constrained-recursive-shapes-operation", method: "POST") +operation ConstrainedRecursiveShapesOperation { + input: ConstrainedRecursiveShapesOperationInputOutput, + output: ConstrainedRecursiveShapesOperationInputOutput, + errors: [ValidationException] +} + +@http(uri: "/query-params-targeting-length-map", method: "POST") +operation QueryParamsTargetingLengthMapOperation { + input: QueryParamsTargetingLengthMapOperationInputOutput, + output: QueryParamsTargetingLengthMapOperationInputOutput, + errors: [ValidationException] +} + +@http(uri: "/query-params-targeting-map-of-length-string-operation", method: "POST") +operation QueryParamsTargetingMapOfLengthStringOperation { + input: QueryParamsTargetingMapOfLengthStringOperationInputOutput, + output: QueryParamsTargetingMapOfLengthStringOperationInputOutput, + errors: [ValidationException] +} + +@http(uri: "/query-params-targeting-map-of-enum-string-operation", method: "POST") +operation QueryParamsTargetingMapOfEnumStringOperation { + input: QueryParamsTargetingMapOfEnumStringOperationInputOutput, + output: QueryParamsTargetingMapOfEnumStringOperationInputOutput, + errors: [ValidationException] +} + +@http(uri: "/query-params-targeting-map-of-list-of-length-string-operation", method: "POST") +operation QueryParamsTargetingMapOfListOfLengthStringOperation { + input: QueryParamsTargetingMapOfListOfLengthStringOperationInputOutput, + output: QueryParamsTargetingMapOfListOfLengthStringOperationInputOutput, + errors: [ValidationException] +} + +@http(uri: "/query-params-targeting-map-of-set-of-length-string-operation", method: "POST") +operation QueryParamsTargetingMapOfSetOfLengthStringOperation { + input: QueryParamsTargetingMapOfSetOfLengthStringOperationInputOutput, + output: QueryParamsTargetingMapOfSetOfLengthStringOperationInputOutput, + errors: [ValidationException] +} + +@http(uri: "/query-params-targeting-map-of-list-of-enum-string-operation", method: "POST") +operation QueryParamsTargetingMapOfListOfEnumStringOperation { + input: QueryParamsTargetingMapOfListOfEnumStringOperationInputOutput, + output: QueryParamsTargetingMapOfListOfEnumStringOperationInputOutput, + errors: [ValidationException] +} + +@http(uri: "/http-prefix-headers-targeting-length-map-operation", method: "POST") +operation HttpPrefixHeadersTargetingLengthMapOperation { + input: HttpPrefixHeadersTargetingLengthMapOperationInputOutput, + output: HttpPrefixHeadersTargetingLengthMapOperationInputOutput, + errors: [ValidationException], +} + +@http(uri: "/http-prefix-headers-targeting-map-of-enum-string-operation", method: "POST") +operation HttpPrefixHeadersTargetingMapOfEnumStringOperation { + input: HttpPrefixHeadersTargetingMapOfEnumStringOperationInputOutput, + output: HttpPrefixHeadersTargetingMapOfEnumStringOperationInputOutput, + errors: [ValidationException] +} + +@http(uri: "/non-streaming-blob-operation", method: "POST") +operation NonStreamingBlobOperation { + input: NonStreamingBlobOperationInputOutput, + output: NonStreamingBlobOperationInputOutput, +} + +@http(uri: "/streaming-blob-operation", method: "POST") +operation StreamingBlobOperation { + input: StreamingBlobOperationInputOutput, + output: StreamingBlobOperationInputOutput, +} + +@http(uri: "/event-streams-operation", method: "POST") +operation EventStreamsOperation { + input: EventStreamsOperationInputOutput, + output: EventStreamsOperationInputOutput, +} + +structure ConstrainedShapesOperationInputOutput { + @required + conA: ConA, +} + +structure ConstrainedHttpBoundShapesOperationInputOutput { + @required + @httpLabel + lengthStringLabel: LengthString, + + @required + @httpLabel + enumStringLabel: EnumString, + + // TODO(https://github.com/awslabs/smithy-rs/issues/1394) `@required` not working + // @required + @httpPrefixHeaders("X-Prefix-Headers-") + lengthStringHeaderMap: MapOfLengthString, + + @httpHeader("X-Length") + lengthStringHeader: LengthString, + + // @httpHeader("X-Length-MediaType") + // lengthStringHeaderWithMediaType: MediaTypeLengthString, + + @httpHeader("X-Length-Set") + lengthStringSetHeader: SetOfLengthString, + + @httpHeader("X-Length-List") + lengthStringListHeader: ListOfLengthString, + + // TODO(https://github.com/awslabs/smithy-rs/issues/1431) + // @httpHeader("X-Enum") + //enumStringHeader: EnumString, + + // @httpHeader("X-Enum-List") + // enumStringListHeader: ListOfEnumString, + + @httpQuery("lengthString") + lengthStringQuery: LengthString, + + @httpQuery("enumString") + enumStringQuery: EnumString, + + @httpQuery("lengthStringList") + lengthStringListQuery: ListOfLengthString, + + @httpQuery("lengthStringSet") + lengthStringSetQuery: SetOfLengthString, + + @httpQuery("enumStringList") + enumStringListQuery: ListOfEnumString, +} + +structure HttpPrefixHeadersTargetingLengthMapOperationInputOutput { + @httpPrefixHeaders("X-Prefix-Headers-LengthMap-") + lengthMap: ConBMap, +} + +structure HttpPrefixHeadersTargetingMapOfEnumStringOperationInputOutput { + @httpPrefixHeaders("X-Prefix-Headers-MapOfEnumString-") + mapOfEnumString: MapOfEnumString, +} + +structure QueryParamsTargetingLengthMapOperationInputOutput { + @httpQueryParams + lengthMap: ConBMap +} + +structure QueryParamsTargetingMapOfLengthStringOperationInputOutput { + @httpQueryParams + mapOfLengthString: MapOfLengthString +} + +structure QueryParamsTargetingMapOfEnumStringOperationInputOutput { + @httpQueryParams + mapOfEnumString: MapOfEnumString +} + +structure QueryParamsTargetingMapOfListOfLengthStringOperationInputOutput { + @httpQueryParams + mapOfListOfLengthString: MapOfListOfLengthString +} + +structure QueryParamsTargetingMapOfSetOfLengthStringOperationInputOutput { + @httpQueryParams + mapOfSetOfLengthString: MapOfSetOfLengthString +} + +structure QueryParamsTargetingMapOfListOfEnumStringOperationInputOutput { + @httpQueryParams + mapOfListOfEnumString: MapOfListOfEnumString +} + +structure NonStreamingBlobOperationInputOutput { + @httpPayload + nonStreamingBlob: NonStreamingBlob, +} + +structure StreamingBlobOperationInputOutput { + @httpPayload + streamingBlob: StreamingBlob, +} + +structure EventStreamsOperationInputOutput { + @httpPayload + events: Event, +} + +@streaming +union Event { + regularMessage: EventStreamRegularMessage, + errorMessage: EventStreamErrorMessage, +} + +structure EventStreamRegularMessage { + messageContent: String + // TODO(https://github.com/awslabs/smithy/issues/1388): Can't add a constraint trait here until the semantics are clarified. + // messageContent: LengthString +} + +@error("server") +structure EventStreamErrorMessage { + messageContent: String + // TODO(https://github.com/awslabs/smithy/issues/1388): Can't add a constraint trait here until the semantics are clarified. + // messageContent: LengthString +} + +// TODO(https://github.com/awslabs/smithy/issues/1389): Can't add a constraint trait here until the semantics are clarified. +@streaming +blob StreamingBlob + +blob NonStreamingBlob + +structure ConA { + @required + conB: ConB, + + optConB: ConB, + + lengthString: LengthString, + minLengthString: MinLengthString, + maxLengthString: MaxLengthString, + fixedLengthString: FixedLengthString, + + conBList: ConBList, + conBList2: ConBList2, + + conBSet: ConBSet, + + conBMap: ConBMap, + + mapOfMapOfListOfListOfConB: MapOfMapOfListOfListOfConB, + + constrainedUnion: ConstrainedUnion, + enumString: EnumString, + + listOfLengthString: ListOfLengthString, + setOfLengthString: SetOfLengthString, + mapOfLengthString: MapOfLengthString, + + nonStreamingBlob: NonStreamingBlob +} + +map MapOfLengthString { + key: LengthString, + value: LengthString, +} + +map MapOfEnumString { + key: EnumString, + value: EnumString, +} + +map MapOfListOfLengthString { + key: LengthString, + value: ListOfLengthString, +} + +map MapOfListOfEnumString { + key: EnumString, + value: ListOfEnumString, +} + +map MapOfSetOfLengthString { + key: LengthString, + value: SetOfLengthString, +} + +@length(min: 2, max: 8) +list LengthListOfLengthString { + member: LengthString +} + +@length(min: 2, max: 69) +string LengthString + +@length(min: 2) +string MinLengthString + +@length(min: 69) +string MaxLengthString + +@length(min: 69, max: 69) +string FixedLengthString + +@mediaType("video/quicktime") +@length(min: 1, max: 69) +string MediaTypeLengthString + +/// A union with constrained members. +union ConstrainedUnion { + enumString: EnumString, + lengthString: LengthString, + + constrainedStructure: ConB, + conBList: ConBList, + conBSet: ConBSet, + conBMap: ConBMap, +} + +@enum([ + { + value: "t2.nano", + name: "T2_NANO", + }, + { + value: "t2.micro", + name: "T2_MICRO", + }, + { + value: "m256.mega", + name: "M256_MEGA", + } +]) +string EnumString + +set SetOfLengthString { + member: LengthString +} + +list ListOfLengthString { + member: LengthString +} + +list ListOfEnumString { + member: EnumString +} + +structure ConB { + @required + nice: String, + @required + int: Integer, + + optNice: String, + optInt: Integer +} + +structure ConstrainedRecursiveShapesOperationInputOutput { + nested: RecursiveShapesInputOutputNested1, + + @required + recursiveList: RecursiveList +} + +structure RecursiveShapesInputOutputNested1 { + @required + recursiveMember: RecursiveShapesInputOutputNested2 +} + +structure RecursiveShapesInputOutputNested2 { + recursiveMember: RecursiveShapesInputOutputNested1, +} + +list RecursiveList { + member: RecursiveShapesInputOutputNested1 +} + +list ConBList { + member: NestedList +} + +list ConBList2 { + member: ConB +} + +list NestedList { + member: ConB +} + +set ConBSet { + member: NestedSet +} + +set NestedSet { + member: String +} + +@length(min: 1, max: 69) +map ConBMap { + key: String, + value: LengthString +} + +@error("client") +structure ErrorWithLengthStringMessage { + // TODO Doesn't work yet because constrained string types don't implement + // `AsRef`. + // @required + // message: LengthString +} + +map MapOfMapOfListOfListOfConB { + key: String, + value: MapOfListOfListOfConB +} + +map MapOfListOfListOfConB { + key: String, + value: ConBList +} diff --git a/codegen-core/common-test-models/misc.smithy b/codegen-core/common-test-models/misc.smithy index a98c0fa21f..69bdcc2cd8 100644 --- a/codegen-core/common-test-models/misc.smithy +++ b/codegen-core/common-test-models/misc.smithy @@ -5,6 +5,7 @@ namespace aws.protocoltests.misc use aws.protocols#restJson1 use smithy.test#httpRequestTests use smithy.test#httpResponseTests +use smithy.framework#ValidationException /// A service to test miscellaneous aspects of code generation where protocol /// selection is not relevant. If you want to test something protocol-specific, @@ -54,10 +55,11 @@ map MapA { /// This operation tests that (de)serializing required values from a nested /// shape works correctly. -@http(uri: "/innerRequiredShapeOperation", method: "POST") +@http(uri: "/requiredInnerShapeOperation", method: "POST") operation RequiredInnerShapeOperation { input: RequiredInnerShapeOperationInputOutput, output: RequiredInnerShapeOperationInputOutput, + errors: [ValidationException], } structure RequiredInnerShapeOperationInputOutput { @@ -236,6 +238,7 @@ operation AcceptHeaderStarService {} operation RequiredHeaderCollectionOperation { input: RequiredHeaderCollectionOperationInputOutput, output: RequiredHeaderCollectionOperationInputOutput, + errors: [ValidationException] } structure RequiredHeaderCollectionOperationInputOutput { diff --git a/codegen-core/common-test-models/naming-obstacle-course-ops.smithy b/codegen-core/common-test-models/naming-obstacle-course-ops.smithy index 087d99b750..f54b27e76f 100644 --- a/codegen-core/common-test-models/naming-obstacle-course-ops.smithy +++ b/codegen-core/common-test-models/naming-obstacle-course-ops.smithy @@ -5,6 +5,7 @@ use smithy.test#httpRequestTests use smithy.test#httpResponseTests use aws.protocols#awsJson1_1 use aws.api#service +use smithy.framework#ValidationException /// Confounds model generation machinery with lots of problematic names @awsJson1_1 @@ -41,17 +42,20 @@ service Config { } ]) operation ReservedWordsAsMembers { - input: ReservedWords + input: ReservedWords, + errors: [ValidationException], } // tests that module names are properly escaped operation Match { input: ReservedWords + errors: [ValidationException], } // Should generate a PascalCased `RpcEchoInput` struct. operation RPCEcho { input: ReservedWords + errors: [ValidationException], } structure ReservedWords { diff --git a/codegen-core/common-test-models/pokemon-common.smithy b/codegen-core/common-test-models/pokemon-common.smithy index 3198cb8c74..d213a16b15 100644 --- a/codegen-core/common-test-models/pokemon-common.smithy +++ b/codegen-core/common-test-models/pokemon-common.smithy @@ -2,6 +2,8 @@ $version: "1.0" namespace com.aws.example +use smithy.framework#ValidationException + /// A Pokémon species forms the basis for at least one Pokémon. @title("Pokémon Species") resource PokemonSpecies { @@ -17,7 +19,7 @@ resource PokemonSpecies { operation GetPokemonSpecies { input: GetPokemonSpeciesInput, output: GetPokemonSpeciesOutput, - errors: [ResourceNotFoundException], + errors: [ResourceNotFoundException, ValidationException], } @input diff --git a/codegen-core/common-test-models/pokemon.smithy b/codegen-core/common-test-models/pokemon.smithy index e955cdd218..d42185e31c 100644 --- a/codegen-core/common-test-models/pokemon.smithy +++ b/codegen-core/common-test-models/pokemon.smithy @@ -3,6 +3,7 @@ $version: "1.0" namespace com.aws.example.rust use aws.protocols#restJson1 +use smithy.framework#ValidationException use com.aws.example#PokemonSpecies use com.aws.example#GetServerStatistics use com.aws.example#DoNothing @@ -31,13 +32,13 @@ resource Storage { read: GetStorage, } -/// Retrieve information about your Pokedex. +/// Retrieve information about your Pokédex. @readonly @http(uri: "/pokedex/{user}", method: "GET") operation GetStorage { input: GetStorageInput, output: GetStorageOutput, - errors: [ResourceNotFoundException, NotAuthorized], + errors: [ResourceNotFoundException, NotAuthorized, ValidationException], } /// Not authorized to access Pokémon storage. @@ -74,7 +75,7 @@ structure GetStorageOutput { operation CapturePokemon { input: CapturePokemonEventsInput, output: CapturePokemonEventsOutput, - errors: [UnsupportedRegionError, ThrottlingError] + errors: [UnsupportedRegionError, ThrottlingError, ValidationException] } @input @@ -140,7 +141,6 @@ structure InvalidPokeballError { } @error("server") structure MasterBallUnsuccessful { - @required message: String, } diff --git a/codegen-core/common-test-models/rest-json-extras.smithy b/codegen-core/common-test-models/rest-json-extras.smithy index 10224b0c7c..65b7fcc8f1 100644 --- a/codegen-core/common-test-models/rest-json-extras.smithy +++ b/codegen-core/common-test-models/rest-json-extras.smithy @@ -6,6 +6,7 @@ use aws.protocols#restJson1 use aws.api#service use smithy.test#httpRequestTests use smithy.test#httpResponseTests +use smithy.framework#ValidationException apply QueryPrecedence @httpRequestTests([ { @@ -120,6 +121,7 @@ structure StringPayloadInput { documentation: "Primitive ints should not be serialized when they are unset", uri: "/primitive-document", method: "POST", + appliesTo: "client", body: "{}", headers: { "Content-Type": "application/json" }, params: {}, @@ -152,7 +154,8 @@ structure PrimitiveIntDocument { ]) @http(uri: "/primitive", method: "POST") operation PrimitiveIntHeader { - output: PrimitiveIntHeaderInput + output: PrimitiveIntHeaderInput, + errors: [ValidationException], } integer PrimitiveInt @@ -174,7 +177,8 @@ structure PrimitiveIntHeaderInput { } ]) operation EnumQuery { - input: EnumQueryInput + input: EnumQueryInput, + errors: [ValidationException], } structure EnumQueryInput { @@ -226,6 +230,7 @@ structure MapWithEnumKeyInputOutput { operation MapWithEnumKeyOp { input: MapWithEnumKeyInputOutput, output: MapWithEnumKeyInputOutput, + errors: [ValidationException], } @@ -265,6 +270,7 @@ structure EscapedStringValuesInputOutput { operation EscapedStringValues { input: EscapedStringValuesInputOutput, output: EscapedStringValuesInputOutput, + errors: [ValidationException], } list NonSparseList { diff --git a/codegen-core/common-test-models/simple.smithy b/codegen-core/common-test-models/simple.smithy index 6e094abe1f..c685c02f67 100644 --- a/codegen-core/common-test-models/simple.smithy +++ b/codegen-core/common-test-models/simple.smithy @@ -5,6 +5,7 @@ namespace com.amazonaws.simple use aws.protocols#restJson1 use smithy.test#httpRequestTests use smithy.test#httpResponseTests +use smithy.framework#ValidationException @restJson1 @title("SimpleService") @@ -74,7 +75,7 @@ resource Service { operation RegisterService { input: RegisterServiceInputRequest, output: RegisterServiceOutputResponse, - errors: [ResourceAlreadyExists] + errors: [ResourceAlreadyExists, ValidationException] } @documentation("Service register input structure") @@ -116,6 +117,7 @@ structure HealthcheckOutputResponse { operation StoreServiceBlob { input: StoreServiceBlobInput, output: StoreServiceBlobOutput + errors: [ValidationException] } @documentation("Store a blob for a service id input structure") diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt index bd231f8094..970c1a63ed 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt @@ -112,6 +112,9 @@ class InlineDependency( fun unwrappedXmlErrors(runtimeConfig: RuntimeConfig): InlineDependency = forRustFile("rest_xml_unwrapped_errors", CargoDependency.smithyXml(runtimeConfig)) + + fun constrained(): InlineDependency = + forRustFile("constrained") } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt index 6ff484bbaa..120798258f 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt @@ -148,6 +148,12 @@ sealed class RustType { } } + data class MaybeConstrained(override val member: RustType) : RustType(), Container { + val runtimeType: RuntimeType = RuntimeType.MaybeConstrained() + override val name = runtimeType.name!! + override val namespace = runtimeType.namespace + } + data class Box(override val member: RustType) : RustType(), Container { override val name = "Box" override val namespace = "std::boxed" @@ -237,6 +243,7 @@ fun RustType.render(fullyQualified: Boolean = true): String { is RustType.Box -> "${this.name}<${this.member.render(fullyQualified)}>" is RustType.Dyn -> "${this.name} ${this.member.render(fullyQualified)}" is RustType.Opaque -> this.name + is RustType.MaybeConstrained -> "${this.name}<${this.member.render(fullyQualified)}>" } return "$namespace$base" } @@ -380,6 +387,7 @@ sealed class Attribute { companion object { val AllowDeadCode = Custom("allow(dead_code)") val AllowDeprecated = Custom("allow(deprecated)") + val AllowUnused = Custom("allow(unused)") val AllowUnusedMut = Custom("allow(unused_mut)") val DocHidden = Custom("doc(hidden)") val DocInline = Custom("doc(inline)") diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegator.kt index bbd77b7c74..f1739324a7 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegator.kt @@ -171,6 +171,25 @@ open class RustCrate( } } +val ErrorsModule = RustModule.public("error", documentation = "All error types that operations can return.") +val OperationsModule = RustModule.public("operation", documentation = "All operations that this crate can perform.") +val ModelsModule = RustModule.public("model", documentation = "Data structures used by operation inputs/outputs.") +val InputsModule = RustModule.public("input", documentation = "Input structures for operations.") +val OutputsModule = RustModule.public("output", documentation = "Output structures for operations.") +val ConfigModule = RustModule.public("config", documentation = "Client configuration.") + +/** + * Allowlist of modules that will be exposed publicly in generated crates + */ +val DefaultPublicModules = setOf( + ErrorsModule, + OperationsModule, + ModelsModule, + InputsModule, + OutputsModule, + ConfigModule, +).associateBy { it.name } + /** * Finalize all the writers by: * - inlining inline dependencies that have been used diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt index d79059065f..4fc233df32 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt @@ -196,7 +196,9 @@ data class RuntimeType(val name: String?, val dependency: RustDependency?, val n val Debug = stdfmt.member("Debug") val Default: RuntimeType = RuntimeType("Default", dependency = null, namespace = "std::default") val Display = stdfmt.member("Display") + val Eq = std.member("cmp::Eq") val From = RuntimeType("From", dependency = null, namespace = "std::convert") + val Hash = std.member("hash::Hash") val TryFrom = RuntimeType("TryFrom", dependency = null, namespace = "std::convert") val PartialEq = std.member("cmp::PartialEq") val StdError = RuntimeType("Error", dependency = null, namespace = "std::error") @@ -256,6 +258,9 @@ data class RuntimeType(val name: String?, val dependency: RustDependency?, val n func, CargoDependency.SmithyProtocolTestHelpers(runtimeConfig), "aws_smithy_protocol_test", ) + fun ConstrainedTrait() = RuntimeType("Constrained", InlineDependency.constrained(), namespace = "crate::constrained") + fun MaybeConstrained() = RuntimeType("MaybeConstrained", InlineDependency.constrained(), namespace = "crate::constrained") + val http = CargoDependency.Http.asType() fun Http(path: String): RuntimeType = RuntimeType(name = path, dependency = CargoDependency.Http, namespace = "http") diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt index d72bafb5b9..238a1177bc 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt @@ -9,6 +9,7 @@ import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.NullableIndex +import software.amazon.smithy.model.knowledge.NullableIndex.CheckMode import software.amazon.smithy.model.shapes.BigDecimalShape import software.amazon.smithy.model.shapes.BigIntegerShape import software.amazon.smithy.model.shapes.BlobShape @@ -77,29 +78,65 @@ data class SymbolLocation(val namespace: String) { val filename = "$namespace.rs" } -val Models = SymbolLocation("model") -val Errors = SymbolLocation("error") -val Operations = SymbolLocation("operation") -val Inputs = SymbolLocation("input") -val Outputs = SymbolLocation("output") +val Models = SymbolLocation(ModelsModule.name) +val Errors = SymbolLocation(ErrorsModule.name) +val Operations = SymbolLocation(OperationsModule.name) +val Serializers = SymbolLocation("serializer") +val Inputs = SymbolLocation(InputsModule.name) +val Outputs = SymbolLocation(OutputsModule.name) +val Unconstrained = SymbolLocation("unconstrained") +val Constrained = SymbolLocation("constrained") /** * Make the Rust type of a symbol optional (hold `Option`) * * This is idempotent and will have no change if the type is already optional. */ -fun Symbol.makeOptional(): Symbol { - return if (isOptional()) { +fun Symbol.makeOptional(): Symbol = + if (isOptional()) { this } else { val rustType = RustType.Option(this.rustType()) - Symbol.builder().rustType(rustType) + Symbol.builder() + .rustType(rustType) + .addReference(this) + .name(rustType.name) + .build() + } + +/** + * Make the Rust type of a symbol boxed (hold `Box`). + * + * This is idempotent and will have no change if the type is already boxed. + */ +fun Symbol.makeRustBoxed(): Symbol = + if (isRustBoxed()) { + this + } else { + val rustType = RustType.Box(this.rustType()) + Symbol.builder() + .rustType(rustType) + .addReference(this) + .name(rustType.name) + .build() + } + +/** + * Make the Rust type of a symbol wrapped in `MaybeConstrained`. (hold `MaybeConstrained`). + * + * This is idempotent and will have no change if the type is already `MaybeConstrained`. + */ +fun Symbol.makeMaybeConstrained(): Symbol = + if (this.rustType() is RustType.MaybeConstrained) { + this + } else { + val rustType = RustType.MaybeConstrained(this.rustType()) + Symbol.builder() .rustType(rustType) .addReference(this) .name(rustType.name) .build() } -} /** * Map the [RustType] of a symbol with [f]. @@ -208,9 +245,6 @@ open class SymbolVisitor( return RuntimeType.Blob(config.runtimeConfig).toSymbol() } - private fun handleOptionality(symbol: Symbol, member: MemberShape): Symbol = - symbol.letIf(nullableIndex.isMemberNullable(member, config.nullabilityCheckMode)) { symbol.makeOptional() } - /** * Produce `Box` when the shape has the `RustBoxTrait` */ @@ -227,7 +261,7 @@ open class SymbolVisitor( } private fun simpleShape(shape: SimpleShape): Symbol { - return symbolBuilder(SimpleShapes.getValue(shape::class)).setDefault(Default.RustDefault).build() + return symbolBuilder(shape, SimpleShapes.getValue(shape::class)).setDefault(Default.RustDefault).build() } override fun booleanShape(shape: BooleanShape): Symbol = simpleShape(shape) @@ -240,7 +274,7 @@ open class SymbolVisitor( override fun stringShape(shape: StringShape): Symbol { return if (shape.hasTrait()) { val rustType = RustType.Opaque(shape.contextName(serviceShape).toPascalCase()) - symbolBuilder(rustType).locatedIn(Models).build() + symbolBuilder(shape, rustType).locatedIn(Models).build() } else { simpleShape(shape) } @@ -248,16 +282,16 @@ open class SymbolVisitor( override fun listShape(shape: ListShape): Symbol { val inner = this.toSymbol(shape.member) - return symbolBuilder(RustType.Vec(inner.rustType())).addReference(inner).build() + return symbolBuilder(shape, RustType.Vec(inner.rustType())).addReference(inner).build() } override fun setShape(shape: SetShape): Symbol { val inner = this.toSymbol(shape.member) val builder = if (model.expectShape(shape.member.target).isStringShape) { - symbolBuilder(RustType.HashSet(inner.rustType())) + symbolBuilder(shape, RustType.HashSet(inner.rustType())) } else { // only strings get put into actual sets because floats are unhashable - symbolBuilder(RustType.Vec(inner.rustType())) + symbolBuilder(shape, RustType.Vec(inner.rustType())) } return builder.addReference(inner).build() } @@ -267,7 +301,7 @@ open class SymbolVisitor( require(target.isStringShape) { "unexpected key shape: ${shape.key}: $target [keys must be strings]" } val key = this.toSymbol(shape.key) val value = this.toSymbol(shape.value) - return symbolBuilder(RustType.HashMap(key.rustType(), value.rustType())).addReference(key) + return symbolBuilder(shape, RustType.HashMap(key.rustType(), value.rustType())).addReference(key) .addReference(value).build() } @@ -285,6 +319,7 @@ open class SymbolVisitor( override fun operationShape(shape: OperationShape): Symbol { return symbolBuilder( + shape, RustType.Opaque( shape.contextName(serviceShape) .replaceFirstChar { it.uppercase() }, @@ -309,7 +344,7 @@ open class SymbolVisitor( val name = shape.contextName(serviceShape).toPascalCase().letIf(isError && config.renameExceptions) { it.replace("Exception", "Error") } - val builder = symbolBuilder(RustType.Opaque(name)) + val builder = symbolBuilder(shape, RustType.Opaque(name)) return when { isError -> builder.locatedIn(Errors) isInput -> builder.locatedIn(Inputs) @@ -320,29 +355,50 @@ open class SymbolVisitor( override fun unionShape(shape: UnionShape): Symbol { val name = shape.contextName(serviceShape).toPascalCase() - val builder = symbolBuilder(RustType.Opaque(name)).locatedIn(Models) + val builder = symbolBuilder(shape, RustType.Opaque(name)).locatedIn(Models) return builder.build() } override fun memberShape(shape: MemberShape): Symbol { val target = model.expectShape(shape.target) - // Handle boxing first so we end up with Option>, not Box> - return handleOptionality(handleRustBoxing(toSymbol(target), shape), shape) + // Handle boxing first so we end up with Option>, not Box>. + return handleOptionality( + handleRustBoxing(toSymbol(target), shape), + shape, + nullableIndex, + config.nullabilityCheckMode, + ) } override fun timestampShape(shape: TimestampShape?): Symbol { return RuntimeType.DateTime(config.runtimeConfig).toSymbol() } +} - private fun symbolBuilder(rustType: RustType): Symbol.Builder { - return Symbol.builder().rustType(rustType).name(rustType.name) - // Every symbol that actually gets defined somewhere should set a definition file - // If we ever generate a `thisisabug.rs`, there is a bug in our symbol generation - .definitionFile("thisisabug.rs") - } +/** + * Boxes and returns [symbol], the symbol for the target of the member shape [shape], if [shape] is annotated with + * [RustBoxTrait]; otherwise returns [symbol] unchanged. + * + * See `RecursiveShapeBoxer.kt` for the model transformation pass that annotates model shapes with [RustBoxTrait]. + */ +fun handleRustBoxing(symbol: Symbol, shape: MemberShape): Symbol = + if (shape.hasTrait()) { + symbol.makeRustBoxed() + } else symbol + +fun symbolBuilder(shape: Shape?, rustType: RustType): Symbol.Builder { + val builder = Symbol.builder().putProperty(SHAPE_KEY, shape) + return builder.rustType(rustType) + .name(rustType.name) + // Every symbol that actually gets defined somewhere should set a definition file + // If we ever generate a `thisisabug.rs`, there is a bug in our symbol generation + .definitionFile("thisisabug.rs") } +fun handleOptionality(symbol: Symbol, member: MemberShape, nullableIndex: NullableIndex, nullabilityCheckMode: CheckMode): Symbol = + symbol.letIf(nullableIndex.isMemberNullable(member, nullabilityCheckMode)) { symbol.makeOptional() } + // TODO(chore): Move this to a useful place private const val RUST_TYPE_KEY = "rusttype" private const val SHAPE_KEY = "shape" @@ -388,11 +444,6 @@ fun Symbol.isOptional(): Boolean = when (this.rustType()) { else -> false } -/** - * Get the referenced symbol for T if [this] is an Option, [this] otherwise - */ -fun Symbol.extractSymbolFromOption(): Symbol = this.mapRustType { it.stripOuter() } - fun Symbol.isRustBoxed(): Boolean = rustType().stripOuter() is RustType.Box // Symbols should _always_ be created with a Rust type & shape attached diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt index 9ae2da62ad..0dec0b1a65 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt @@ -6,6 +6,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.StructureShape @@ -28,14 +29,24 @@ import software.amazon.smithy.rust.codegen.core.smithy.Default import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.canUseDefault import software.amazon.smithy.rust.codegen.core.smithy.defaultValue import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.makeOptional import software.amazon.smithy.rust.codegen.core.smithy.rustType +import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.toSnakeCase +// TODO(https://github.com/awslabs/smithy-rs/issues/1401) This builder generator is only used by the client. +// Move this entire file, and its tests, to `codegen-client`. + +fun builderSymbolFn(symbolProvider: RustSymbolProvider): (StructureShape) -> Symbol = { structureShape -> + structureShape.builderSymbol(symbolProvider) +} + fun StructureShape.builderSymbol(symbolProvider: RustSymbolProvider): Symbol { val structureSymbol = symbolProvider.toSymbol(this) val builderNamespace = RustReservedWords.escapeIfNeeded(structureSymbol.name.toSnakeCase()) @@ -65,6 +76,23 @@ class BuilderGenerator( private val symbolProvider: RustSymbolProvider, private val shape: StructureShape, ) { + companion object { + /** + * Returns whether a structure shape, whose builder has been generated with [BuilderGenerator], requires a + * fallible builder to be constructed. + */ + fun hasFallibleBuilder(structureShape: StructureShape, symbolProvider: SymbolProvider): Boolean = + // All operation inputs should have fallible builders in case a new required field is added in the future. + structureShape.hasTrait() || + structureShape + .members() + .map { symbolProvider.toSymbol(it) }.any { + // If any members are not optional && we can't use a default, we need to + // generate a fallible builder. + !it.isOptional() && !it.canUseDefault() + } + } + private val runtimeConfig = symbolProvider.config().runtimeConfig private val members: List = shape.allMembers.values.toList() private val structureSymbol = symbolProvider.toSymbol(shape) @@ -79,7 +107,7 @@ class BuilderGenerator( } private fun renderBuildFn(implBlockWriter: RustWriter) { - val fallibleBuilder = StructureGenerator.hasFallibleBuilder(shape, symbolProvider) + val fallibleBuilder = hasFallibleBuilder(shape, symbolProvider) val outputSymbol = symbolProvider.toSymbol(shape) val returnType = when (fallibleBuilder) { true -> "Result<${implBlockWriter.format(outputSymbol)}, ${implBlockWriter.format(runtimeConfig.operationBuildError())}>" diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt index e735ec5e11..eed5b5462e 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt @@ -85,8 +85,8 @@ open class EnumGenerator( private val model: Model, private val symbolProvider: RustSymbolProvider, private val writer: RustWriter, - private val shape: StringShape, - private val enumTrait: EnumTrait, + protected val shape: StringShape, + protected val enumTrait: EnumTrait, ) { protected val symbol: Symbol = symbolProvider.toSymbol(shape) protected val enumName: String = symbol.name diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt index a58c0c65be..7973af7603 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt @@ -62,12 +62,14 @@ open class Instantiator( private val symbolProvider: RustSymbolProvider, private val model: Model, private val runtimeConfig: RuntimeConfig, + /** Behavior of the builder type used for structure shapes. */ + private val builderKindBehavior: BuilderKindBehavior, /** * A function that given a symbol for an enum shape and a string, returns a writable to instantiate the enum with * the string value. **/ private val enumFromStringFn: (Symbol, String) -> Writable, - /** Fill out required fields with a default value **/ + /** Fill out required fields with a default value. **/ private val defaultsForRequiredFields: Boolean = false, ) { data class Ctx( @@ -76,6 +78,20 @@ open class Instantiator( val lowercaseMapKeys: Boolean = false, ) + /** + * Client and server structures have different builder types. `Instantiator` needs to know how the builder + * type behaves to generate code for it. + */ + interface BuilderKindBehavior { + fun hasFallibleBuilder(shape: StructureShape): Boolean + + // Client structure builders have two kinds of setters: one that always takes in `Option`, and one that takes + // in the structure field's type. The latter's method name is the field's name, whereas the former is prefixed + // with `set_`. Client instantiators call the `set_*` builder setters. + fun setterName(memberShape: MemberShape): String + fun doesSetterTakeInOption(memberShape: MemberShape): Boolean + } + fun render(writer: RustWriter, shape: Shape, data: Node, ctx: Ctx = Ctx()) { when (shape) { // Compound Shapes @@ -165,7 +181,9 @@ open class Instantiator( writer.conditionalBlock( "Some(", ")", - conditional = model.expectShape(memberShape.container) is StructureShape || symbol.isOptional(), + // The conditions are not commutative: note client builders always take in `Option`. + conditional = symbol.isOptional() || + (model.expectShape(memberShape.container) is StructureShape && builderKindBehavior.doesSetterTakeInOption(memberShape)), ) { writer.conditionalBlock( "Box::new(", @@ -277,7 +295,8 @@ open class Instantiator( */ private fun renderStructure(writer: RustWriter, shape: StructureShape, data: ObjectNode, ctx: Ctx) { fun renderMemberHelper(memberShape: MemberShape, value: Node) { - writer.withBlock(".${memberShape.setterName()}(", ")") { + val setterName = builderKindBehavior.setterName(memberShape) + writer.withBlock(".$setterName(", ")") { renderMember(this, memberShape, value, ctx) } } @@ -297,8 +316,9 @@ open class Instantiator( val memberShape = shape.expectMember(key.value) renderMemberHelper(memberShape, value) } + writer.rust(".build()") - if (StructureGenerator.hasFallibleBuilder(shape, symbolProvider)) { + if (builderKindBehavior.hasFallibleBuilder(shape)) { writer.rust(".unwrap()") } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGenerator.kt index cbd0a1395c..3972307ac0 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGenerator.kt @@ -28,16 +28,12 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider -import software.amazon.smithy.rust.codegen.core.smithy.canUseDefault import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorGenerator -import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.renamedFrom import software.amazon.smithy.rust.codegen.core.smithy.rustType -import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.getTrait -import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.redactIfNecessary fun RustWriter.implBlock(structureShape: Shape, symbolProvider: SymbolProvider, block: Writable) { @@ -68,20 +64,6 @@ open class StructureGenerator( } } - companion object { - /** Returns whether a structure shape requires a fallible builder to be generated. */ - fun hasFallibleBuilder(structureShape: StructureShape, symbolProvider: SymbolProvider): Boolean = - // All operation inputs should have fallible builders in case a new required field is added in the future. - structureShape.hasTrait() || - structureShape - .allMembers - .values.map { symbolProvider.toSymbol(it) }.any { - // If any members are not optional && we can't use a default, we need to - // generate a fallible builder - !it.isOptional() && !it.canUseDefault() - } - } - /** * Search for lifetimes used by the members of the struct and generate a declaration. * e.g. `<'a, 'b>` @@ -100,7 +82,8 @@ open class StructureGenerator( } else "" } - /** Render a custom debug implementation + /** + * Render a custom debug implementation * When [SensitiveTrait] support is required, render a custom debug implementation to redact sensitive data */ private fun renderDebugImpl() { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorGenerator.kt index 07f9315b85..4303fa1a68 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorGenerator.kt @@ -9,10 +9,14 @@ import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.model.traits.RetryableTrait +import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.asDeref +import software.amazon.smithy.rust.codegen.core.rustlang.render import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig @@ -20,6 +24,8 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.StdError import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.smithy.mapRustType +import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.util.REDACTION import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.errorMessageMember @@ -82,10 +88,23 @@ class ErrorGenerator( } } if (messageShape != null) { - val (returnType, message) = if (symbolProvider.toSymbol(messageShape).isOptional()) { - "Option<&str>" to "self.${symbolProvider.toMemberName(messageShape)}.as_deref()" + val messageSymbol = symbolProvider.toSymbol(messageShape).mapRustType { t -> t.asDeref() } + val messageType = messageSymbol.rustType() + val memberName = symbolProvider.toMemberName(messageShape) + val (returnType, message) = if (messageType.stripOuter() is RustType.Opaque) { + // The string shape has a constraint trait that makes its symbol be a wrapper tuple struct. + if (messageSymbol.isOptional()) { + "Option<&${messageType.stripOuter().render()}>" to + "self.$memberName.as_ref()" + } else { + "&${messageType.render()}" to "&self.$memberName" + } } else { - "&str" to "self.${symbolProvider.toMemberName(messageShape)}.as_ref()" + if (messageSymbol.isOptional()) { + messageType.render() to "self.$memberName.as_deref()" + } else { + messageType.render() to "self.$memberName.as_ref()" + } } rust( diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt index 2a54f56741..97a703ab98 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt @@ -6,15 +6,19 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators.http import software.amazon.smithy.codegen.core.CodegenException +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.model.knowledge.HttpBinding import software.amazon.smithy.model.knowledge.HttpBindingIndex import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.CollectionShape import software.amazon.smithy.model.shapes.DocumentShape +import software.amazon.smithy.model.shapes.ListShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.SimpleShape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape @@ -38,6 +42,8 @@ import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedSectionGenerator +import software.amazon.smithy.rust.codegen.core.smithy.customize.Section import software.amazon.smithy.rust.codegen.core.smithy.generators.operationBuildError import software.amazon.smithy.rust.codegen.core.smithy.makeOptional import software.amazon.smithy.rust.codegen.core.smithy.mapRustType @@ -68,6 +74,18 @@ enum class HttpMessageType { REQUEST, RESPONSE } +/** + * Class describing an HTTP binding (de)serialization section that can be used in a customization. + */ +sealed class HttpBindingSection(name: String) : Section(name) { + data class BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders(val variableName: String, val shape: MapShape) : + HttpBindingSection("BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders") + data class AfterDeserializingIntoAHashMapOfHttpPrefixHeaders(val memberShape: MemberShape) : + HttpBindingSection("AfterDeserializingIntoAHashMapOfHttpPrefixHeaders") +} + +typealias HttpBindingCustomization = NamedSectionGenerator + /** * This class generates Rust functions that (de)serialize data from/to an HTTP message. * They are useful for *both*: @@ -88,12 +106,15 @@ enum class HttpMessageType { */ class HttpBindingGenerator( private val protocol: Protocol, - codegenContext: CodegenContext, + private val codegenContext: CodegenContext, + private val symbolProvider: SymbolProvider, private val operationShape: OperationShape, + /** Function that maps a StructureShape into its builder symbol */ + private val builderSymbol: (StructureShape) -> Symbol, + private val customizations: List = listOf(), ) { private val runtimeConfig = codegenContext.runtimeConfig - private val symbolProvider = codegenContext.symbolProvider - private val target = codegenContext.target + private val codegenTarget = codegenContext.target private val model = codegenContext.model private val service = codegenContext.serviceShape private val index = HttpBindingIndex.of(model) @@ -120,7 +141,7 @@ class HttpBindingGenerator( val fnName = "deser_header_${fnName(operationShape, binding)}" return RuntimeType.forInlineFun(fnName, httpSerdeModule) { rustBlock( - "pub fn $fnName(header_map: &#T::HeaderMap) -> std::result::Result<#T, #T::ParseError>", + "pub(crate) fn $fnName(header_map: &#T::HeaderMap) -> std::result::Result<#T, #T::ParseError>", RuntimeType.http, outputT, headerUtil, @@ -134,7 +155,6 @@ class HttpBindingGenerator( fun generateDeserializePrefixHeaderFn(binding: HttpBindingDescriptor): RuntimeType { check(binding.location == HttpBinding.Location.PREFIX_HEADERS) val outputSymbol = symbolProvider.toSymbol(binding.member) - check(outputSymbol.rustType().stripOuter() is RustType.HashMap) { outputSymbol.rustType() } val target = model.expectShape(binding.member.target) check(target is MapShape) val fnName = "deser_prefix_header_${fnName(operationShape, binding)}" @@ -151,7 +171,7 @@ class HttpBindingGenerator( val returnTypeSymbol = outputSymbol.mapRustType { it.asOptional() } return RuntimeType.forInlineFun(fnName, httpSerdeModule) { rustBlock( - "pub fn $fnName(header_map: &#T::HeaderMap) -> std::result::Result<#T, #T::ParseError>", + "pub(crate) fn $fnName(header_map: &#T::HeaderMap) -> std::result::Result<#T, #T::ParseError>", RuntimeType.http, returnTypeSymbol, headerUtil, @@ -162,13 +182,19 @@ class HttpBindingGenerator( let out: std::result::Result<_, _> = headers.map(|(key, header_name)| { let values = header_map.get_all(header_name); #T(values.iter()).map(|v| (key.to_string(), v.expect( - "we have checked there is at least one value for this header name; please file a bug report under https://github.com/awslabs/smithy-rs/issues - "))) + "we have checked there is at least one value for this header name; please file a bug report under https://github.com/awslabs/smithy-rs/issues" + ))) }).collect(); - out.map(Some) """, headerUtil, inner, ) + + for (customization in customizations) { + customization.section( + HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders(binding.member), + )(this) + } + rust("out.map(Some)") } } } @@ -221,12 +247,10 @@ class HttpBindingGenerator( private fun RustWriter.bindEventStreamOutput(operationShape: OperationShape, targetShape: UnionShape) { val unmarshallerConstructorFn = EventStreamUnmarshallerGenerator( protocol, - model, - runtimeConfig, - symbolProvider, + codegenContext, operationShape, targetShape, - target, + builderSymbol, ).render() rustTemplate( """ @@ -280,7 +304,7 @@ class HttpBindingGenerator( } } if (targetShape.hasTrait()) { - if (target == CodegenTarget.SERVER) { + if (codegenTarget == CodegenTarget.SERVER) { rust( "Ok(#T::try_from(body_str)?)", symbolProvider.toSymbol(targetShape), @@ -312,19 +336,20 @@ class HttpBindingGenerator( * Parse a value from a header. * This function produces an expression which produces the precise type required by the target shape. */ - private fun RustWriter.deserializeFromHeader(targetType: Shape, memberShape: MemberShape) { - val rustType = symbolProvider.toSymbol(targetType).rustType().stripOuter() + private fun RustWriter.deserializeFromHeader(targetShape: Shape, memberShape: MemberShape) { + val rustType = symbolProvider.toSymbol(targetShape).rustType().stripOuter() // Normally, we go through a flow that looks for `,`s but that's wrong if the output // is just a single string (which might include `,`s.). // MediaType doesn't include `,` since it's base64, send that through the normal path - if (targetType is StringShape && !targetType.hasTrait()) { + if (targetShape is StringShape && !targetShape.hasTrait()) { rust("#T::one_or_none(headers)", headerUtil) return } - val (coreType, coreShape) = if (targetType is CollectionShape) { - rustType.stripOuter() to model.expectShape(targetType.member.target) + val (coreType, coreShape) = if (targetShape is CollectionShape) { + val coreShape = model.expectShape(targetShape.member.target) + symbolProvider.toSymbol(coreShape).rustType() to coreShape } else { - rustType to targetType + rustType to targetShape } val parsedValue = safeName() if (coreType == dateTime) { @@ -336,18 +361,18 @@ class HttpBindingGenerator( ) val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat) rust( - "let $parsedValue: Vec<${coreType.render(true)}> = #T::many_dates(headers, #T)?;", + "let $parsedValue: Vec<${coreType.render()}> = #T::many_dates(headers, #T)?;", headerUtil, timestampFormatType, ) } else if (coreShape.isPrimitive()) { rust( - "let $parsedValue = #T::read_many_primitive::<${coreType.render(fullyQualified = true)}>(headers)?;", + "let $parsedValue = #T::read_many_primitive::<${coreType.render()}>(headers)?;", headerUtil, ) } else { rust( - "let $parsedValue: Vec<${coreType.render(fullyQualified = true)}> = #T::read_many_from_str(headers)?;", + "let $parsedValue: Vec<${coreType.render()}> = #T::read_many_from_str(headers)?;", headerUtil, ) if (coreShape.hasTrait()) { @@ -386,17 +411,36 @@ class HttpBindingGenerator( }) """, ) - else -> rustTemplate( - """ - if $parsedValue.len() > 1 { - Err(#{header_util}::ParseError::new_with_message(format!("expected one item but found {}", $parsedValue.len()))) + else -> { + if (targetShape is ListShape) { + // This is a constrained list shape and we must therefore be generating a server SDK. + check(codegenTarget == CodegenTarget.SERVER) + check(rustType is RustType.Opaque) + rust( + """ + Ok(if !$parsedValue.is_empty() { + Some(#T($parsedValue)) + } else { + None + }) + """, + symbolProvider.toSymbol(targetShape), + ) } else { - let mut $parsedValue = $parsedValue; - Ok($parsedValue.pop()) + check(targetShape is SimpleShape) + rustTemplate( + """ + if $parsedValue.len() > 1 { + Err(#{header_util}::ParseError::new_with_message(format!("expected one item but found {}", $parsedValue.len()))) + } else { + let mut $parsedValue = $parsedValue; + Ok($parsedValue.pop()) + } + """, + "header_util" to headerUtil, + ) } - """, - "header_util" to headerUtil, - ) + } } } @@ -475,16 +519,20 @@ class HttpBindingGenerator( val targetShape = model.expectShape(memberShape.target) val memberSymbol = symbolProvider.toSymbol(memberShape) val memberName = symbolProvider.toMemberName(memberShape) - ifSet(targetShape, memberSymbol, "&input.$memberName") { field -> - val isListHeader = targetShape is CollectionShape listForEach(targetShape, field) { innerField, targetId -> val innerMemberType = model.expectShape(targetId) if (innerMemberType.isPrimitive()) { val encoder = CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Encoder") rust("let mut encoder = #T::from(${autoDeref(innerField)});", encoder) } - val formatted = headerFmtFun(this, innerMemberType, memberShape, innerField, isListHeader) + val formatted = headerFmtFun( + this, + innerMemberType, + memberShape, + innerField, + isListHeader = targetShape is CollectionShape, + ) val safeName = safeName("formatted") write("let $safeName = $formatted;") rustBlock("if !$safeName.is_empty()") { @@ -519,6 +567,11 @@ class HttpBindingGenerator( val valueTargetShape = model.expectShape(targetShape.value.target) ifSet(targetShape, memberSymbol, "&input.$memberName") { field -> + for (customization in customizations) { + customization.section( + HttpBindingSection.BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders(field, targetShape), + )(this) + } rustTemplate( """ for (k, v) in $field { @@ -539,6 +592,7 @@ class HttpBindingGenerator( })?; builder = builder.header(header_name, header_value); } + """, "build_error" to runtimeConfig.operationBuildError(), ) @@ -564,7 +618,7 @@ class HttpBindingGenerator( val func = writer.format(RuntimeType.Base64Encode(runtimeConfig)) "$func(&$targetName)" } else { - quoteValue("AsRef::::as_ref($targetName)") + quoteValue("$targetName.as_str()") } } target.isTimestampShape -> { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/RequestBindingGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/RequestBindingGenerator.kt index dafaeea252..c4e75d297b 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/RequestBindingGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/RequestBindingGenerator.kt @@ -5,6 +5,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators.http +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.knowledge.HttpBinding import software.amazon.smithy.model.knowledge.HttpBindingIndex import software.amazon.smithy.model.pattern.SmithyPattern @@ -12,6 +13,7 @@ import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency @@ -24,6 +26,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.generators.OperationBuildError +import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.operationBuildError import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol @@ -45,6 +48,8 @@ fun SmithyPattern.rustFormatString(prefix: String, separator: String): String { return base.dq() } +// TODO(https://github.com/awslabs/smithy-rs/issues/1901) Move to `codegen-client` and update docs. +// `MakeOperationGenerator` needs to be moved to `codegen-client` first, which is not easy. /** * Generates methods to serialize and deserialize requests based on the HTTP trait. Specifically: * 1. `fn update_http_request(builder: http::request::Builder) -> Builder` @@ -62,7 +67,9 @@ class RequestBindingGenerator( private val symbolProvider = codegenContext.symbolProvider private val runtimeConfig = codegenContext.runtimeConfig private val httpTrait = protocol.httpBindingResolver.httpTrait(operationShape) - private val httpBindingGenerator = HttpBindingGenerator(protocol, codegenContext, operationShape) + private fun builderSymbol(shape: StructureShape): Symbol = shape.builderSymbol(symbolProvider) + private val httpBindingGenerator = + HttpBindingGenerator(protocol, codegenContext, codegenContext.symbolProvider, operationShape, ::builderSymbol) private val index = HttpBindingIndex.of(model) private val Encoder = CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Encoder") @@ -99,7 +106,7 @@ class RequestBindingGenerator( rust( """ let builder = #{T}(input, builder)?; - """.trimIndent(), + """, addHeadersFn, ) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/ResponseBindingGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/ResponseBindingGenerator.kt index 1de4cd2897..fde74bbd77 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/ResponseBindingGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/ResponseBindingGenerator.kt @@ -5,19 +5,27 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators.http +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol +// TODO(https://github.com/awslabs/smithy-rs/issues/1901) Move to `codegen-client` and update docs. +// `MakeOperationGenerator` needs to be moved to `codegen-client` first, which is not easy. class ResponseBindingGenerator( protocol: Protocol, - codegenContext: CodegenContext, + private val codegenContext: CodegenContext, operationShape: OperationShape, ) { - private val httpBindingGenerator = HttpBindingGenerator(protocol, codegenContext, operationShape) + private fun builderSymbol(shape: StructureShape): Symbol = shape.builderSymbol(codegenContext.symbolProvider) + + private val httpBindingGenerator = + HttpBindingGenerator(protocol, codegenContext, codegenContext.symbolProvider, operationShape, ::builderSymbol) fun generateDeserializeHeaderFn(binding: HttpBindingDescriptor): RuntimeType = httpBindingGenerator.generateDeserializeHeaderFn(binding) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/MakeOperationGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/MakeOperationGenerator.kt index c77ef20c4a..94d4cd28f1 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/MakeOperationGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/MakeOperationGenerator.kt @@ -33,6 +33,8 @@ import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.letIf +// TODO(https://github.com/awslabs/smithy-rs/issues/1901): Move to `codegen-client`. + /** Generates the `make_operation` function on input structs */ open class MakeOperationGenerator( protected val codegenContext: CodegenContext, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt index 5869e77f89..d527c0bc76 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt @@ -19,6 +19,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbolFn import software.amazon.smithy.rust.codegen.core.smithy.generators.serializationError import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator @@ -129,8 +130,14 @@ open class AwsJson( override fun additionalRequestHeaders(operationShape: OperationShape): List> = listOf("x-amz-target" to "${codegenContext.serviceShape.id.name}.${operationShape.id.name}") - override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator = - JsonParserGenerator(codegenContext, httpBindingResolver, ::awsJsonFieldName) + override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { + return JsonParserGenerator( + codegenContext, + httpBindingResolver, + ::awsJsonFieldName, + builderSymbolFn(codegenContext.symbolProvider), + ) + } override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = AwsJsonSerializerGenerator(codegenContext, httpBindingResolver) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQuery.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQuery.kt index d2bd4eb9fc..5bd7c2ab60 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQuery.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQuery.kt @@ -6,9 +6,11 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols import software.amazon.smithy.aws.traits.protocols.AwsQueryErrorTrait +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model import software.amazon.smithy.model.pattern.UriPattern import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.ToShapeId import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.model.traits.TimestampFormatTrait @@ -19,6 +21,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.AwsQueryParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.AwsQuerySerializerGenerator @@ -55,8 +58,11 @@ class AwsQueryProtocol(private val codegenContext: CodegenContext) : Protocol { override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.DATE_TIME - override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator = - AwsQueryParserGenerator(codegenContext, awsQueryErrors) + override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { + fun builderSymbol(shape: StructureShape): Symbol = + shape.builderSymbol(codegenContext.symbolProvider) + return AwsQueryParserGenerator(codegenContext, awsQueryErrors, ::builderSymbol) + } override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = AwsQuerySerializerGenerator(codegenContext) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Ec2Query.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Ec2Query.kt index 3f9dca4ca6..5f5ab09ef0 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Ec2Query.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Ec2Query.kt @@ -5,8 +5,10 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.pattern.UriPattern import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.model.traits.TimestampFormatTrait import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency @@ -16,6 +18,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.Ec2QueryParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.Ec2QuerySerializerGenerator @@ -46,8 +49,11 @@ class Ec2QueryProtocol(private val codegenContext: CodegenContext) : Protocol { override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.DATE_TIME - override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator = - Ec2QueryParserGenerator(codegenContext, ec2QueryErrors) + override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { + fun builderSymbol(shape: StructureShape): Symbol = + shape.builderSymbol(codegenContext.symbolProvider) + return Ec2QueryParserGenerator(codegenContext, ec2QueryErrors, ::builderSymbol) + } override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = Ec2QuerySerializerGenerator(codegenContext) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt index 7a25fabfdb..31c5ddae5c 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt @@ -5,6 +5,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape @@ -20,6 +21,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.asType import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerGenerator @@ -85,8 +87,11 @@ open class RestJson(val codegenContext: CodegenContext) : Protocol { override fun additionalErrorResponseHeaders(errorShape: StructureShape): List> = listOf("x-amzn-errortype" to errorShape.id.name) - override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator = - JsonParserGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName) + override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { + fun builderSymbol(shape: StructureShape): Symbol = + shape.builderSymbol(codegenContext.symbolProvider) + return JsonParserGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName, ::builderSymbol) + } override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = JsonSerializerGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestXml.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestXml.kt index 268abc0d66..3108b98dd8 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestXml.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestXml.kt @@ -6,7 +6,9 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols import software.amazon.smithy.aws.traits.protocols.RestXmlTrait +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.TimestampFormatTrait import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustModule @@ -15,6 +17,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.RestXmlParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator @@ -45,7 +48,9 @@ open class RestXml(val codegenContext: CodegenContext) : Protocol { TimestampFormatTrait.Format.DATE_TIME override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { - return RestXmlParserGenerator(codegenContext, restXmlErrors) + fun builderSymbol(shape: StructureShape): Symbol = + shape.builderSymbol(codegenContext.symbolProvider) + return RestXmlParserGenerator(codegenContext, restXmlErrors, ::builderSymbol) } override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGenerator.kt index cb0569c22d..a53bc2ab1b 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGenerator.kt @@ -5,6 +5,8 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType @@ -27,10 +29,12 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType class AwsQueryParserGenerator( codegenContext: CodegenContext, xmlErrors: RuntimeType, + builderSymbol: (shape: StructureShape) -> Symbol, private val xmlBindingTraitParserGenerator: XmlBindingTraitParserGenerator = XmlBindingTraitParserGenerator( codegenContext, xmlErrors, + builderSymbol, ) { context, inner -> val operationName = codegenContext.symbolProvider.toSymbol(context.shape).name val responseWrapperName = operationName + "Response" diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGenerator.kt index c33e257377..f59f2df556 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGenerator.kt @@ -5,6 +5,8 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType @@ -25,10 +27,12 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType class Ec2QueryParserGenerator( codegenContext: CodegenContext, xmlErrors: RuntimeType, + builderSymbol: (shape: StructureShape) -> Symbol, private val xmlBindingTraitParserGenerator: XmlBindingTraitParserGenerator = XmlBindingTraitParserGenerator( codegenContext, xmlErrors, + builderSymbol, ) { context, inner -> val operationName = codegenContext.symbolProvider.toSymbol(context.shape).name val responseWrapperName = operationName + "Response" diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt index c476fcdb9d..056be65fc4 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt @@ -6,7 +6,6 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse import software.amazon.smithy.codegen.core.Symbol -import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.BooleanShape import software.amazon.smithy.model.shapes.ByteShape @@ -26,18 +25,19 @@ import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.asType +import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.withBlock +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.error.eventStreamErrorSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant +import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticEventStreamUnionTrait import software.amazon.smithy.rust.codegen.core.smithy.transformers.eventStreamErrors @@ -48,19 +48,22 @@ import software.amazon.smithy.rust.codegen.core.util.toPascalCase class EventStreamUnmarshallerGenerator( private val protocol: Protocol, - private val model: Model, - runtimeConfig: RuntimeConfig, - private val symbolProvider: RustSymbolProvider, + codegenContext: CodegenContext, private val operationShape: OperationShape, private val unionShape: UnionShape, - private val target: CodegenTarget, + /** Function that maps a StructureShape into its builder symbol */ + private val builderSymbol: (StructureShape) -> Symbol, ) { + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + private val codegenTarget = codegenContext.target + private val runtimeConfig = codegenContext.runtimeConfig private val unionSymbol = symbolProvider.toSymbol(unionShape) private val smithyEventStream = CargoDependency.SmithyEventStream(runtimeConfig) - private val errorSymbol = if (target == CodegenTarget.SERVER && unionShape.eventStreamErrors().isEmpty()) { + private val errorSymbol = if (codegenTarget == CodegenTarget.SERVER && unionShape.eventStreamErrors().isEmpty()) { RuntimeType("MessageStreamError", smithyEventStream, "aws_smithy_http::event_stream").toSymbol() } else { - unionShape.eventStreamErrorSymbol(model, symbolProvider, target).toSymbol() + unionShape.eventStreamErrorSymbol(model, symbolProvider, codegenTarget).toSymbol() } private val eventStreamSerdeModule = RustModule.private("event_stream_serde") private val codegenScope = arrayOf( @@ -149,7 +152,7 @@ class EventStreamUnmarshallerGenerator( } } rustBlock("_unknown_variant => ") { - when (target.renderUnknownVariant()) { + when (codegenTarget.renderUnknownVariant()) { true -> rustTemplate( "Ok(#{UnmarshalledMessage}::Event(#{Output}::${UnionGenerator.UnknownVariantName}))", "Output" to unionSymbol, @@ -191,7 +194,7 @@ class EventStreamUnmarshallerGenerator( ) } else -> { - rust("let mut builder = #T::builder();", symbolProvider.toSymbol(unionStruct)) + rust("let mut builder = #T::default();", builderSymbol(unionStruct)) val payloadMember = unionStruct.members().firstOrNull { it.hasTrait() } if (payloadMember != null) { renderUnmarshallEventPayload(payloadMember) @@ -225,18 +228,19 @@ class EventStreamUnmarshallerGenerator( } private fun RustWriter.renderUnmarshallEventHeader(member: MemberShape) { - val memberName = symbolProvider.toMemberName(member) - withBlock("builder = builder.$memberName(", ");") { - when (val target = model.expectShape(member.target)) { - is BooleanShape -> rustTemplate("#{expect_fns}::expect_bool(header)?", *codegenScope) - is ByteShape -> rustTemplate("#{expect_fns}::expect_byte(header)?", *codegenScope) - is ShortShape -> rustTemplate("#{expect_fns}::expect_int16(header)?", *codegenScope) - is IntegerShape -> rustTemplate("#{expect_fns}::expect_int32(header)?", *codegenScope) - is LongShape -> rustTemplate("#{expect_fns}::expect_int64(header)?", *codegenScope) - is BlobShape -> rustTemplate("#{expect_fns}::expect_byte_array(header)?", *codegenScope) - is StringShape -> rustTemplate("#{expect_fns}::expect_string(header)?", *codegenScope) - is TimestampShape -> rustTemplate("#{expect_fns}::expect_timestamp(header)?", *codegenScope) - else -> throw IllegalStateException("unsupported event stream header shape type: $target") + withBlock("builder = builder.${member.setterName()}(", ");") { + conditionalBlock("Some(", ")", member.isOptional) { + when (val target = model.expectShape(member.target)) { + is BooleanShape -> rustTemplate("#{expect_fns}::expect_bool(header)?", *codegenScope) + is ByteShape -> rustTemplate("#{expect_fns}::expect_byte(header)?", *codegenScope) + is ShortShape -> rustTemplate("#{expect_fns}::expect_int16(header)?", *codegenScope) + is IntegerShape -> rustTemplate("#{expect_fns}::expect_int32(header)?", *codegenScope) + is LongShape -> rustTemplate("#{expect_fns}::expect_int64(header)?", *codegenScope) + is BlobShape -> rustTemplate("#{expect_fns}::expect_byte_array(header)?", *codegenScope) + is StringShape -> rustTemplate("#{expect_fns}::expect_string(header)?", *codegenScope) + is TimestampShape -> rustTemplate("#{expect_fns}::expect_timestamp(header)?", *codegenScope) + else -> throw IllegalStateException("unsupported event stream header shape type: $target") + } } } } @@ -259,31 +263,33 @@ class EventStreamUnmarshallerGenerator( *codegenScope, ) } - val memberName = symbolProvider.toMemberName(member) - withBlock("builder = builder.$memberName(", ");") { - when (target) { - is BlobShape -> { - rustTemplate("#{Blob}::new(message.payload().as_ref())", *codegenScope) - } - is StringShape -> { - rustTemplate( - """ - std::str::from_utf8(message.payload()) - .map_err(|_| #{Error}::unmarshalling("message payload is not valid UTF-8"))? - """, - *codegenScope, - ) - } - is UnionShape, is StructureShape -> { - renderParseProtocolPayload(member) + withBlock("builder = builder.${member.setterName()}(", ");") { + conditionalBlock("Some(", ")", member.isOptional) { + when (target) { + is BlobShape -> { + rustTemplate("#{Blob}::new(message.payload().as_ref())", *codegenScope) + } + is StringShape -> { + rustTemplate( + """ + std::str::from_utf8(message.payload()) + .map_err(|_| #{Error}::unmarshalling("message payload is not valid UTF-8"))? + .to_owned() + """, + *codegenScope, + ) + } + is UnionShape, is StructureShape -> { + renderParseProtocolPayload(member) + } } } } } private fun RustWriter.renderParseProtocolPayload(member: MemberShape) { - val parser = protocol.structuredDataParser(operationShape).payloadParser(member) val memberName = symbolProvider.toMemberName(member) + val parser = protocol.structuredDataParser(operationShape).payloadParser(member) rustTemplate( """ #{parser}(&message.payload()[..]) @@ -297,7 +303,7 @@ class EventStreamUnmarshallerGenerator( } private fun RustWriter.renderUnmarshallError() { - when (target) { + when (codegenTarget) { CodegenTarget.CLIENT -> { rustTemplate( """ @@ -326,12 +332,12 @@ class EventStreamUnmarshallerGenerator( rustBlock("${member.memberName.dq()} $matchOperator ") { // TODO(EventStream): Errors on the operation can be disjoint with errors in the union, // so we need to generate a new top-level Error type for each event stream union. - when (target) { + when (codegenTarget) { CodegenTarget.CLIENT -> { val target = model.expectShape(member.target, StructureShape::class.java) val parser = protocol.structuredDataParser(operationShape).errorParser(target) if (parser != null) { - rust("let mut builder = #T::builder();", symbolProvider.toSymbol(target)) + rust("let mut builder = #T::default();", builderSymbol(target)) rustTemplate( """ builder = #{parser}(&message.payload()[..], builder) @@ -354,7 +360,7 @@ class EventStreamUnmarshallerGenerator( val target = model.expectShape(member.target, StructureShape::class.java) val parser = protocol.structuredDataParser(operationShape).errorParser(target) val mut = if (parser != null) { " mut" } else { "" } - rust("let$mut builder = #T::builder();", symbolProvider.toSymbol(target)) + rust("let$mut builder = #T::default();", builderSymbol(target)) if (parser != null) { rustTemplate( """ @@ -387,7 +393,7 @@ class EventStreamUnmarshallerGenerator( rust("}") } } - when (target) { + when (codegenTarget) { CodegenTarget.CLIENT -> { rustTemplate("Ok(#{UnmarshalledMessage}::Error(#{OpError}::generic(generic)))", *codegenScope) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt index 43c18287a5..a69a8e9651 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt @@ -5,6 +5,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.BooleanShape import software.amazon.smithy.model.shapes.CollectionShape @@ -13,6 +14,7 @@ import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.NumberShape import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.TimestampShape @@ -37,12 +39,13 @@ import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.canUseDefault -import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator +import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedSectionGenerator +import software.amazon.smithy.rust.codegen.core.smithy.customize.Section import software.amazon.smithy.rust.codegen.core.smithy.generators.TypeConversionGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName +import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.isRustBoxed import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation @@ -54,16 +57,45 @@ import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.outputShape import software.amazon.smithy.utils.StringUtils +/** + * Class describing a JSON parser section that can be used in a customization. + */ +sealed class JsonParserSection(name: String) : Section(name) { + data class BeforeBoxingDeserializedMember(val shape: MemberShape) : JsonParserSection("BeforeBoxingDeserializedMember") +} + +/** + * Customization for the JSON parser. + */ +typealias JsonParserCustomization = NamedSectionGenerator + +data class ReturnSymbolToParse(val symbol: Symbol, val isUnconstrained: Boolean) + class JsonParserGenerator( - private val codegenContext: CodegenContext, + codegenContext: CodegenContext, private val httpBindingResolver: HttpBindingResolver, /** Function that maps a MemberShape into a JSON field name */ private val jsonName: (MemberShape) -> String, + /** Function that maps a StructureShape into its builder symbol */ + private val builderSymbol: (StructureShape) -> Symbol, + /** + * Whether we should parse a value for a shape into its associated unconstrained type. For example, when the shape + * is a `StructureShape`, we should construct and return a builder instead of building into the final `struct` the + * user gets. This is only relevant for the server, that parses the incoming request and only after enforces + * constraint traits. + * + * The function returns a data class that signals the return symbol that should be parsed, and whether it's + * unconstrained or not. + */ + private val returnSymbolToParse: (Shape) -> ReturnSymbolToParse = { shape -> + ReturnSymbolToParse(codegenContext.symbolProvider.toSymbol(shape), false) + }, + private val customizations: List = listOf(), ) : StructuredDataParserGenerator { private val model = codegenContext.model private val symbolProvider = codegenContext.symbolProvider private val runtimeConfig = codegenContext.runtimeConfig - private val target = codegenContext.target + private val codegenTarget = codegenContext.target private val smithyJson = CargoDependency.smithyJson(runtimeConfig).asType() private val jsonDeserModule = RustModule.private("json_deser") private val typeConversionGenerator = TypeConversionGenerator(model, symbolProvider, runtimeConfig) @@ -94,14 +126,14 @@ class JsonParserGenerator( */ private fun structureParser( fnName: String, - structureShape: StructureShape, + builderSymbol: Symbol, includedMembers: List, ): RuntimeType { return RuntimeType.forInlineFun(fnName, jsonDeserModule) { val unusedMut = if (includedMembers.isEmpty()) "##[allow(unused_mut)] " else "" rustBlockTemplate( - "pub fn $fnName(value: &[u8], ${unusedMut}mut builder: #{Builder}) -> Result<#{Builder}, #{Error}>", - "Builder" to structureShape.builderSymbol(symbolProvider), + "pub(crate) fn $fnName(value: &[u8], ${unusedMut}mut builder: #{Builder}) -> Result<#{Builder}, #{Error}>", + "Builder" to builderSymbol, *codegenScope, ) { rustTemplate( @@ -159,7 +191,7 @@ class JsonParserGenerator( } val outputShape = operationShape.outputShape(model) val fnName = symbolProvider.deserializeFunctionName(operationShape) - return structureParser(fnName, outputShape, httpDocumentMembers) + return structureParser(fnName, builderSymbol(outputShape), httpDocumentMembers) } override fun errorParser(errorShape: StructureShape): RuntimeType? { @@ -167,13 +199,13 @@ class JsonParserGenerator( return null } val fnName = symbolProvider.deserializeFunctionName(errorShape) + "_json_err" - return structureParser(fnName, errorShape, errorShape.members().toList()) + return structureParser(fnName, builderSymbol(errorShape), errorShape.members().toList()) } private fun orEmptyJson(): RuntimeType = RuntimeType.forInlineFun("or_empty_doc", jsonDeserModule) { rust( """ - pub fn or_empty_doc(data: &[u8]) -> &[u8] { + pub(crate) fn or_empty_doc(data: &[u8]) -> &[u8] { if data.is_empty() { b"{}" } else { @@ -191,7 +223,7 @@ class JsonParserGenerator( } val inputShape = operationShape.inputShape(model) val fnName = symbolProvider.deserializeFunctionName(operationShape) - return structureParser(fnName, inputShape, includedMembers) + return structureParser(fnName, builderSymbol(inputShape), includedMembers) } private fun RustWriter.expectEndOfTokenStream() { @@ -208,8 +240,29 @@ class JsonParserGenerator( rustBlock("match key.to_unescaped()?.as_ref()") { for (member in members) { rustBlock("${jsonName(member).dq()} =>") { - withBlock("builder = builder.${member.setterName()}(", ");") { - deserializeMember(member) + when (codegenTarget) { + CodegenTarget.CLIENT -> { + withBlock("builder = builder.${member.setterName()}(", ");") { + deserializeMember(member) + } + } + CodegenTarget.SERVER -> { + if (symbolProvider.toSymbol(member).isOptional()) { + withBlock("builder = builder.${member.setterName()}(", ");") { + deserializeMember(member) + } + } else { + rust("if let Some(v) = ") + deserializeMember(member) + rust( + """ + { + builder = builder.${member.setterName()}(v); + } + """, + ) + } + } } } } @@ -234,6 +287,9 @@ class JsonParserGenerator( } val symbol = symbolProvider.toSymbol(memberShape) if (symbol.isRustBoxed()) { + for (customization in customizations) { + customization.section(JsonParserSection.BeforeBoxingDeserializedMember(memberShape))(this) + } rust(".map(Box::new)") } } @@ -250,15 +306,8 @@ class JsonParserGenerator( withBlock("$escapedStrName.to_unescaped().map(|u|", ")") { when (target.hasTrait()) { true -> { - if (convertsToEnumInServer(target)) { - rustTemplate( - """ - #{EnumSymbol}::try_from(u.as_ref()) - .map_err(|e| #{Error}::custom(format!("unknown variant {}", e))) - """, - "EnumSymbol" to symbolProvider.toSymbol(target), - *codegenScope, - ) + if (returnSymbolToParse(target).isUnconstrained) { + rust("u.into_owned()") } else { rust("#T::from(u.as_ref())", symbolProvider.toSymbol(target)) } @@ -268,12 +317,8 @@ class JsonParserGenerator( } } - private fun convertsToEnumInServer(shape: StringShape) = target == CodegenTarget.SERVER && shape.hasTrait() - private fun RustWriter.deserializeString(target: StringShape) { - // Additional `.transpose()?` because we can't use `?` inside the closures that parsed the string. - val additionalTranspose = ".transpose()?".repeat(if (convertsToEnumInServer(target)) 2 else 1) - withBlockTemplate("#{expect_string_or_null}(tokens.next())?.map(|s|", ")$additionalTranspose", *codegenScope) { + withBlockTemplate("#{expect_string_or_null}(tokens.next())?.map(|s|", ").transpose()?", *codegenScope) { deserializeStringInner(target, "s") } } @@ -287,9 +332,10 @@ class JsonParserGenerator( rustTemplate( """ #{expect_number_or_null}(tokens.next())? - .map(|v| v.try_into()) + .map(#{NumberType}::try_from) .transpose()? """, + "NumberType" to symbolProvider.toSymbol(target), *codegenScope, ) } @@ -311,16 +357,17 @@ class JsonParserGenerator( private fun RustWriter.deserializeCollection(shape: CollectionShape) { val fnName = symbolProvider.deserializeFunctionName(shape) val isSparse = shape.hasTrait() + val (returnSymbol, returnUnconstrainedType) = returnSymbolToParse(shape) val parser = RuntimeType.forInlineFun(fnName, jsonDeserModule) { // Allow non-snake-case since some SDK models have lists with names prefixed with `__listOf__`, // which become `__list_of__`, and the Rust compiler warning doesn't like multiple adjacent underscores. rustBlockTemplate( """ - ##[allow(clippy::type_complexity, non_snake_case)] - pub fn $fnName<'a, I>(tokens: &mut #{Peekable}) -> Result, #{Error}> + ##[allow(non_snake_case)] + pub(crate) fn $fnName<'a, I>(tokens: &mut #{Peekable}) -> Result, #{Error}> where I: Iterator, #{Error}>> """, - "Shape" to symbolProvider.toSymbol(shape), + "ReturnType" to returnSymbol, *codegenScope, ) { startArrayOrNull { @@ -346,7 +393,11 @@ class JsonParserGenerator( } } } - rust("Ok(Some(items))") + if (returnUnconstrainedType) { + rust("Ok(Some(#{T}(items)))", returnSymbol) + } else { + rust("Ok(Some(items))") + } } } } @@ -357,16 +408,17 @@ class JsonParserGenerator( val keyTarget = model.expectShape(shape.key.target) as StringShape val fnName = symbolProvider.deserializeFunctionName(shape) val isSparse = shape.hasTrait() + val returnSymbolToParse = returnSymbolToParse(shape) val parser = RuntimeType.forInlineFun(fnName, jsonDeserModule) { // Allow non-snake-case since some SDK models have maps with names prefixed with `__mapOf__`, // which become `__map_of__`, and the Rust compiler warning doesn't like multiple adjacent underscores. rustBlockTemplate( """ - ##[allow(clippy::type_complexity, non_snake_case)] - pub fn $fnName<'a, I>(tokens: &mut #{Peekable}) -> Result, #{Error}> + ##[allow(non_snake_case)] + pub(crate) fn $fnName<'a, I>(tokens: &mut #{Peekable}) -> Result, #{Error}> where I: Iterator, #{Error}>> """, - "Shape" to symbolProvider.toSymbol(shape), + "ReturnType" to returnSymbolToParse.symbol, *codegenScope, ) { startObjectOrNull { @@ -378,9 +430,6 @@ class JsonParserGenerator( withBlock("let value =", ";") { deserializeMember(shape.value) } - if (convertsToEnumInServer(keyTarget)) { - rust("let key = key?;") - } if (isSparse) { rust("map.insert(key, value);") } else { @@ -389,7 +438,11 @@ class JsonParserGenerator( } } } - rust("Ok(Some(map))") + if (returnSymbolToParse.isUnconstrained) { + rust("Ok(Some(#{T}(map)))", returnSymbolToParse.symbol) + } else { + rust("Ok(Some(map))") + } } } } @@ -398,29 +451,25 @@ class JsonParserGenerator( private fun RustWriter.deserializeStruct(shape: StructureShape) { val fnName = symbolProvider.deserializeFunctionName(shape) - val symbol = symbolProvider.toSymbol(shape) + val returnSymbolToParse = returnSymbolToParse(shape) val nestedParser = RuntimeType.forInlineFun(fnName, jsonDeserModule) { rustBlockTemplate( """ - pub fn $fnName<'a, I>(tokens: &mut #{Peekable}) -> Result, #{Error}> + pub(crate) fn $fnName<'a, I>(tokens: &mut #{Peekable}) -> Result, #{Error}> where I: Iterator, #{Error}>> """, - "Shape" to symbol, + "ReturnType" to returnSymbolToParse.symbol, *codegenScope, ) { startObjectOrNull { Attribute.AllowUnusedMut.render(this) - rustTemplate("let mut builder = #{Shape}::builder();", *codegenScope, "Shape" to symbol) + rustTemplate("let mut builder = #{Builder}::default();", *codegenScope, "Builder" to builderSymbol(shape)) deserializeStructInner(shape.members()) - withBlock("Ok(Some(builder.build()", "))") { - if (StructureGenerator.hasFallibleBuilder(shape, symbolProvider)) { - rustTemplate( - """.map_err(|err| #{Error}::new( - #{ErrorReason}::Custom(format!("{}", err).into()), None) - )?""", - *codegenScope, - ) - } + // Only call `build()` if the builder is not fallible. Otherwise, return the builder. + if (returnSymbolToParse.isUnconstrained) { + rust("Ok(Some(builder))") + } else { + rust("Ok(Some(builder.build()))") } } } @@ -430,15 +479,15 @@ class JsonParserGenerator( private fun RustWriter.deserializeUnion(shape: UnionShape) { val fnName = symbolProvider.deserializeFunctionName(shape) - val symbol = symbolProvider.toSymbol(shape) + val returnSymbolToParse = returnSymbolToParse(shape) val nestedParser = RuntimeType.forInlineFun(fnName, jsonDeserModule) { rustBlockTemplate( """ - pub fn $fnName<'a, I>(tokens: &mut #{Peekable}) -> Result, #{Error}> + pub(crate) fn $fnName<'a, I>(tokens: &mut #{Peekable}) -> Result, #{Error}> where I: Iterator, #{Error}>> """, *codegenScope, - "Shape" to symbol, + "Shape" to returnSymbolToParse.symbol, ) { rust("let mut variant = None;") rustBlock("match tokens.next().transpose()?") { @@ -462,14 +511,14 @@ class JsonParserGenerator( for (member in shape.members()) { val variantName = symbolProvider.toMemberName(member) rustBlock("${jsonName(member).dq()} =>") { - withBlock("Some(#T::$variantName(", "))", symbol) { + withBlock("Some(#T::$variantName(", "))", returnSymbolToParse.symbol) { deserializeMember(member) unwrapOrDefaultOrError(member) } } } - when (target.renderUnknownVariant()) { - // in client mode, resolve an unknown union variant to the unknown variant + when (codegenTarget.renderUnknownVariant()) { + // In client mode, resolve an unknown union variant to the unknown variant. true -> rustTemplate( """ _ => { @@ -477,9 +526,11 @@ class JsonParserGenerator( Some(#{Union}::${UnionGenerator.UnknownVariantName}) } """, - "Union" to symbol, *codegenScope, + "Union" to returnSymbolToParse.symbol, + *codegenScope, ) - // in server mode, use strict parsing + // In server mode, use strict parsing. + // Consultation: https://github.com/awslabs/smithy/issues/1222 false -> rustTemplate( """variant => return Err(#{Error}::custom(format!("unexpected union variant: {}", variant)))""", *codegenScope, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/RestXmlParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/RestXmlParserGenerator.kt index ed41cfd85a..156d025b9b 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/RestXmlParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/RestXmlParserGenerator.kt @@ -5,6 +5,8 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType @@ -17,10 +19,12 @@ import software.amazon.smithy.rust.codegen.core.util.orNull class RestXmlParserGenerator( codegenContext: CodegenContext, xmlErrors: RuntimeType, + builderSymbol: (shape: StructureShape) -> Symbol, private val xmlBindingTraitParserGenerator: XmlBindingTraitParserGenerator = XmlBindingTraitParserGenerator( codegenContext, xmlErrors, + builderSymbol, ) { context, inner -> val shapeName = context.outputShapeName // Get the non-synthetic version of the outputShape and check to see if it has the `AllowInvalidXmlRoot` trait diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt index 841975c229..37b0d1e616 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt @@ -7,6 +7,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse import software.amazon.smithy.aws.traits.customizations.S3UnwrappedXmlOutputTrait import software.amazon.smithy.codegen.core.CodegenException +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.HttpBinding import software.amazon.smithy.model.knowledge.HttpBindingIndex @@ -41,9 +42,8 @@ import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName import software.amazon.smithy.rust.codegen.core.smithy.isOptional @@ -71,6 +71,7 @@ data class OperationWrapperContext( class XmlBindingTraitParserGenerator( codegenContext: CodegenContext, private val xmlErrors: RuntimeType, + private val builderSymbol: (shape: StructureShape) -> Symbol, private val writeOperationWrapper: RustWriter.(OperationWrapperContext, OperationInnerWriteable) -> Unit, ) : StructuredDataParserGenerator { @@ -187,7 +188,7 @@ class XmlBindingTraitParserGenerator( Attribute.AllowUnusedMut.render(this) rustBlock( "pub fn $fnName(inp: &[u8], mut builder: #1T) -> Result<#1T, #2T>", - outputShape.builderSymbol(symbolProvider), + builderSymbol(outputShape), xmlError, ) { rustTemplate( @@ -220,7 +221,7 @@ class XmlBindingTraitParserGenerator( Attribute.AllowUnusedMut.render(this) rustBlock( "pub fn $fnName(inp: &[u8], mut builder: #1T) -> Result<#1T, #2T>", - errorShape.builderSymbol(symbolProvider), + builderSymbol(errorShape), xmlError, ) { val members = errorShape.errorXmlMembers() @@ -254,7 +255,7 @@ class XmlBindingTraitParserGenerator( Attribute.AllowUnusedMut.render(this) rustBlock( "pub fn $fnName(inp: &[u8], mut builder: #1T) -> Result<#1T, #2T>", - inputShape.builderSymbol(symbolProvider), + builderSymbol(inputShape), xmlError, ) { rustTemplate( @@ -476,7 +477,7 @@ class XmlBindingTraitParserGenerator( rust("let _ = decoder;") } withBlock("Ok(builder.build()", ")") { - if (StructureGenerator.hasFallibleBuilder(shape, symbolProvider)) { + if (BuilderGenerator.hasFallibleBuilder(shape, symbolProvider)) { // NOTE:(rcoh) This branch is unreachable given the current nullability rules. // Only synthetic inputs can have fallible builders, but synthetic inputs can never be parsed // (because they're inputs, only outputs will be parsed!) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt index 415a0a8de8..4ea25612b6 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt @@ -19,7 +19,6 @@ import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.TimestampShape import software.amazon.smithy.model.shapes.UnionShape -import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.model.traits.TimestampFormatTrait.Format.EPOCH_SECONDS import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustModule @@ -48,35 +47,37 @@ import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.expectTrait -import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.outputShape /** - * Class describing a JSON section that can be used in a customization. + * Class describing a JSON serializer section that can be used in a customization. */ -sealed class JsonSection(name: String) : Section(name) { +sealed class JsonSerializerSection(name: String) : Section(name) { /** Mutate the server error object prior to finalization. Eg: this can be used to inject `__type` to record the error type. */ - data class ServerError(val structureShape: StructureShape, val jsonObject: String) : JsonSection("ServerError") + data class ServerError(val structureShape: StructureShape, val jsonObject: String) : JsonSerializerSection("ServerError") + + /** Mutate a map prior to it being serialized. **/ + data class BeforeIteratingOverMap(val shape: MapShape, val valueExpression: ValueExpression) : JsonSerializerSection("BeforeIteratingOverMap") /** Mutate the input object prior to finalization. */ - data class InputStruct(val structureShape: StructureShape, val jsonObject: String) : JsonSection("InputStruct") + data class InputStruct(val structureShape: StructureShape, val jsonObject: String) : JsonSerializerSection("InputStruct") /** Mutate the output object prior to finalization. */ - data class OutputStruct(val structureShape: StructureShape, val jsonObject: String) : JsonSection("OutputStruct") + data class OutputStruct(val structureShape: StructureShape, val jsonObject: String) : JsonSerializerSection("OutputStruct") } /** - * JSON customization. + * Customization for the JSON serializer. */ -typealias JsonCustomization = NamedSectionGenerator +typealias JsonSerializerCustomization = NamedSectionGenerator class JsonSerializerGenerator( codegenContext: CodegenContext, private val httpBindingResolver: HttpBindingResolver, /** Function that maps a MemberShape into a JSON field name */ private val jsonName: (MemberShape) -> String, - private val customizations: List = listOf(), + private val customizations: List = listOf(), ) : StructuredDataSerializerGenerator { private data class Context( /** Expression that retrieves a JsonValueWriter from either a JsonObjectWriter or JsonArrayWriter */ @@ -154,7 +155,7 @@ class JsonSerializerGenerator( private val model = codegenContext.model private val symbolProvider = codegenContext.symbolProvider - private val target = codegenContext.target + private val codegenTarget = codegenContext.target private val runtimeConfig = codegenContext.runtimeConfig private val smithyTypes = CargoDependency.SmithyTypes(runtimeConfig).asType() private val smithyJson = CargoDependency.smithyJson(runtimeConfig).asType() @@ -180,7 +181,7 @@ class JsonSerializerGenerator( fnName: String, structureShape: StructureShape, includedMembers: List, - makeSection: (StructureShape, String) -> JsonSection, + makeSection: (StructureShape, String) -> JsonSerializerSection, ): RuntimeType { return RuntimeType.forInlineFun(fnName, operationSerModule) { rustBlockTemplate( @@ -251,7 +252,7 @@ class JsonSerializerGenerator( rust("let mut out = String::new();") rustTemplate("let mut object = #{JsonObjectWriter}::new(&mut out);", *codegenScope) serializeStructure(StructContext("object", "input", inputShape), httpDocumentMembers) - customizations.forEach { it.section(JsonSection.InputStruct(inputShape, "object"))(this) } + customizations.forEach { it.section(JsonSerializerSection.InputStruct(inputShape, "object"))(this) } rust("object.finish();") rustTemplate("Ok(#{SdkBody}::from(out))", *codegenScope) } @@ -293,7 +294,7 @@ class JsonSerializerGenerator( val outputShape = operationShape.outputShape(model) val fnName = symbolProvider.serializeFunctionName(outputShape) - return serverSerializer(fnName, outputShape, httpDocumentMembers, JsonSection::OutputStruct) + return serverSerializer(fnName, outputShape, httpDocumentMembers, JsonSerializerSection::OutputStruct) } override fun serverErrorSerializer(shape: ShapeId): RuntimeType { @@ -302,7 +303,7 @@ class JsonSerializerGenerator( httpBindingResolver.errorResponseBindings(shape).filter { it.location == HttpLocation.DOCUMENT } .map { it.member } val fnName = symbolProvider.serializeFunctionName(errorShape) - return serverSerializer(fnName, errorShape, includedMembers, JsonSection::ServerError) + return serverSerializer(fnName, errorShape, includedMembers, JsonSerializerSection::ServerError) } private fun RustWriter.serializeStructure( @@ -358,6 +359,7 @@ class JsonSerializerGenerator( private fun RustWriter.serializeMemberValue(context: MemberContext, target: Shape) { val writer = context.writerExpression val value = context.valueExpression + when (target) { is StringShape -> rust("$writer.string(${value.name}.as_str());") is BooleanShape -> rust("$writer.boolean(${value.asValue()});") @@ -430,12 +432,11 @@ class JsonSerializerGenerator( private fun RustWriter.serializeMap(context: Context) { val keyName = safeName("key") val valueName = safeName("value") + for (customization in customizations) { + customization.section(JsonSerializerSection.BeforeIteratingOverMap(context.shape, context.valueExpression))(this) + } rustBlock("for ($keyName, $valueName) in ${context.valueExpression.asRef()}") { - val keyTarget = model.expectShape(context.shape.key.target) - val keyExpression = when (keyTarget.hasTrait()) { - true -> "$keyName.as_str()" - else -> keyName - } + val keyExpression = "$keyName.as_str()" serializeMember(MemberContext.mapMember(context, keyExpression, valueName)) } } @@ -456,7 +457,7 @@ class JsonSerializerGenerator( serializeMember(MemberContext.unionMember(context, "inner", member, jsonName)) } } - if (target.renderUnknownVariant()) { + if (codegenTarget.renderUnknownVariant()) { rustTemplate( "#{Union}::${UnionGenerator.UnknownVariantName} => return Err(#{Error}::unknown_variant(${unionSymbol.name.dq()}))", "Union" to unionSymbol, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt index 2ca848e075..cf8a77a76f 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt @@ -63,7 +63,7 @@ class XmlBindingTraitSerializerGenerator( private val runtimeConfig = codegenContext.runtimeConfig private val model = codegenContext.model private val smithyXml = CargoDependency.smithyXml(runtimeConfig).asType() - private val target = codegenContext.target + private val codegenTarget = codegenContext.target private val codegenScope = arrayOf( "XmlWriter" to smithyXml.member("encode::XmlWriter"), @@ -291,7 +291,14 @@ class XmlBindingTraitSerializerGenerator( private fun RustWriter.serializeRawMember(member: MemberShape, input: String) { when (model.expectShape(member.target)) { is StringShape -> { - rust("$input.as_str()") + // The `input` expression always evaluates to a reference type at this point, but if it does so because + // it's preceded by the `&` operator, calling `as_str()` on it will upset Clippy. + val dereferenced = if (input.startsWith("&")) { + autoDeref(input) + } else { + input + } + rust("$dereferenced.as_str()") } is BooleanShape, is NumberShape -> { rust( @@ -399,7 +406,7 @@ class XmlBindingTraitSerializerGenerator( } } - if (target.renderUnknownVariant()) { + if (codegenTarget.renderUnknownVariant()) { rustTemplate( "#{Union}::${UnionGenerator.UnknownVariantName} => return Err(#{Error}::unknown_variant(${unionSymbol.name.dq()}))", "Union" to unionSymbol, diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt index 9dc1e0592d..abe650017d 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt @@ -10,6 +10,7 @@ import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.node.Node import software.amazon.smithy.model.node.StringNode import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape @@ -17,13 +18,13 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer -import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder -import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider +import software.amazon.smithy.rust.codegen.core.testutil.testCodegenContext import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.lookup @@ -83,15 +84,28 @@ class InstantiatorTest { } """.asSmithyModel().let { RecursiveShapeBoxer.transform(it) } - private val symbolProvider = testSymbolProvider(model) - private val runtimeConfig = TestRuntimeConfig + private val codegenContext = testCodegenContext(model) + private val symbolProvider = codegenContext.symbolProvider + private val runtimeConfig = codegenContext.runtimeConfig + // This is the exact same behavior of the client. + private class BuilderKindBehavior(val codegenContext: CodegenContext) : Instantiator.BuilderKindBehavior { + override fun hasFallibleBuilder(shape: StructureShape) = + BuilderGenerator.hasFallibleBuilder(shape, codegenContext.symbolProvider) + + override fun setterName(memberShape: MemberShape) = memberShape.setterName() + + override fun doesSetterTakeInOption(memberShape: MemberShape) = true + } + + // This can be empty since the actual behavior is tested in `ClientInstantiatorTest` and `ServerInstantiatorTest`. private fun enumFromStringFn(symbol: Symbol, data: String) = writable { } @Test fun `generate unions`() { val union = model.lookup("com.test#MyUnion") - val sut = Instantiator(symbolProvider, model, runtimeConfig, ::enumFromStringFn) + val sut = + Instantiator(symbolProvider, model, runtimeConfig, BuilderKindBehavior(codegenContext), ::enumFromStringFn) val data = Node.parse("""{ "stringVariant": "ok!" }""") val project = TestWorkspace.testProject() @@ -110,7 +124,8 @@ class InstantiatorTest { @Test fun `generate struct builders`() { val structure = model.lookup("com.test#MyStruct") - val sut = Instantiator(symbolProvider, model, runtimeConfig, ::enumFromStringFn) + val sut = + Instantiator(symbolProvider, model, runtimeConfig, BuilderKindBehavior(codegenContext), ::enumFromStringFn) val data = Node.parse("""{ "bar": 10, "foo": "hello" }""") val project = TestWorkspace.testProject() @@ -134,7 +149,8 @@ class InstantiatorTest { @Test fun `generate builders for boxed structs`() { val structure = model.lookup("com.test#WithBox") - val sut = Instantiator(symbolProvider, model, runtimeConfig, ::enumFromStringFn) + val sut = + Instantiator(symbolProvider, model, runtimeConfig, BuilderKindBehavior(codegenContext), ::enumFromStringFn) val data = Node.parse( """ { @@ -172,7 +188,8 @@ class InstantiatorTest { @Test fun `generate lists`() { val data = Node.parse("""["bar", "foo"]""") - val sut = Instantiator(symbolProvider, model, runtimeConfig, ::enumFromStringFn) + val sut = + Instantiator(symbolProvider, model, runtimeConfig, BuilderKindBehavior(codegenContext), ::enumFromStringFn) val project = TestWorkspace.testProject() project.withModule(RustModule.Model) { @@ -180,16 +197,21 @@ class InstantiatorTest { withBlock("let result = ", ";") { sut.render(this, model.lookup("com.test#MyList"), data) } - rust("""assert_eq!(result, vec!["bar".to_owned(), "foo".to_owned()]);""") } + project.compileAndTest() } - project.compileAndTest() } @Test fun `generate sparse lists`() { val data = Node.parse(""" [ "bar", "foo", null ] """) - val sut = Instantiator(symbolProvider, model, runtimeConfig, ::enumFromStringFn) + val sut = Instantiator( + symbolProvider, + model, + runtimeConfig, + BuilderKindBehavior(codegenContext), + ::enumFromStringFn, + ) val project = TestWorkspace.testProject() project.withModule(RustModule.Model) { @@ -207,14 +229,20 @@ class InstantiatorTest { fun `generate maps of maps`() { val data = Node.parse( """ - { - "k1": { "map": {} }, - "k2": { "map": { "k3": {} } }, - "k3": { } - } - """, + { + "k1": { "map": {} }, + "k2": { "map": { "k3": {} } }, + "k3": { } + } + """, + ) + val sut = Instantiator( + symbolProvider, + model, + runtimeConfig, + BuilderKindBehavior(codegenContext), + ::enumFromStringFn, ) - val sut = Instantiator(symbolProvider, model, runtimeConfig, ::enumFromStringFn) val inner = model.lookup("com.test#Inner") val project = TestWorkspace.testProject() @@ -226,11 +254,11 @@ class InstantiatorTest { } rust( """ - assert_eq!(result.len(), 3); - assert_eq!(result.get("k1").unwrap().map.as_ref().unwrap().len(), 0); - assert_eq!(result.get("k2").unwrap().map.as_ref().unwrap().len(), 1); - assert_eq!(result.get("k3").unwrap().map, None); - """, + assert_eq!(result.len(), 3); + assert_eq!(result.get("k1").unwrap().map.as_ref().unwrap().len(), 0); + assert_eq!(result.get("k2").unwrap().map.as_ref().unwrap().len(), 1); + assert_eq!(result.get("k3").unwrap().map, None); + """, ) } } @@ -241,7 +269,13 @@ class InstantiatorTest { fun `blob inputs are binary data`() { // "Parameter values that contain binary data MUST be defined using values // that can be represented in plain text (for example, use "foo" and not "Zm9vCg==")." - val sut = Instantiator(symbolProvider, model, runtimeConfig, ::enumFromStringFn) + val sut = Instantiator( + symbolProvider, + model, + runtimeConfig, + BuilderKindBehavior(codegenContext), + ::enumFromStringFn, + ) val project = TestWorkspace.testProject() project.withModule(RustModule.Model) { diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGeneratorTest.kt index cee3fb1472..b90543bea3 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGeneratorTest.kt @@ -10,6 +10,7 @@ import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbolFn import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig @@ -45,7 +46,11 @@ class AwsQueryParserGeneratorTest { val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model) val symbolProvider = codegenContext.symbolProvider - val parserGenerator = AwsQueryParserGenerator(codegenContext, RuntimeType.wrappedXmlErrors(TestRuntimeConfig)) + val parserGenerator = AwsQueryParserGenerator( + codegenContext, + RuntimeType.wrappedXmlErrors(TestRuntimeConfig), + builderSymbolFn(symbolProvider), + ) val operationParser = parserGenerator.operationParser(model.lookup("test#SomeOperation"))!! val project = TestWorkspace.testProject(testSymbolProvider(model)) diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGeneratorTest.kt index b8eb1e77bb..7b835d8223 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGeneratorTest.kt @@ -10,6 +10,7 @@ import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbolFn import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig @@ -45,7 +46,11 @@ class Ec2QueryParserGeneratorTest { val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model) val symbolProvider = codegenContext.symbolProvider - val parserGenerator = Ec2QueryParserGenerator(codegenContext, RuntimeType.wrappedXmlErrors(TestRuntimeConfig)) + val parserGenerator = Ec2QueryParserGenerator( + codegenContext, + RuntimeType.wrappedXmlErrors(TestRuntimeConfig), + builderSymbolFn(symbolProvider), + ) val operationParser = parserGenerator.operationParser(model.lookup("test#SomeOperation"))!! val project = TestWorkspace.testProject(testSymbolProvider(model)) diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt index 049e456d8a..5cd898a29f 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt @@ -6,12 +6,14 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse import org.junit.jupiter.api.Test +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpTraitHttpBindingResolver import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolContentTypes import software.amazon.smithy.rust.codegen.core.smithy.protocols.restJsonFieldName @@ -115,10 +117,14 @@ class JsonParserGeneratorTest { val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model) val symbolProvider = codegenContext.symbolProvider + fun builderSymbol(shape: StructureShape): Symbol = + shape.builderSymbol(symbolProvider) + val parserGenerator = JsonParserGenerator( codegenContext, HttpTraitHttpBindingResolver(model, ProtocolContentTypes.consistent("application/json")), ::restJsonFieldName, + ::builderSymbol, ) val operationGenerator = parserGenerator.operationParser(model.lookup("test#Op")) val payloadGenerator = parserGenerator.payloadParser(model.lookup("test#OpOutput\$top")) diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt index c4932fe71a..52bb4e1e7e 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt @@ -13,6 +13,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbolFn import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig @@ -96,6 +97,7 @@ internal class XmlBindingTraitParserGeneratorTest { val parserGenerator = XmlBindingTraitParserGenerator( codegenContext, RuntimeType.wrappedXmlErrors(TestRuntimeConfig), + builderSymbolFn(symbolProvider), ) { _, inner -> inner("decoder") } val operationParser = parserGenerator.operationParser(model.lookup("test#Op"))!! val project = TestWorkspace.testProject(testSymbolProvider(model)) diff --git a/codegen-server-test/build.gradle.kts b/codegen-server-test/build.gradle.kts index 3d01b403ab..e63401daff 100644 --- a/codegen-server-test/build.gradle.kts +++ b/codegen-server-test/build.gradle.kts @@ -33,6 +33,7 @@ dependencies { implementation("software.amazon.smithy:smithy-aws-protocol-tests:$smithyVersion") implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") + implementation("software.amazon.smithy:smithy-validation-model:$smithyVersion") } val allCodegenTests = "../codegen-core/common-test-models".let { commonModels -> @@ -40,13 +41,24 @@ val allCodegenTests = "../codegen-core/common-test-models".let { commonModels -> CodegenTest("crate#Config", "naming_test_ops", imports = listOf("$commonModels/naming-obstacle-course-ops.smithy")), CodegenTest("naming_obs_structs#NamingObstacleCourseStructs", "naming_test_structs", imports = listOf("$commonModels/naming-obstacle-course-structs.smithy")), CodegenTest("com.amazonaws.simple#SimpleService", "simple", imports = listOf("$commonModels/simple.smithy")), + CodegenTest( + "com.amazonaws.constraints#ConstraintsService", "constraints_without_public_constrained_types", + imports = listOf("$commonModels/constraints.smithy"), + extraConfig = """, "codegen": { "publicConstrainedTypes": false } """, + ), + CodegenTest("com.amazonaws.constraints#ConstraintsService", "constraints", imports = listOf("$commonModels/constraints.smithy")), CodegenTest("aws.protocoltests.restjson#RestJson", "rest_json"), CodegenTest("aws.protocoltests.restjson#RestJsonExtras", "rest_json_extras", imports = listOf("$commonModels/rest-json-extras.smithy")), - CodegenTest("aws.protocoltests.restjson.validation#RestJsonValidation", "rest_json_validation"), + CodegenTest("aws.protocoltests.restjson.validation#RestJsonValidation", "rest_json_validation", + extraConfig = """, "codegen": { "ignoreUnsupportedConstraints": true } """, + ), CodegenTest("aws.protocoltests.json10#JsonRpc10", "json_rpc10"), CodegenTest("aws.protocoltests.json#JsonProtocol", "json_rpc11"), CodegenTest("aws.protocoltests.misc#MiscService", "misc", imports = listOf("$commonModels/misc.smithy")), - CodegenTest("com.amazonaws.ebs#Ebs", "ebs", imports = listOf("$commonModels/ebs.json")), + CodegenTest("com.amazonaws.ebs#Ebs", "ebs", + imports = listOf("$commonModels/ebs.json"), + extraConfig = """, "codegen": { "ignoreUnsupportedConstraints": true } """, + ), CodegenTest("com.amazonaws.s3#AmazonS3", "s3"), CodegenTest("com.aws.example.rust#PokemonService", "pokemon-service-server-sdk", imports = listOf("$commonModels/pokemon.smithy", "$commonModels/pokemon-common.smithy")), ) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonCodegenServerPlugin.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonCodegenServerPlugin.kt index a2fa85b72c..218eb8af1b 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonCodegenServerPlugin.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonCodegenServerPlugin.kt @@ -18,6 +18,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitor import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig import software.amazon.smithy.rust.codegen.server.python.smithy.customizations.DECORATORS +import software.amazon.smithy.rust.codegen.server.smithy.ConstrainedShapeSymbolProvider import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.customizations.ServerRequiredCustomizations import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator @@ -65,14 +66,17 @@ class PythonCodegenServerPlugin : SmithyBuildPlugin { model: Model, serviceShape: ServiceShape, symbolVisitorConfig: SymbolVisitorConfig, + constrainedTypes: Boolean = true, ) = // Rename a set of symbols that do not implement `PyClass` and have been wrapped in // `aws_smithy_http_server_python::types`. PythonServerSymbolVisitor(model, serviceShape = serviceShape, config = symbolVisitorConfig) + // Generate public constrained types for directly constrained shapes. + // In the Python server project, this is only done to generate constrained types for simple shapes (e.g. + // a `string` shape with the `length` trait), but these always remain `pub(crate)`. + .let { if (constrainedTypes) ConstrainedShapeSymbolProvider(it, model, serviceShape) else it } // Generate different types for EventStream shapes (e.g. transcribe streaming) - .let { - EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model, CodegenTarget.SERVER) - } + .let { EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model, CodegenTarget.SERVER) } // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes .let { BaseSymbolMetadataProvider(it, model, additionalAttributes = listOf()) } // Streaming shapes need different derives (e.g. they cannot derive Eq) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt index 32a6fe3887..9c38f61187 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt @@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.server.python.smithy import software.amazon.smithy.build.PluginContext import software.amazon.smithy.codegen.core.CodegenException +import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.NullableIndex import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.StringShape @@ -15,18 +16,17 @@ import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.rust.codegen.client.smithy.customize.RustCodegenDecorator +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig -import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock -import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerEnumGenerator import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerServiceGenerator import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerStructureGenerator import software.amazon.smithy.rust.codegen.server.smithy.DefaultServerPublicModules import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenVisitor +import software.amazon.smithy.rust.codegen.server.smithy.ServerSymbolProviders import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerProtocolLoader @@ -61,14 +61,44 @@ class PythonServerCodegenVisitor( ) .protocolFor(context.model, service) protocolGeneratorFactory = generator + model = codegenDecorator.transformModel(service, baseModel) - symbolProvider = PythonCodegenServerPlugin.baseSymbolProvider(model, service, symbolVisitorConfig) - // Override `codegenContext` which carries the symbolProvider. - codegenContext = ServerCodegenContext(model, symbolProvider, service, protocol, settings) + // `publicConstrainedTypes` must always be `false` for the Python server, since Python generates its own + // wrapper newtypes. + settings = settings.copy(codegenConfig = settings.codegenConfig.copy(publicConstrainedTypes = false)) + + fun baseSymbolProviderFactory( + model: Model, + serviceShape: ServiceShape, + symbolVisitorConfig: SymbolVisitorConfig, + publicConstrainedTypes: Boolean, + ) = PythonCodegenServerPlugin.baseSymbolProvider(model, serviceShape, symbolVisitorConfig, publicConstrainedTypes) + + val serverSymbolProviders = ServerSymbolProviders.from( + model, + service, + symbolVisitorConfig, + settings.codegenConfig.publicConstrainedTypes, + ::baseSymbolProviderFactory, + ) + + // Override `codegenContext` which carries the various symbol providers. + codegenContext = + ServerCodegenContext( + model, + serverSymbolProviders.symbolProvider, + service, + protocol, + settings, + serverSymbolProviders.unconstrainedShapeSymbolProvider, + serverSymbolProviders.constrainedShapeSymbolProvider, + serverSymbolProviders.constraintViolationSymbolProvider, + serverSymbolProviders.pubCrateConstrainedShapeSymbolProvider, + ) // Override `rustCrate` which carries the symbolProvider. - rustCrate = RustCrate(context.fileManifest, symbolProvider, DefaultServerPublicModules, settings.codegenConfig) + rustCrate = RustCrate(context.fileManifest, codegenContext.symbolProvider, DefaultServerPublicModules, settings.codegenConfig) // Override `protocolGenerator` which carries the symbolProvider. protocolGenerator = protocolGeneratorFactory.buildProtocolGenerator(codegenContext) } @@ -88,13 +118,9 @@ class PythonServerCodegenVisitor( rustCrate.useShapeWriter(shape) { // Use Python specific structure generator that adds the #[pyclass] attribute // and #[pymethods] implementation. - PythonServerStructureGenerator(model, symbolProvider, this, shape).render(CodegenTarget.SERVER) - val builderGenerator = - BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape) - builderGenerator.render(this) - implBlock(shape, symbolProvider) { - builderGenerator.renderConvenienceMethod(this) - } + PythonServerStructureGenerator(model, codegenContext.symbolProvider, this, shape).render(CodegenTarget.SERVER) + + renderStructureShapeBuilder(shape, this) } } @@ -104,12 +130,9 @@ class PythonServerCodegenVisitor( * Although raw strings require no code generation, enums are actually [EnumTrait] applied to string shapes. */ override fun stringShape(shape: StringShape) { - logger.info("[rust-server-codegen] Generating an enum $shape") - shape.getTrait()?.also { enum -> - rustCrate.useShapeWriter(shape) { - PythonServerEnumGenerator(model, symbolProvider, this, shape, enum, codegenContext.runtimeConfig).render() - } - } + fun pythonServerEnumGeneratorFactory(codegenContext: ServerCodegenContext, writer: RustWriter, shape: StringShape) = + PythonServerEnumGenerator(codegenContext, writer, shape) + stringShape(shape, ::pythonServerEnumGeneratorFactory) } /** diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEnumGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEnumGenerator.kt index b804d12bbb..cad7bad677 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEnumGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEnumGenerator.kt @@ -5,9 +5,7 @@ package software.amazon.smithy.rust.codegen.server.python.smithy.generators -import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.StringShape -import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable @@ -16,10 +14,9 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig -import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCargoDependency +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerEnumGenerator /** @@ -28,13 +25,10 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerEnumGe * some utility functions like `__str__()` and `__repr__()`. */ class PythonServerEnumGenerator( - model: Model, - symbolProvider: RustSymbolProvider, + codegenContext: ServerCodegenContext, private val writer: RustWriter, - private val shape: StringShape, - enumTrait: EnumTrait, - runtimeConfig: RuntimeConfig, -) : ServerEnumGenerator(model, symbolProvider, writer, shape, enumTrait, runtimeConfig) { + shape: StringShape, +) : ServerEnumGenerator(codegenContext, writer, shape) { private val pyo3Symbols = listOf(PythonServerCargoDependency.PyO3.asType()) @@ -48,11 +42,6 @@ class PythonServerEnumGenerator( Attribute.Custom("pyo3::pyclass", symbols = pyo3Symbols).render(writer) } - override fun renderFromForStr() { - renderPyClass() - super.renderFromForStr() - } - private fun renderPyO3Methods() { Attribute.Custom("pyo3::pymethods", symbols = pyo3Symbols).render(writer) writer.rustTemplate( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProvider.kt new file mode 100644 index 0000000000..92ff5faf77 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProvider.kt @@ -0,0 +1,109 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.knowledge.NullableIndex +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.traits.LengthTrait +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.smithy.Models +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.contextName +import software.amazon.smithy.rust.codegen.core.smithy.handleOptionality +import software.amazon.smithy.rust.codegen.core.smithy.handleRustBoxing +import software.amazon.smithy.rust.codegen.core.smithy.locatedIn +import software.amazon.smithy.rust.codegen.core.smithy.rustType +import software.amazon.smithy.rust.codegen.core.smithy.symbolBuilder +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.toPascalCase + +/** + * The [ConstrainedShapeSymbolProvider] returns, for a given _directly_ + * constrained shape, a symbol whose Rust type can hold the constrained values. + * + * For all shapes with supported traits directly attached to them, this type is + * a [RustType.Opaque] wrapper tuple newtype holding the inner constrained + * type. + * + * The symbols this symbol provider returns are always public and exposed to + * the end user. + * + * This symbol provider is meant to be used "deep" within the wrapped symbol + * providers chain, just above the core base symbol provider, `SymbolVisitor`. + * + * If the shape is _transitively but not directly_ constrained, use + * [PubCrateConstrainedShapeSymbolProvider] instead, which returns symbols + * whose associated types are `pub(crate)` and thus not exposed to the end + * user. + */ +class ConstrainedShapeSymbolProvider( + private val base: RustSymbolProvider, + private val model: Model, + private val serviceShape: ServiceShape, +) : WrappingSymbolProvider(base) { + private val nullableIndex = NullableIndex.of(model) + + private fun publicConstrainedSymbolForMapShape(shape: Shape): Symbol { + check(shape is MapShape) + + val rustType = RustType.Opaque(shape.contextName(serviceShape).toPascalCase()) + return symbolBuilder(shape, rustType).locatedIn(Models).build() + } + + override fun toSymbol(shape: Shape): Symbol { + return when (shape) { + is MemberShape -> { + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) Member shapes can have constraint traits + // (constraint trait precedence). + val target = model.expectShape(shape.target) + val targetSymbol = this.toSymbol(target) + // Handle boxing first so we end up with `Option>`, not `Box>`. + handleOptionality(handleRustBoxing(targetSymbol, shape), shape, nullableIndex, base.config().nullabilityCheckMode) + } + is MapShape -> { + if (shape.isDirectlyConstrained(base)) { + check(shape.hasTrait()) { "Only the `length` constraint trait can be applied to maps" } + publicConstrainedSymbolForMapShape(shape) + } else { + val keySymbol = this.toSymbol(shape.key) + val valueSymbol = this.toSymbol(shape.value) + symbolBuilder(shape, RustType.HashMap(keySymbol.rustType(), valueSymbol.rustType())) + .addReference(keySymbol) + .addReference(valueSymbol) + .build() + } + } + is CollectionShape -> { + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) Both arms return the same because we haven't + // implemented any constraint trait on collection shapes yet. + if (shape.isDirectlyConstrained(base)) { + val inner = this.toSymbol(shape.member) + symbolBuilder(shape, RustType.Vec(inner.rustType())).addReference(inner).build() + } else { + val inner = this.toSymbol(shape.member) + symbolBuilder(shape, RustType.Vec(inner.rustType())).addReference(inner).build() + } + } + is StringShape -> { + if (shape.isDirectlyConstrained(base)) { + val rustType = RustType.Opaque(shape.contextName(serviceShape).toPascalCase()) + symbolBuilder(shape, rustType).locatedIn(Models).build() + } else { + base.toSymbol(shape) + } + } + else -> base.toSymbol(shape) + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintViolationSymbolProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintViolationSymbolProvider.kt new file mode 100644 index 0000000000..119db96452 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintViolationSymbolProvider.kt @@ -0,0 +1,123 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.smithy.Models +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.contextName +import software.amazon.smithy.rust.codegen.core.smithy.rustType +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase +import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderSymbol + +/** + * The [ConstraintViolationSymbolProvider] returns, for a given constrained + * shape, a symbol whose Rust type can hold information about constraint + * violations that may occur when building the shape from unconstrained values. + * + * So, for example, given the model: + * + * ```smithy + * @pattern("\\w+") + * @length(min: 1, max: 69) + * string NiceString + * + * structure Structure { + * @required + * niceString: NiceString + * } + * ``` + * + * A `NiceString` built from an arbitrary Rust `String` may give rise to at + * most two constraint trait violations: one for `pattern`, one for `length`. + * Similarly, the shape `Structure` can fail to be built when a value for + * `niceString` is not provided. + * + * Said type is always called `ConstraintViolation`, and resides in a bespoke + * module inside the same module as the _public_ constrained type the user is + * exposed to. When the user is _not_ exposed to the constrained type, the + * constraint violation type's module is a child of the `model` module. + * + * It is the responsibility of the caller to ensure that the shape is + * constrained (either directly or transitively) before using this symbol + * provider. This symbol provider intentionally crashes if the shape is not + * constrained. + */ +class ConstraintViolationSymbolProvider( + private val base: RustSymbolProvider, + private val model: Model, + private val serviceShape: ServiceShape, + private val publicConstrainedTypes: Boolean, +) : WrappingSymbolProvider(base) { + private val constraintViolationName = "ConstraintViolation" + + private fun constraintViolationSymbolForCollectionOrMapOrUnionShape(shape: Shape): Symbol { + check(shape is CollectionShape || shape is MapShape || shape is UnionShape) + + val symbol = base.toSymbol(shape) + val constraintViolationNamespace = + "${symbol.namespace.let { it.ifEmpty { "crate::${Models.namespace}" } }}::${ + RustReservedWords.escapeIfNeeded( + shape.contextName(serviceShape).toSnakeCase(), + ) + }" + val rustType = RustType.Opaque(constraintViolationName, constraintViolationNamespace) + return Symbol.builder() + .rustType(rustType) + .name(rustType.name) + .namespace(rustType.namespace, "::") + .definitionFile(symbol.definitionFile) + .build() + } + + override fun toSymbol(shape: Shape): Symbol { + check(shape.canReachConstrainedShape(model, base)) + + return when (shape) { + is MapShape, is CollectionShape, is UnionShape -> { + constraintViolationSymbolForCollectionOrMapOrUnionShape(shape) + } + is StructureShape -> { + val builderSymbol = shape.serverBuilderSymbol(base, pubCrate = !publicConstrainedTypes) + + val namespace = builderSymbol.namespace + val rustType = RustType.Opaque(constraintViolationName, namespace) + Symbol.builder() + .rustType(rustType) + .name(rustType.name) + .namespace(rustType.namespace, "::") + .definitionFile(builderSymbol.definitionFile) + .build() + } + is StringShape -> { + val namespace = "crate::${Models.namespace}::${ + RustReservedWords.escapeIfNeeded( + shape.contextName(serviceShape).toSnakeCase(), + ) + }" + val rustType = RustType.Opaque(constraintViolationName, namespace) + Symbol.builder() + .rustType(rustType) + .name(rustType.name) + .namespace(rustType.namespace, "::") + .definitionFile(Models.filename) + .build() + } + else -> TODO("Constraint traits on other shapes not implemented yet: $shape") + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/Constraints.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/Constraints.kt new file mode 100644 index 0000000000..82102be18a --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/Constraints.kt @@ -0,0 +1,133 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import software.amazon.smithy.codegen.core.SymbolProvider +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.neighbor.Walker +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.SimpleShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.model.traits.LengthTrait +import software.amazon.smithy.model.traits.PatternTrait +import software.amazon.smithy.model.traits.RangeTrait +import software.amazon.smithy.model.traits.RequiredTrait +import software.amazon.smithy.model.traits.UniqueItemsTrait +import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE +import software.amazon.smithy.rust.codegen.core.util.hasTrait + +/** + * This file contains utilities to work with constrained shapes. + */ + +/** + * Whether the shape has any trait that could cause a request to be rejected with a constraint violation, _whether + * we support it or not_. + */ +fun Shape.hasConstraintTrait() = + hasTrait() || + hasTrait() || + hasTrait() || + hasTrait() || + hasTrait() || + hasTrait() + +/** + * We say a shape is _directly_ constrained if: + * + * - it has a constraint trait, or; + * - in the case of it being an aggregate shape, one of its member shapes has a constraint trait. + * + * Note that an aggregate shape whose member shapes do not have constraint traits but that has a member whose target is + * a constrained shape is _not_ directly constrained. + * + * At the moment only a subset of constraint traits are implemented on a subset of shapes; that's why we match against + * a subset of shapes in each arm, and check for a subset of constraint traits attached to the shape in the arm's + * (with these subsets being smaller than what [the spec] accounts for). + * + * [the spec]: https://awslabs.github.io/smithy/2.0/spec/constraint-traits.html + */ +fun Shape.isDirectlyConstrained(symbolProvider: SymbolProvider): Boolean = when (this) { + is StructureShape -> { + // TODO(https://github.com/awslabs/smithy-rs/issues/1302, https://github.com/awslabs/smithy/issues/1179): + // The only reason why the functions in this file have + // to take in a `SymbolProvider` is because non-`required` blob streaming members are interpreted as + // `required`, so we can't use `member.isOptional` here. + this.members().map { symbolProvider.toSymbol(it) }.any { !it.isOptional() } + } + is MapShape -> this.hasTrait() + is StringShape -> this.hasTrait() || this.hasTrait() + else -> false +} + +fun MemberShape.hasConstraintTraitOrTargetHasConstraintTrait(model: Model, symbolProvider: SymbolProvider): Boolean = + this.isDirectlyConstrained(symbolProvider) || (model.expectShape(this.target).isDirectlyConstrained(symbolProvider)) + +fun Shape.isTransitivelyButNotDirectlyConstrained(model: Model, symbolProvider: SymbolProvider): Boolean = + !this.isDirectlyConstrained(symbolProvider) && this.canReachConstrainedShape(model, symbolProvider) + +fun Shape.canReachConstrainedShape(model: Model, symbolProvider: SymbolProvider): Boolean = + if (this is MemberShape) { + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) Constraint traits on member shapes are not implemented + // yet. Also, note that a walker over a member shape can, perhaps counterintuitively, reach the _containing_ shape, + // so we can't simply delegate to the `else` branch when we implement them. + this.targetCanReachConstrainedShape(model, symbolProvider) + } else { + Walker(model).walkShapes(this).toSet().any { it.isDirectlyConstrained(symbolProvider) } + } + +fun MemberShape.targetCanReachConstrainedShape(model: Model, symbolProvider: SymbolProvider): Boolean = + model.expectShape(this.target).canReachConstrainedShape(model, symbolProvider) + +fun Shape.hasPublicConstrainedWrapperTupleType(model: Model, publicConstrainedTypes: Boolean): Boolean = when (this) { + is MapShape -> publicConstrainedTypes && this.hasTrait() + is StringShape -> !this.hasTrait() && (publicConstrainedTypes && this.hasTrait()) + is MemberShape -> model.expectShape(this.target).hasPublicConstrainedWrapperTupleType(model, publicConstrainedTypes) + else -> false +} + +fun Shape.wouldHaveConstrainedWrapperTupleTypeWerePublicConstrainedTypesEnabled(model: Model): Boolean = + hasPublicConstrainedWrapperTupleType(model, true) + +/** + * Helper function to determine whether a shape will map to a _public_ constrained wrapper tuple type. + * + * This function is used in core code generators, so it takes in a [CodegenContext] that is downcast + * to [ServerCodegenContext] when generating servers. + */ +fun workingWithPublicConstrainedWrapperTupleType(shape: Shape, model: Model, publicConstrainedTypes: Boolean): Boolean = + shape.hasPublicConstrainedWrapperTupleType(model, publicConstrainedTypes) + +/** + * Returns whether a shape's type _name_ contains a non-public type when `publicConstrainedTypes` is `false`. + * + * For example, a `Vec` contains a non-public type, because `crate::model::LengthString` + * is `pub(crate)` when `publicConstrainedTypes` is `false` + * + * Note that a structure shape's type _definition_ may contain non-public types, but its _name_ is always public. + * + * Note how we short-circuit on `publicConstrainedTypes = true`, but we still require it to be passed in instead of laying + * the responsibility on the caller, for API safety usage. + */ +fun Shape.typeNameContainsNonPublicType( + model: Model, + symbolProvider: SymbolProvider, + publicConstrainedTypes: Boolean, +): Boolean = !publicConstrainedTypes && when (this) { + is SimpleShape -> wouldHaveConstrainedWrapperTupleTypeWerePublicConstrainedTypesEnabled(model) + is MemberShape -> model.expectShape(this.target).typeNameContainsNonPublicType(model, symbolProvider, publicConstrainedTypes) + is CollectionShape -> this.canReachConstrainedShape(model, symbolProvider) + is MapShape -> this.canReachConstrainedShape(model, symbolProvider) + is StructureShape, is UnionShape -> false + else -> UNREACHABLE("the above arms should be exhaustive, but we received shape: $this") +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/LengthTraitValidationErrorMessage.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/LengthTraitValidationErrorMessage.kt new file mode 100644 index 0000000000..b15b2dc8f0 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/LengthTraitValidationErrorMessage.kt @@ -0,0 +1,21 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import software.amazon.smithy.model.traits.LengthTrait + +fun LengthTrait.validationErrorMessage(): String { + val beginning = "Value with length {} at '{}' failed to satisfy constraint: Member must have length " + val ending = if (this.min.isPresent && this.max.isPresent) { + "between ${this.min.get()} and ${this.max.get()}, inclusive" + } else if (this.min.isPresent) ( + "greater than or equal to ${this.min.get()}" + ) else { + check(this.max.isPresent) + "less than or equal to ${this.max.get()}" + } + return "$beginning$ending" +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProvider.kt new file mode 100644 index 0000000000..e63e18c7ac --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProvider.kt @@ -0,0 +1,124 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.knowledge.NullableIndex +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.SimpleShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.smithy.Constrained +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.handleOptionality +import software.amazon.smithy.rust.codegen.core.smithy.handleRustBoxing +import software.amazon.smithy.rust.codegen.core.smithy.rustType +import software.amazon.smithy.rust.codegen.core.util.PANIC +import software.amazon.smithy.rust.codegen.core.util.toPascalCase +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase + +/** + * The [PubCrateConstrainedShapeSymbolProvider] returns, for a given + * _transitively but not directly_ constrained shape, a symbol whose Rust type + * can hold the constrained values. + * + * For collection and map shapes, this type is a [RustType.Opaque] wrapper + * tuple newtype holding a container over the inner constrained type. For + * member shapes, it's whatever their target shape resolves to. + * + * The class name is prefixed with `PubCrate` because the symbols it returns + * have associated types that are generated as `pub(crate)`. See the + * `PubCrate*Generator` classes to see how these types are generated. + * + * It is important that this symbol provider does _not_ wrap + * [ConstrainedShapeSymbolProvider], since otherwise it will eventually + * delegate to it and generate a symbol with a `pub` type. + * + * Note simple shapes cannot be transitively and not directly constrained at + * the same time, so this symbol provider is only implemented for aggregate shapes. + * The symbol provider will intentionally crash in such a case to avoid the caller + * incorrectly using it. + * + * Note also that for the purposes of this symbol provider, a member shape is + * transitively but not directly constrained only in the case where it itself + * is not directly constrained and its target also is not directly constrained. + * + * If the shape is _directly_ constrained, use [ConstrainedShapeSymbolProvider] + * instead. + */ +class PubCrateConstrainedShapeSymbolProvider( + private val base: RustSymbolProvider, + private val model: Model, + private val serviceShape: ServiceShape, +) : WrappingSymbolProvider(base) { + private val nullableIndex = NullableIndex.of(model) + + private fun constrainedSymbolForCollectionOrMapShape(shape: Shape): Symbol { + check(shape is CollectionShape || shape is MapShape) + + val name = constrainedTypeNameForCollectionOrMapShape(shape, serviceShape) + val namespace = "crate::${Constrained.namespace}::${RustReservedWords.escapeIfNeeded(name.toSnakeCase())}" + val rustType = RustType.Opaque(name, namespace) + return Symbol.builder() + .rustType(rustType) + .name(rustType.name) + .namespace(rustType.namespace, "::") + .definitionFile(Constrained.filename) + .build() + } + + private fun errorMessage(shape: Shape) = + "This symbol provider was called with $shape. However, it can only be called with a shape that is transitively constrained." + + override fun toSymbol(shape: Shape): Symbol { + require(shape.isTransitivelyButNotDirectlyConstrained(model, base)) { errorMessage(shape) } + + return when (shape) { + is CollectionShape, is MapShape -> { + constrainedSymbolForCollectionOrMapShape(shape) + } + is MemberShape -> { + require(model.expectShape(shape.container).isStructureShape) { + "This arm is only exercised by `ServerBuilderGenerator`" + } + require(!shape.hasConstraintTraitOrTargetHasConstraintTrait(model, base)) { errorMessage(shape) } + + val targetShape = model.expectShape(shape.target) + + if (targetShape is SimpleShape) { + base.toSymbol(shape) + } else { + val targetSymbol = this.toSymbol(targetShape) + // Handle boxing first so we end up with `Option>`, not `Box>`. + handleOptionality(handleRustBoxing(targetSymbol, shape), shape, nullableIndex, base.config().nullabilityCheckMode) + } + } + is StructureShape, is UnionShape -> { + // Structure shapes and union shapes always generate a [RustType.Opaque] constrained type. + base.toSymbol(shape) + } + else -> { + check(shape is SimpleShape) + // The rest of the shape types are simple shapes, which are impossible to be transitively but not + // directly constrained; directly constrained shapes generate public constrained types. + PANIC(errorMessage(shape)) + } + } + } +} + +fun constrainedTypeNameForCollectionOrMapShape(shape: Shape, serviceShape: ServiceShape): String { + check(shape is CollectionShape || shape is MapShape) + return "${shape.id.getName(serviceShape).toPascalCase()}Constrained" +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstraintViolationSymbolProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstraintViolationSymbolProvider.kt new file mode 100644 index 0000000000..05a8d635a4 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstraintViolationSymbolProvider.kt @@ -0,0 +1,37 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.rustType + +/** + * This is only used when `publicConstrainedTypes` is `false`. + * + * This must wrap [ConstraintViolationSymbolProvider]. + */ +class PubCrateConstraintViolationSymbolProvider( + private val base: ConstraintViolationSymbolProvider, +) : WrappingSymbolProvider(base) { + override fun toSymbol(shape: Shape): Symbol { + val baseSymbol = base.toSymbol(shape) + // If the shape is a structure shape, the module where its builder is hosted when `publicConstrainedTypes` is + // `false` is already suffixed with `_internal`. + if (shape is StructureShape) { + return baseSymbol + } + val baseRustType = baseSymbol.rustType() + val newNamespace = baseSymbol.namespace + "_internal" + return baseSymbol.toBuilder() + .rustType(RustType.Opaque(baseRustType.name, newNamespace)) + .namespace(newNamespace, baseSymbol.namespaceDelimiter) + .build() + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenServerPlugin.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenServerPlugin.kt index ac63fd9eeb..5bbc073431 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenServerPlugin.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenServerPlugin.kt @@ -24,10 +24,12 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.Ser import java.util.logging.Level import java.util.logging.Logger -/** Rust Codegen Plugin - * This is the entrypoint for code generation, triggered by the smithy-build plugin. - * `resources/META-INF.services/software.amazon.smithy.build.SmithyBuildPlugin` refers to this class by name which - * enables the smithy-build plugin to invoke `execute` with all of the Smithy plugin context + models. +/** + * Rust Codegen Plugin + * + * This is the entrypoint for code generation, triggered by the smithy-build plugin. + * `resources/META-INF.services/software.amazon.smithy.build.SmithyBuildPlugin` refers to this class by name which + * enables the smithy-build plugin to invoke `execute` with all of the Smithy plugin context + models. */ class RustCodegenServerPlugin : SmithyBuildPlugin { private val logger = Logger.getLogger(javaClass.name) @@ -51,8 +53,8 @@ class RustCodegenServerPlugin : SmithyBuildPlugin { } companion object { - /** SymbolProvider - * When generating code, smithy types need to be converted into Rust types—that is the core role of the symbol provider + /** + * When generating code, smithy types need to be converted into Rust types—that is the core role of the symbol provider. * * The Symbol provider is composed of a base [SymbolVisitor] which handles the core functionality, then is layered * with other symbol providers, documented inline, to handle the full scope of Smithy types. @@ -61,12 +63,13 @@ class RustCodegenServerPlugin : SmithyBuildPlugin { model: Model, serviceShape: ServiceShape, symbolVisitorConfig: SymbolVisitorConfig, + constrainedTypes: Boolean = true, ) = SymbolVisitor(model, serviceShape = serviceShape, config = symbolVisitorConfig) + // Generate public constrained types for directly constrained shapes. + .let { if (constrainedTypes) ConstrainedShapeSymbolProvider(it, model, serviceShape) else it } // Generate different types for EventStream shapes (e.g. transcribe streaming) - .let { - EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model, CodegenTarget.SERVER) - } + .let { EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model, CodegenTarget.SERVER) } // Generate [ByteStream] instead of `Blob` for streaming binary shapes (e.g. S3 GetObject) .let { StreamingShapeSymbolProvider(it, model) } // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenContext.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenContext.kt index 0cc39ac647..a0ad38f04f 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenContext.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenContext.kt @@ -27,6 +27,10 @@ data class ServerCodegenContext( override val serviceShape: ServiceShape, override val protocol: ShapeId, override val settings: ServerRustSettings, + val unconstrainedShapeSymbolProvider: UnconstrainedShapeSymbolProvider, + val constrainedShapeSymbolProvider: RustSymbolProvider, + val constraintViolationSymbolProvider: ConstraintViolationSymbolProvider, + val pubCrateConstrainedShapeSymbolProvider: PubCrateConstrainedShapeSymbolProvider, ) : CodegenContext( model, symbolProvider, serviceShape, protocol, settings, CodegenTarget.SERVER, ) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt index de10207c53..c0a68a26b7 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt @@ -6,25 +6,33 @@ package software.amazon.smithy.rust.codegen.server.smithy import software.amazon.smithy.build.PluginContext +import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.NullableIndex import software.amazon.smithy.model.neighbor.Walker +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.SetShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.ShapeVisitor import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.model.traits.LengthTrait import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.rust.codegen.client.smithy.customize.RustCodegenDecorator import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.smithy.Constrained import software.amazon.smithy.rust.codegen.core.smithy.CoreRustSettings +import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule import software.amazon.smithy.rust.codegen.core.smithy.RustCrate -import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig -import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator +import software.amazon.smithy.rust.codegen.core.smithy.Unconstrained import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock @@ -33,13 +41,29 @@ import software.amazon.smithy.rust.codegen.core.smithy.transformers.EventStreamN import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer import software.amazon.smithy.rust.codegen.core.util.CommandFailed -import software.amazon.smithy.rust.codegen.core.util.getTrait +import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.runCommand +import software.amazon.smithy.rust.codegen.server.smithy.generators.ConstrainedMapGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.ConstrainedStringGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.ConstrainedTraitForEnumGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.MapConstraintViolationGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.PubCrateConstrainedCollectionGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.PubCrateConstrainedMapGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGeneratorWithoutPublicConstrainedTypes import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerEnumGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerServiceGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerStructureConstrainedTraitImpl +import software.amazon.smithy.rust.codegen.server.smithy.generators.UnconstrainedCollectionGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.UnconstrainedMapGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.UnconstrainedUnionGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerProtocolLoader +import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput +import software.amazon.smithy.rust.codegen.server.smithy.transformers.AttachValidationExceptionToConstrainedOperationInputsInAllowList +import software.amazon.smithy.rust.codegen.server.smithy.transformers.RemoveEbsModelValidationException +import software.amazon.smithy.rust.codegen.server.smithy.transformers.ShapesReachableFromOperationInputTagger import java.util.logging.Logger val DefaultServerPublicModules = setOf( @@ -47,7 +71,6 @@ val DefaultServerPublicModules = setOf( RustModule.Model, RustModule.Input, RustModule.Output, - RustModule.Config, ).associateBy { it.name } /** @@ -60,15 +83,18 @@ open class ServerCodegenVisitor( ) : ShapeVisitor.Default() { protected val logger = Logger.getLogger(javaClass.name) - protected val settings = ServerRustSettings.from(context.model, context.settings) + protected var settings = ServerRustSettings.from(context.model, context.settings) - protected var symbolProvider: RustSymbolProvider protected var rustCrate: RustCrate private val fileManifest = context.fileManifest protected var model: Model protected var codegenContext: ServerCodegenContext protected var protocolGeneratorFactory: ProtocolGeneratorFactory protected var protocolGenerator: ServerProtocolGenerator + private val unconstrainedModule = + RustModule.private(Unconstrained.namespace, "Unconstrained types for constrained shapes.") + private val constrainedModule = + RustModule.private(Constrained.namespace, "Constrained types for constrained shapes.") init { val symbolVisitorConfig = @@ -77,6 +103,7 @@ open class ServerCodegenVisitor( renameExceptions = false, nullabilityCheckMode = NullableIndex.CheckMode.SERVER, ) + val baseModel = baselineTransform(context.model) val service = settings.getService(baseModel) val (protocol, generator) = @@ -88,18 +115,30 @@ open class ServerCodegenVisitor( ) .protocolFor(context.model, service) protocolGeneratorFactory = generator + model = codegenDecorator.transformModel(service, baseModel) - symbolProvider = RustCodegenServerPlugin.baseSymbolProvider(model, service, symbolVisitorConfig) + + val serverSymbolProviders = ServerSymbolProviders.from( + model, + service, + symbolVisitorConfig, + settings.codegenConfig.publicConstrainedTypes, + RustCodegenServerPlugin::baseSymbolProvider, + ) codegenContext = ServerCodegenContext( model, - symbolProvider, + serverSymbolProviders.symbolProvider, service, protocol, settings, + serverSymbolProviders.unconstrainedShapeSymbolProvider, + serverSymbolProviders.constrainedShapeSymbolProvider, + serverSymbolProviders.constraintViolationSymbolProvider, + serverSymbolProviders.pubCrateConstrainedShapeSymbolProvider, ) - rustCrate = RustCrate(context.fileManifest, symbolProvider, DefaultServerPublicModules, settings.codegenConfig) + rustCrate = RustCrate(context.fileManifest, codegenContext.symbolProvider, DefaultServerPublicModules, settings.codegenConfig) protocolGenerator = protocolGeneratorFactory.buildProtocolGenerator(codegenContext) } @@ -117,6 +156,13 @@ open class ServerCodegenVisitor( .let(RecursiveShapeBoxer::transform) // Normalize operations by adding synthetic input and output shapes to every operation .let(OperationNormalizer::transform) + // Remove the EBS model's own `ValidationException`, which collides with `smithy.framework#ValidationException` + .let(RemoveEbsModelValidationException::transform) + // Attach the `smithy.framework#ValidationException` error to operations whose inputs are constrained, + // if they belong to a service in an allowlist + .let(AttachValidationExceptionToConstrainedOperationInputsInAllowList::transform) + // Tag aggregate shapes reachable from operation input + .let(ShapesReachableFromOperationInputTagger::transform) // Normalize event stream operations .let(EventStreamNormalizer::transform) @@ -139,9 +185,26 @@ open class ServerCodegenVisitor( */ fun execute() { val service = settings.getService(model) - logger.info( + logger.warning( "[rust-server-codegen] Generating Rust server for service $service, protocol ${codegenContext.protocol}", ) + + for (validationResult in listOf( + validateOperationsWithConstrainedInputHaveValidationExceptionAttached( + model, + service, + ), + validateUnsupportedConstraints(model, service, codegenContext.settings.codegenConfig), + )) { + for (logMessage in validationResult.messages) { + // TODO(https://github.com/awslabs/smithy-rs/issues/1756): These are getting duplicated. + logger.log(logMessage.level, logMessage.message) + } + if (validationResult.shouldAbort) { + throw CodegenException("Unsupported constraints feature used; see error messages above for resolution") + } + } + val serviceShapes = Walker(model).walkShapes(service) serviceShapes.forEach { it.accept(this) } codegenDecorator.extras(codegenContext, rustCrate) @@ -159,7 +222,7 @@ open class ServerCodegenVisitor( timeout = settings.codegenConfig.formatTimeoutSeconds.toLong(), ) } catch (err: CommandFailed) { - logger.warning( + logger.info( "[rust-server-codegen] Failed to run cargo fmt: [${service.id}]\n${err.output}", ) } @@ -180,12 +243,110 @@ open class ServerCodegenVisitor( override fun structureShape(shape: StructureShape) { logger.info("[rust-server-codegen] Generating a structure $shape") rustCrate.useShapeWriter(shape) { - StructureGenerator(model, symbolProvider, this, shape).render(CodegenTarget.SERVER) - val builderGenerator = - BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape) - builderGenerator.render(this) - this.implBlock(shape, symbolProvider) { - builderGenerator.renderConvenienceMethod(this) + StructureGenerator(model, codegenContext.symbolProvider, this, shape).render(CodegenTarget.SERVER) + + renderStructureShapeBuilder(shape, this) + } + } + + protected fun renderStructureShapeBuilder( + shape: StructureShape, + writer: RustWriter, + ) { + if (codegenContext.settings.codegenConfig.publicConstrainedTypes || shape.isReachableFromOperationInput()) { + val serverBuilderGenerator = ServerBuilderGenerator(codegenContext, shape) + serverBuilderGenerator.render(writer) + + if (codegenContext.settings.codegenConfig.publicConstrainedTypes) { + writer.implBlock(shape, codegenContext.symbolProvider) { + serverBuilderGenerator.renderConvenienceMethod(this) + } + } + } + + if (shape.isReachableFromOperationInput()) { + ServerStructureConstrainedTraitImpl( + codegenContext.symbolProvider, + codegenContext.settings.codegenConfig.publicConstrainedTypes, + shape, + writer, + ).render() + } + + if (!codegenContext.settings.codegenConfig.publicConstrainedTypes) { + val serverBuilderGeneratorWithoutPublicConstrainedTypes = + ServerBuilderGeneratorWithoutPublicConstrainedTypes(codegenContext, shape) + serverBuilderGeneratorWithoutPublicConstrainedTypes.render(writer) + + writer.implBlock(shape, codegenContext.symbolProvider) { + serverBuilderGeneratorWithoutPublicConstrainedTypes.renderConvenienceMethod(this) + } + } + } + + override fun listShape(shape: ListShape) = collectionShape(shape) + override fun setShape(shape: SetShape) = collectionShape(shape) + + private fun collectionShape(shape: CollectionShape) { + if (shape.isReachableFromOperationInput() && shape.canReachConstrainedShape( + model, + codegenContext.symbolProvider, + ) + ) { + logger.info("[rust-server-codegen] Generating an unconstrained type for collection shape $shape") + rustCrate.withModule(unconstrainedModule) unconstrainedModuleWriter@{ + rustCrate.withModule(ModelsModule) modelsModuleWriter@{ + UnconstrainedCollectionGenerator( + codegenContext, + this@unconstrainedModuleWriter, + this@modelsModuleWriter, + shape, + ).render() + } + } + + logger.info("[rust-server-codegen] Generating a constrained type for collection shape $shape") + rustCrate.withModule(constrainedModule) { + PubCrateConstrainedCollectionGenerator(codegenContext, this, shape).render() + } + } + } + + override fun mapShape(shape: MapShape) { + val renderUnconstrainedMap = + shape.isReachableFromOperationInput() && shape.canReachConstrainedShape( + model, + codegenContext.symbolProvider, + ) + if (renderUnconstrainedMap) { + logger.info("[rust-server-codegen] Generating an unconstrained type for map $shape") + rustCrate.withModule(unconstrainedModule) { + UnconstrainedMapGenerator(codegenContext, this, shape).render() + } + + if (!shape.isDirectlyConstrained(codegenContext.symbolProvider)) { + logger.info("[rust-server-codegen] Generating a constrained type for map $shape") + rustCrate.withModule(constrainedModule) { + PubCrateConstrainedMapGenerator(codegenContext, this, shape).render() + } + } + } + + val isDirectlyConstrained = shape.isDirectlyConstrained(codegenContext.symbolProvider) + if (isDirectlyConstrained) { + rustCrate.withModule(ModelsModule) { + ConstrainedMapGenerator( + codegenContext, + this, + shape, + if (renderUnconstrainedMap) codegenContext.unconstrainedShapeSymbolProvider.toSymbol(shape) else null, + ).render() + } + } + + if (isDirectlyConstrained || renderUnconstrainedMap) { + rustCrate.withModule(ModelsModule) { + MapConstraintViolationGenerator(codegenContext, this, shape).render() } } } @@ -196,10 +357,36 @@ open class ServerCodegenVisitor( * Although raw strings require no code generation, enums are actually [EnumTrait] applied to string shapes. */ override fun stringShape(shape: StringShape) { - logger.info("[rust-server-codegen] Generating an enum $shape") - shape.getTrait()?.also { enum -> + fun serverEnumGeneratorFactory(codegenContext: ServerCodegenContext, writer: RustWriter, shape: StringShape) = + ServerEnumGenerator(codegenContext, writer, shape) + stringShape(shape, ::serverEnumGeneratorFactory) + } + + protected fun stringShape( + shape: StringShape, + enumShapeGeneratorFactory: (codegenContext: ServerCodegenContext, writer: RustWriter, shape: StringShape) -> ServerEnumGenerator, + ) { + if (shape.hasTrait()) { + logger.info("[rust-server-codegen] Generating an enum $shape") rustCrate.useShapeWriter(shape) { - ServerEnumGenerator(model, symbolProvider, this, shape, enum, codegenContext.runtimeConfig).render() + enumShapeGeneratorFactory(codegenContext, this, shape).render() + ConstrainedTraitForEnumGenerator(model, codegenContext.symbolProvider, this, shape).render() + } + } + + if (shape.hasTrait() && shape.hasTrait()) { + logger.warning( + """ + String shape $shape has an `enum` trait and the `length` trait. This is valid according to the Smithy + IDL v1 spec, but it's unclear what the semantics are. In any case, the Smithy core libraries should enforce the + constraints (which it currently does not), not each code generator. + See https://github.com/awslabs/smithy/issues/1121f for more information. + """.trimIndent().replace("\n", " "), + ) + } else if (!shape.hasTrait() && shape.isDirectlyConstrained(codegenContext.symbolProvider)) { + logger.info("[rust-server-codegen] Generating a constrained string $shape") + rustCrate.withModule(ModelsModule) { + ConstrainedStringGenerator(codegenContext, this, shape).render() } } } @@ -212,9 +399,27 @@ open class ServerCodegenVisitor( * This function _does not_ generate any serializers. */ override fun unionShape(shape: UnionShape) { - logger.info("[rust-server-codegen] Generating an union $shape") + logger.info("[rust-server-codegen] Generating an union shape $shape") rustCrate.useShapeWriter(shape) { - UnionGenerator(model, symbolProvider, this, shape, renderUnknownVariant = false).render() + UnionGenerator(model, codegenContext.symbolProvider, this, shape, renderUnknownVariant = false).render() + } + + if (shape.isReachableFromOperationInput() && shape.canReachConstrainedShape( + model, + codegenContext.symbolProvider, + ) + ) { + logger.info("[rust-server-codegen] Generating an unconstrained type for union shape $shape") + rustCrate.withModule(unconstrainedModule) unconstrainedModuleWriter@{ + rustCrate.withModule(ModelsModule) modelsModuleWriter@{ + UnconstrainedUnionGenerator( + codegenContext, + this@unconstrainedModuleWriter, + this@modelsModuleWriter, + shape, + ).render() + } + } } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt index 4b7f18ba89..d1d74d80bf 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt @@ -24,9 +24,6 @@ object ServerRuntimeType { fun Router(runtimeConfig: RuntimeConfig) = RuntimeType("Router", ServerCargoDependency.SmithyHttpServer(runtimeConfig), "${runtimeConfig.crateSrcPrefix}_http_server::routing") - fun RequestSpecModule(runtimeConfig: RuntimeConfig) = - RuntimeType("request_spec", ServerCargoDependency.SmithyHttpServer(runtimeConfig), "${runtimeConfig.crateSrcPrefix}_http_server::routing") - fun OperationHandler(runtimeConfig: RuntimeConfig) = forInlineDependency(ServerInlineDependency.serverOperationHandler(runtimeConfig)) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRustSettings.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRustSettings.kt index d9a8389600..dbfc8356a2 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRustSettings.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRustSettings.kt @@ -74,20 +74,35 @@ data class ServerRustSettings( } } +/** + * [publicConstrainedTypes]: Generate constrained wrapper newtypes for constrained shapes + * [ignoreUnsupportedConstraints]: Generate model even though unsupported constraints are present + */ data class ServerCodegenConfig( override val formatTimeoutSeconds: Int = defaultFormatTimeoutSeconds, override val debugMode: Boolean = defaultDebugMode, + val publicConstrainedTypes: Boolean = defaultPublicConstrainedTypes, + val ignoreUnsupportedConstraints: Boolean = defaultIgnoreUnsupportedConstraints, ) : CoreCodegenConfig( formatTimeoutSeconds, debugMode, ) { companion object { - // Note `node` is unused, because at the moment `ServerCodegenConfig` has the same properties as - // `CodegenConfig`. In the future, the server will have server-specific codegen options just like the client - // does. + private const val defaultPublicConstrainedTypes = true + private const val defaultIgnoreUnsupportedConstraints = false + fun fromCodegenConfigAndNode(coreCodegenConfig: CoreCodegenConfig, node: Optional) = - ServerCodegenConfig( - formatTimeoutSeconds = coreCodegenConfig.formatTimeoutSeconds, - debugMode = coreCodegenConfig.debugMode, - ) + if (node.isPresent) { + ServerCodegenConfig( + formatTimeoutSeconds = coreCodegenConfig.formatTimeoutSeconds, + debugMode = coreCodegenConfig.debugMode, + publicConstrainedTypes = node.get().getBooleanMemberOrDefault("publicConstrainedTypes", defaultPublicConstrainedTypes), + ignoreUnsupportedConstraints = node.get().getBooleanMemberOrDefault("ignoreUnsupportedConstraints", defaultIgnoreUnsupportedConstraints), + ) + } else { + ServerCodegenConfig( + formatTimeoutSeconds = coreCodegenConfig.formatTimeoutSeconds, + debugMode = coreCodegenConfig.debugMode, + ) + } } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerSymbolProviders.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerSymbolProviders.kt new file mode 100644 index 0000000000..e2b77c90fd --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerSymbolProviders.kt @@ -0,0 +1,65 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig + +/** + * Just a handy class to centralize initialization all the symbol providers required by the server code generators, to + * make the init blocks of the codegen visitors ([ServerCodegenVisitor] and [PythonServerCodegenVisitor]), and the + * unit test setup code, shorter and DRYer. + */ +class ServerSymbolProviders private constructor( + val symbolProvider: RustSymbolProvider, + val unconstrainedShapeSymbolProvider: UnconstrainedShapeSymbolProvider, + val constrainedShapeSymbolProvider: RustSymbolProvider, + val constraintViolationSymbolProvider: ConstraintViolationSymbolProvider, + val pubCrateConstrainedShapeSymbolProvider: PubCrateConstrainedShapeSymbolProvider, +) { + companion object { + fun from( + model: Model, + service: ServiceShape, + symbolVisitorConfig: SymbolVisitorConfig, + publicConstrainedTypes: Boolean, + baseSymbolProviderFactory: (model: Model, service: ServiceShape, symbolVisitorConfig: SymbolVisitorConfig, publicConstrainedTypes: Boolean) -> RustSymbolProvider, + ): ServerSymbolProviders { + val baseSymbolProvider = baseSymbolProviderFactory(model, service, symbolVisitorConfig, publicConstrainedTypes) + return ServerSymbolProviders( + symbolProvider = baseSymbolProvider, + constrainedShapeSymbolProvider = baseSymbolProviderFactory( + model, + service, + symbolVisitorConfig, + true, + ), + unconstrainedShapeSymbolProvider = UnconstrainedShapeSymbolProvider( + baseSymbolProviderFactory( + model, + service, + symbolVisitorConfig, + false, + ), + model, service, publicConstrainedTypes, + ), + pubCrateConstrainedShapeSymbolProvider = PubCrateConstrainedShapeSymbolProvider( + baseSymbolProvider, + model, + service, + ), + constraintViolationSymbolProvider = ConstraintViolationSymbolProvider( + baseSymbolProvider, + model, + service, + publicConstrainedTypes, + ), + ) + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/UnconstrainedShapeSymbolProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/UnconstrainedShapeSymbolProvider.kt new file mode 100644 index 0000000000..9fa2182e6b --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/UnconstrainedShapeSymbolProvider.kt @@ -0,0 +1,166 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.knowledge.NullableIndex +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.smithy.Default +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.Unconstrained +import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.handleOptionality +import software.amazon.smithy.rust.codegen.core.smithy.handleRustBoxing +import software.amazon.smithy.rust.codegen.core.smithy.rustType +import software.amazon.smithy.rust.codegen.core.smithy.setDefault +import software.amazon.smithy.rust.codegen.core.smithy.symbolBuilder +import software.amazon.smithy.rust.codegen.core.util.toPascalCase +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase +import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderSymbol + +/** + * The [UnconstrainedShapeSymbolProvider] returns, _for a given constrained + * shape_, a symbol whose Rust type can hold the corresponding unconstrained + * values. + * + * For collection and map shapes, this type is a [RustType.Opaque] wrapper + * tuple newtype holding a container over the inner unconstrained type. For + * structure shapes, it's their builder type. For union shapes, it's an enum + * whose variants are the corresponding unconstrained variants. For simple + * shapes, it's whatever the regular base symbol provider returns. + * + * So, for example, given the following model: + * + * ```smithy + * list ListA { + * member: ListB + * } + * + * list ListB { + * member: Structure + * } + * + * structure Structure { + * @required + * string: String + * } + * ``` + * + * `ListB` is not _directly_ constrained, but it is constrained, because it + * holds `Structure`s, that are constrained. So the corresponding unconstrained + * symbol has Rust type `struct + * ListBUnconstrained(std::vec::Vec)`. + * Likewise, `ListA` is also constrained. Its unconstrained symbol has Rust + * type `struct ListAUnconstrained(std::vec::Vec)`. + * + * For an _unconstrained_ shape and for simple shapes, this symbol provider + * delegates to the base symbol provider. It is therefore important that this + * symbol provider _not_ wrap [PublicConstrainedShapeSymbolProvider] (from the + * `codegen-server` subproject), because that symbol provider will return a + * constrained type for shapes that have constraint traits attached. + */ +class UnconstrainedShapeSymbolProvider( + private val base: RustSymbolProvider, + private val model: Model, + private val serviceShape: ServiceShape, + private val publicConstrainedTypes: Boolean, +) : WrappingSymbolProvider(base) { + private val nullableIndex = NullableIndex.of(model) + + private fun unconstrainedSymbolForCollectionOrMapOrUnionShape(shape: Shape): Symbol { + check(shape is CollectionShape || shape is MapShape || shape is UnionShape) + + val name = unconstrainedTypeNameForCollectionOrMapOrUnionShape(shape, serviceShape) + val namespace = "crate::${Unconstrained.namespace}::${RustReservedWords.escapeIfNeeded(name.toSnakeCase())}" + val rustType = RustType.Opaque(name, namespace) + return Symbol.builder() + .rustType(rustType) + .name(rustType.name) + .namespace(rustType.namespace, "::") + .definitionFile(Unconstrained.filename) + .build() + } + + override fun toSymbol(shape: Shape): Symbol = + when (shape) { + is CollectionShape -> { + if (shape.canReachConstrainedShape(model, base)) { + unconstrainedSymbolForCollectionOrMapOrUnionShape(shape) + } else { + base.toSymbol(shape) + } + } + is MapShape -> { + if (shape.canReachConstrainedShape(model, base)) { + unconstrainedSymbolForCollectionOrMapOrUnionShape(shape) + } else { + base.toSymbol(shape) + } + } + is StructureShape -> { + if (shape.canReachConstrainedShape(model, base)) { + shape.serverBuilderSymbol(base, !publicConstrainedTypes) + } else { + base.toSymbol(shape) + } + } + is UnionShape -> { + if (shape.canReachConstrainedShape(model, base)) { + unconstrainedSymbolForCollectionOrMapOrUnionShape(shape) + } else { + base.toSymbol(shape) + } + } + is MemberShape -> { + // There are only two cases where we use this symbol provider on a member shape. + // + // 1. When generating deserializers for HTTP-bound member shapes. See, for example: + // * how [HttpBindingGenerator] generates deserializers for a member shape with the `httpPrefixHeaders` + // trait targeting a map shape of string keys and values; or + // * how [ServerHttpBoundProtocolGenerator] deserializes for a member shape with the `httpQuery` + // trait targeting a collection shape that can reach a constrained shape. + // + // 2. When generating members for unconstrained unions. See [UnconstrainedUnionGenerator]. + if (shape.targetCanReachConstrainedShape(model, base)) { + val targetShape = model.expectShape(shape.target) + val targetSymbol = this.toSymbol(targetShape) + // Handle boxing first so we end up with `Option>`, not `Box>`. + handleOptionality(handleRustBoxing(targetSymbol, shape), shape, nullableIndex, base.config().nullabilityCheckMode) + } else { + base.toSymbol(shape) + } + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) Constraint traits on member shapes are not + // implemented yet. + } + is StringShape -> { + if (shape.canReachConstrainedShape(model, base)) { + symbolBuilder(shape, RustType.String).setDefault(Default.RustDefault).build() + } else { + base.toSymbol(shape) + } + } + else -> base.toSymbol(shape) + } +} + +/** + * Unconstrained type names are always suffixed with `Unconstrained` for clarity, even though we could dispense with it + * given that they all live inside the `unconstrained` module, so they don't collide with the constrained types. + */ +fun unconstrainedTypeNameForCollectionOrMapOrUnionShape(shape: Shape, serviceShape: ServiceShape): String { + check(shape is CollectionShape || shape is MapShape || shape is UnionShape) + return "${shape.id.getName(serviceShape).toPascalCase()}Unconstrained" +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraints.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraints.kt new file mode 100644 index 0000000000..d487689b20 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraints.kt @@ -0,0 +1,248 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.neighbor.Walker +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.EnumShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.SetShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.model.traits.LengthTrait +import software.amazon.smithy.model.traits.PatternTrait +import software.amazon.smithy.model.traits.RangeTrait +import software.amazon.smithy.model.traits.RequiredTrait +import software.amazon.smithy.model.traits.StreamingTrait +import software.amazon.smithy.model.traits.Trait +import software.amazon.smithy.model.traits.UniqueItemsTrait +import software.amazon.smithy.rust.codegen.core.util.expectTrait +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.inputShape +import software.amazon.smithy.rust.codegen.core.util.orNull +import java.util.logging.Level + +private sealed class UnsupportedConstraintMessageKind { + private val constraintTraitsUberIssue = "https://github.com/awslabs/smithy-rs/issues/1401" + + fun intoLogMessage(ignoreUnsupportedConstraints: Boolean): LogMessage { + fun buildMessage(intro: String, willSupport: Boolean, trackingIssue: String) = + """ + $intro + This is not supported in the smithy-rs server SDK. + ${ if (willSupport) "It will be supported in the future." else "" } + See the tracking issue ($trackingIssue). + If you want to go ahead and generate the server SDK ignoring unsupported constraint traits, set the key `ignoreUnsupportedConstraintTraits` + inside the `runtimeConfig.codegenConfig` JSON object in your `smithy-build.json` to `true`. + """.trimIndent().replace("\n", " ") + + fun buildMessageShapeHasUnsupportedConstraintTrait(shape: Shape, constraintTrait: Trait, trackingIssue: String) = + buildMessage( + "The ${shape.type} shape `${shape.id}` has the constraint trait `${constraintTrait.toShapeId()}` attached.", + willSupport = true, + trackingIssue, + ) + + val level = if (ignoreUnsupportedConstraints) Level.WARNING else Level.SEVERE + + return when (this) { + is UnsupportedConstraintOnMemberShape -> LogMessage( + level, + buildMessageShapeHasUnsupportedConstraintTrait(shape, constraintTrait, constraintTraitsUberIssue), + ) + is UnsupportedConstraintOnShapeReachableViaAnEventStream -> LogMessage( + level, + buildMessage( + """ + The ${shape.type} shape `${shape.id}` has the constraint trait `${constraintTrait.toShapeId()}` attached. + This shape is also part of an event stream; it is unclear what the semantics for constrained shapes in event streams are. + """.trimIndent().replace("\n", " "), + willSupport = false, + "https://github.com/awslabs/smithy/issues/1388", + ), + ) + is UnsupportedLengthTraitOnStreamingBlobShape -> LogMessage( + level, + buildMessage( + """ + The ${shape.type} shape `${shape.id}` has both the `${lengthTrait.toShapeId()}` and `${streamingTrait.toShapeId()}` constraint traits attached. + It is unclear what the semantics for streaming blob shapes are. + """.trimIndent().replace("\n", " "), + willSupport = false, + "https://github.com/awslabs/smithy/issues/1389", + ), + ) + is UnsupportedLengthTraitOnCollectionOrOnBlobShape -> LogMessage( + level, + buildMessageShapeHasUnsupportedConstraintTrait(shape, lengthTrait, constraintTraitsUberIssue), + ) + is UnsupportedPatternTraitOnStringShape -> LogMessage( + level, + buildMessageShapeHasUnsupportedConstraintTrait(shape, patternTrait, constraintTraitsUberIssue), + ) + is UnsupportedRangeTraitOnShape -> LogMessage( + level, + buildMessageShapeHasUnsupportedConstraintTrait(shape, rangeTrait, constraintTraitsUberIssue), + ) + } + } +} +private data class OperationWithConstrainedInputWithoutValidationException(val shape: OperationShape) +private data class UnsupportedConstraintOnMemberShape(val shape: MemberShape, val constraintTrait: Trait) : UnsupportedConstraintMessageKind() +private data class UnsupportedConstraintOnShapeReachableViaAnEventStream(val shape: Shape, val constraintTrait: Trait) : UnsupportedConstraintMessageKind() +private data class UnsupportedLengthTraitOnStreamingBlobShape(val shape: BlobShape, val lengthTrait: LengthTrait, val streamingTrait: StreamingTrait) : UnsupportedConstraintMessageKind() +private data class UnsupportedLengthTraitOnCollectionOrOnBlobShape(val shape: Shape, val lengthTrait: LengthTrait) : UnsupportedConstraintMessageKind() +private data class UnsupportedPatternTraitOnStringShape(val shape: Shape, val patternTrait: PatternTrait) : UnsupportedConstraintMessageKind() +private data class UnsupportedRangeTraitOnShape(val shape: Shape, val rangeTrait: RangeTrait) : UnsupportedConstraintMessageKind() + +data class LogMessage(val level: Level, val message: String) +data class ValidationResult(val shouldAbort: Boolean, val messages: List) + +private val allConstraintTraits = setOf( + LengthTrait::class.java, + PatternTrait::class.java, + RangeTrait::class.java, + UniqueItemsTrait::class.java, + EnumTrait::class.java, + RequiredTrait::class.java, +) +private val unsupportedConstraintsOnMemberShapes = allConstraintTraits - RequiredTrait::class.java + +fun validateOperationsWithConstrainedInputHaveValidationExceptionAttached(model: Model, service: ServiceShape): ValidationResult { + // Traverse the model and error out if an operation uses constrained input, but it does not have + // `ValidationException` attached in `errors`. https://github.com/awslabs/smithy-rs/pull/1199#discussion_r809424783 + // TODO(https://github.com/awslabs/smithy-rs/issues/1401): This check will go away once we add support for + // `disableDefaultValidation` set to `true`, allowing service owners to map from constraint violations to operation errors. + val walker = Walker(model) + val operationsWithConstrainedInputWithoutValidationExceptionSet = walker.walkShapes(service) + .filterIsInstance() + .asSequence() + .filter { operationShape -> + // Walk the shapes reachable via this operation input. + walker.walkShapes(operationShape.inputShape(model)) + .any { it is SetShape || it is EnumShape || it.hasConstraintTrait() } + } + .filter { !it.errors.contains(ShapeId.from("smithy.framework#ValidationException")) } + .map { OperationWithConstrainedInputWithoutValidationException(it) } + .toSet() + + val messages = + operationsWithConstrainedInputWithoutValidationExceptionSet.map { + LogMessage( + Level.SEVERE, + """ + Operation ${it.shape.id} takes in input that is constrained + (https://awslabs.github.io/smithy/2.0/spec/constraint-traits.html), and as such can fail with a validation + exception. You must model this behavior in the operation shape in your model file. + """.trimIndent().replace("\n", "") + + """ + + ```smithy + use smithy.framework#ValidationException + + operation ${it.shape.id.name} { + ... + errors: [..., ValidationException] // <-- Add this. + } + ``` + """.trimIndent(), + ) + } + + return ValidationResult(shouldAbort = messages.any { it.level == Level.SEVERE }, messages) +} + +fun validateUnsupportedConstraints(model: Model, service: ServiceShape, codegenConfig: ServerCodegenConfig): ValidationResult { + // Traverse the model and error out if: + val walker = Walker(model) + + // 1. Constraint traits on member shapes are used. [Constraint trait precedence] has not been implemented yet. + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) + // [Constraint trait precedence]: https://awslabs.github.io/smithy/2.0/spec/model.html#applying-traits + val unsupportedConstraintOnMemberShapeSet = walker + .walkShapes(service) + .asSequence() + .filterIsInstance() + .filterMapShapesToTraits(unsupportedConstraintsOnMemberShapes) + .map { (shape, trait) -> UnsupportedConstraintOnMemberShape(shape as MemberShape, trait) } + .toSet() + + // 2. Constraint traits on streaming blob shapes are used. Their semantics are unclear. + // TODO(https://github.com/awslabs/smithy/issues/1389) + val unsupportedLengthTraitOnStreamingBlobShapeSet = walker + .walkShapes(service) + .asSequence() + .filterIsInstance() + .filter { it.hasTrait() && it.hasTrait() } + .map { UnsupportedLengthTraitOnStreamingBlobShape(it, it.expectTrait(), it.expectTrait()) } + .toSet() + + // 3. Constraint traits in event streams are used. Their semantics are unclear. + // TODO(https://github.com/awslabs/smithy/issues/1388) + val unsupportedConstraintOnShapeReachableViaAnEventStreamSet = walker + .walkShapes(service) + .asSequence() + .filterIsInstance() + .filter { it.hasTrait() } + .flatMap { walker.walkShapes(it) } + .filterMapShapesToTraits(allConstraintTraits) + .map { (shape, trait) -> UnsupportedConstraintOnShapeReachableViaAnEventStream(shape, trait) } + .toSet() + + // 4. Length trait on collection shapes or on blob shapes is used. It has not been implemented yet for these target types. + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) + val unsupportedLengthTraitOnCollectionOrOnBlobShapeSet = walker + .walkShapes(service) + .asSequence() + .filter { it is CollectionShape || it is BlobShape } + .filter { it.hasTrait() } + .map { UnsupportedLengthTraitOnCollectionOrOnBlobShape(it, it.expectTrait()) } + .toSet() + + // 5. Pattern trait on string shapes is used. It has not been implemented yet. + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) + val unsupportedPatternTraitOnStringShapeSet = walker + .walkShapes(service) + .asSequence() + .filterIsInstance() + .filterMapShapesToTraits(setOf(PatternTrait::class.java)) + .map { (shape, patternTrait) -> UnsupportedPatternTraitOnStringShape(shape, patternTrait as PatternTrait) } + .toSet() + + // 6. Range trait on any shape is used. It has not been implemented yet. + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) + val unsupportedRangeTraitOnShapeSet = walker + .walkShapes(service) + .asSequence() + .filterMapShapesToTraits(setOf(RangeTrait::class.java)) + .map { (shape, rangeTrait) -> UnsupportedRangeTraitOnShape(shape, rangeTrait as RangeTrait) } + .toSet() + + val messages = + unsupportedConstraintOnMemberShapeSet.map { it.intoLogMessage(codegenConfig.ignoreUnsupportedConstraints) } + + unsupportedLengthTraitOnStreamingBlobShapeSet.map { it.intoLogMessage(codegenConfig.ignoreUnsupportedConstraints) } + + unsupportedConstraintOnShapeReachableViaAnEventStreamSet.map { it.intoLogMessage(codegenConfig.ignoreUnsupportedConstraints) } + + unsupportedLengthTraitOnCollectionOrOnBlobShapeSet.map { it.intoLogMessage(codegenConfig.ignoreUnsupportedConstraints) } + + unsupportedPatternTraitOnStringShapeSet.map { it.intoLogMessage(codegenConfig.ignoreUnsupportedConstraints) } + + unsupportedRangeTraitOnShapeSet.map { it.intoLogMessage(codegenConfig.ignoreUnsupportedConstraints) } + + return ValidationResult(shouldAbort = messages.any { it.level == Level.SEVERE }, messages) +} + +/** + * Returns a sequence over pairs `(shape, trait)`. + * The returned sequence contains one pair per shape in the input iterable that has attached a trait contained in [traits]. + */ +private fun Sequence.filterMapShapesToTraits(traits: Set>): Sequence> = + this.map { shape -> shape to traits.mapNotNull { shape.getTrait(it).orNull() } } + .flatMap { (shape, traits) -> traits.map { shape to it } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeIteratingOverMapJsonCustomization.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeIteratingOverMapJsonCustomization.kt new file mode 100644 index 0000000000..820f5bc8b7 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeIteratingOverMapJsonCustomization.kt @@ -0,0 +1,38 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.customizations + +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerCustomization +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerSection +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.workingWithPublicConstrainedWrapperTupleType + +/** + * A customization to, just before we iterate over a _constrained_ map shape in a JSON serializer, unwrap the wrapper + * newtype and take a shared reference to the actual `std::collections::HashMap` within it. + */ +class BeforeIteratingOverMapJsonCustomization(private val codegenContext: ServerCodegenContext) : JsonSerializerCustomization() { + override fun section(section: JsonSerializerSection): Writable = when (section) { + is JsonSerializerSection.BeforeIteratingOverMap -> writable { + if (workingWithPublicConstrainedWrapperTupleType( + section.shape, + codegenContext.model, + codegenContext.settings.codegenConfig.publicConstrainedTypes, + ) + ) { + // Note that this particular implementation just so happens to work because when the customization + // is invoked in the JSON serializer, the value expression is guaranteed to be a variable binding name. + // If the expression in the future were to be more complex, we wouldn't be able to write the left-hand + // side of this assignment. + rust("""let ${section.valueExpression.name} = &${section.valueExpression.name}.0;""") + } + } + else -> emptySection + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGenerator.kt new file mode 100644 index 0000000000..677350dd67 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGenerator.kt @@ -0,0 +1,160 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.traits.LengthTrait +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.documentShape +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.util.expectTrait +import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext + +/** + * [ConstrainedMapGenerator] generates a wrapper tuple newtype holding a constrained `std::collections::HashMap`. + * This type can be built from unconstrained values, yielding a `ConstraintViolation` when the input does not satisfy + * the constraints. + * + * The [`length` trait] is the only constraint trait applicable to map shapes. + * + * If [unconstrainedSymbol] is provided, the `MaybeConstrained` trait is implemented for the constrained type, using the + * [unconstrainedSymbol]'s associated type as the associated type for the trait. + * + * [`length` trait]: https://awslabs.github.io/smithy/1.0/spec/core/constraint-traits.html#length-trait + */ +class ConstrainedMapGenerator( + val codegenContext: ServerCodegenContext, + val writer: RustWriter, + val shape: MapShape, + private val unconstrainedSymbol: Symbol? = null, +) { + private val model = codegenContext.model + private val constrainedShapeSymbolProvider = codegenContext.constrainedShapeSymbolProvider + private val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes + private val constraintViolationSymbolProvider = + with(codegenContext.constraintViolationSymbolProvider) { + if (publicConstrainedTypes) { + this + } else { + PubCrateConstraintViolationSymbolProvider(this) + } + } + private val symbolProvider = codegenContext.symbolProvider + + fun render() { + // The `length` trait is the only constraint trait applicable to map shapes. + val lengthTrait = shape.expectTrait() + + val name = constrainedShapeSymbolProvider.toSymbol(shape).name + val inner = "std::collections::HashMap<#{KeySymbol}, #{ValueSymbol}>" + val constraintViolation = constraintViolationSymbolProvider.toSymbol(shape) + + val condition = if (lengthTrait.min.isPresent && lengthTrait.max.isPresent) { + "(${lengthTrait.min.get()}..=${lengthTrait.max.get()}).contains(&length)" + } else if (lengthTrait.min.isPresent) { + "${lengthTrait.min.get()} <= length" + } else { + "length <= ${lengthTrait.max.get()}" + } + + val constrainedTypeVisibility = if (publicConstrainedTypes) { + Visibility.PUBLIC + } else { + Visibility.PUBCRATE + } + val constrainedTypeMetadata = RustMetadata( + Attribute.Derives(setOf(RuntimeType.Debug, RuntimeType.Clone, RuntimeType.PartialEq)), + visibility = constrainedTypeVisibility, + ) + + val codegenScope = arrayOf( + "KeySymbol" to constrainedShapeSymbolProvider.toSymbol(model.expectShape(shape.key.target)), + "ValueSymbol" to constrainedShapeSymbolProvider.toSymbol(model.expectShape(shape.value.target)), + "From" to RuntimeType.From, + "TryFrom" to RuntimeType.TryFrom, + "ConstraintViolation" to constraintViolation, + ) + + writer.documentShape(shape, model, note = rustDocsNote(name)) + constrainedTypeMetadata.render(writer) + writer.rustTemplate("struct $name(pub(crate) $inner);", *codegenScope) + if (constrainedTypeVisibility == Visibility.PUBCRATE) { + Attribute.AllowUnused.render(writer) + } + writer.rustTemplate( + """ + impl $name { + /// ${rustDocsInnerMethod(inner)} + pub fn inner(&self) -> &$inner { + &self.0 + } + + /// ${rustDocsIntoInnerMethod(inner)} + pub fn into_inner(self) -> $inner { + self.0 + } + } + + impl #{TryFrom}<$inner> for $name { + type Error = #{ConstraintViolation}; + + /// ${rustDocsTryFromMethod(name, inner)} + fn try_from(value: $inner) -> Result { + let length = value.len(); + if $condition { + Ok(Self(value)) + } else { + Err(#{ConstraintViolation}::Length(length)) + } + } + } + + impl #{From}<$name> for $inner { + fn from(value: $name) -> Self { + value.into_inner() + } + } + """, + *codegenScope, + ) + + if (!publicConstrainedTypes && isValueConstrained(shape, model, symbolProvider)) { + writer.rustTemplate( + """ + impl #{From}<$name> for #{FullyUnconstrainedSymbol} { + fn from(value: $name) -> Self { + value + .into_inner() + .into_iter() + .map(|(k, v)| (k, v.into())) + .collect() + } + } + """, + *codegenScope, + "FullyUnconstrainedSymbol" to symbolProvider.toSymbol(shape), + ) + } + + if (unconstrainedSymbol != null) { + writer.rustTemplate( + """ + impl #{ConstrainedTrait} for $name { + type Unconstrained = #{UnconstrainedSymbol}; + } + """, + "ConstrainedTrait" to RuntimeType.ConstrainedTrait(), + "UnconstrainedSymbol" to unconstrainedSymbol, + ) + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGeneratorCommon.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGeneratorCommon.kt new file mode 100644 index 0000000000..fb5ce1daee --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGeneratorCommon.kt @@ -0,0 +1,22 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.codegen.core.SymbolProvider +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained + +/** + * Common helper functions used in [UnconstrainedMapGenerator] and [MapConstraintViolationGenerator]. + */ + +fun isKeyConstrained(shape: StringShape, symbolProvider: SymbolProvider) = shape.isDirectlyConstrained(symbolProvider) + +fun isValueConstrained(shape: Shape, model: Model, symbolProvider: SymbolProvider): Boolean = + shape.canReachConstrainedShape(model, symbolProvider) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedShapeGeneratorCommon.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedShapeGeneratorCommon.kt new file mode 100644 index 0000000000..d0d447cdad --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedShapeGeneratorCommon.kt @@ -0,0 +1,24 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +/** + * Functions shared amongst the constrained shape generators, to keep them DRY and consistent. + */ + +fun rustDocsNote(typeName: String) = + "this is a constrained type because its corresponding modeled Smithy shape has one or more " + + "[constraint traits]. Use [`parse`] or [`$typeName::TryFrom`] to construct values of this type." + + "[constraint traits]: https://awslabs.github.io/smithy/1.0/spec/core/constraint-traits.html" + +fun rustDocsTryFromMethod(typeName: String, inner: String) = + "Constructs a `$typeName` from an [`$inner`], failing when the provided value does not satisfy the modeled constraints." + +fun rustDocsInnerMethod(inner: String) = + "Returns an immutable reference to the underlying [`$inner`]." + +fun rustDocsIntoInnerMethod(inner: String) = + "Consumes the value, returning the underlying [`$inner`]." diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt new file mode 100644 index 0000000000..a63801cce4 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt @@ -0,0 +1,183 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.traits.LengthTrait +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.documentShape +import software.amazon.smithy.rust.codegen.core.rustlang.render +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained +import software.amazon.smithy.rust.codegen.core.util.expectTrait +import software.amazon.smithy.rust.codegen.core.util.redactIfNecessary +import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput +import software.amazon.smithy.rust.codegen.server.smithy.validationErrorMessage + +/** + * [ConstrainedStringGenerator] generates a wrapper tuple newtype holding a constrained `String`. + * This type can be built from unconstrained values, yielding a `ConstraintViolation` when the input does not satisfy + * the constraints. + */ +class ConstrainedStringGenerator( + val codegenContext: ServerCodegenContext, + val writer: RustWriter, + val shape: StringShape, +) { + val model = codegenContext.model + val constrainedShapeSymbolProvider = codegenContext.constrainedShapeSymbolProvider + val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes + private val constraintViolationSymbolProvider = + with(codegenContext.constraintViolationSymbolProvider) { + if (publicConstrainedTypes) { + this + } else { + PubCrateConstraintViolationSymbolProvider(this) + } + } + + fun render() { + val lengthTrait = shape.expectTrait() + + val symbol = constrainedShapeSymbolProvider.toSymbol(shape) + val name = symbol.name + val inner = RustType.String.render() + val constraintViolation = constraintViolationSymbolProvider.toSymbol(shape) + + val condition = if (lengthTrait.min.isPresent && lengthTrait.max.isPresent) { + "(${lengthTrait.min.get()}..=${lengthTrait.max.get()}).contains(&length)" + } else if (lengthTrait.min.isPresent) { + "${lengthTrait.min.get()} <= length" + } else { + "length <= ${lengthTrait.max.get()}" + } + + val constrainedTypeVisibility = if (publicConstrainedTypes) { + Visibility.PUBLIC + } else { + Visibility.PUBCRATE + } + val constrainedTypeMetadata = RustMetadata( + Attribute.Derives(setOf(RuntimeType.Debug, RuntimeType.Clone, RuntimeType.PartialEq, RuntimeType.Eq, RuntimeType.Hash)), + visibility = constrainedTypeVisibility, + ) + + // Note that we're using the linear time check `chars().count()` instead of `len()` on the input value, since the + // Smithy specification says the `length` trait counts the number of Unicode code points when applied to string shapes. + // https://awslabs.github.io/smithy/1.0/spec/core/constraint-traits.html#length-trait + writer.documentShape(shape, model, note = rustDocsNote(name)) + constrainedTypeMetadata.render(writer) + writer.rust("struct $name(pub(crate) $inner);") + if (constrainedTypeVisibility == Visibility.PUBCRATE) { + Attribute.AllowUnused.render(writer) + } + writer.rustTemplate( + """ + impl $name { + /// Extracts a string slice containing the entire underlying `String`. + pub fn as_str(&self) -> &str { + &self.0 + } + + /// ${rustDocsInnerMethod(inner)} + pub fn inner(&self) -> &$inner { + &self.0 + } + + /// ${rustDocsIntoInnerMethod(inner)} + pub fn into_inner(self) -> $inner { + self.0 + } + } + + impl #{ConstrainedTrait} for $name { + type Unconstrained = $inner; + } + + impl #{From}<$inner> for #{MaybeConstrained} { + fn from(value: $inner) -> Self { + Self::Unconstrained(value) + } + } + + impl #{Display} for $name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + ${shape.redactIfNecessary(model, "self.0")}.fmt(f) + } + } + + impl #{TryFrom}<$inner> for $name { + type Error = #{ConstraintViolation}; + + /// ${rustDocsTryFromMethod(name, inner)} + fn try_from(value: $inner) -> Result { + let length = value.chars().count(); + if $condition { + Ok(Self(value)) + } else { + Err(#{ConstraintViolation}::Length(length)) + } + } + } + + impl #{From}<$name> for $inner { + fn from(value: $name) -> Self { + value.into_inner() + } + } + """, + "ConstrainedTrait" to RuntimeType.ConstrainedTrait(), + "ConstraintViolation" to constraintViolation, + "MaybeConstrained" to symbol.makeMaybeConstrained(), + "Display" to RuntimeType.Display, + "From" to RuntimeType.From, + "TryFrom" to RuntimeType.TryFrom, + ) + + val constraintViolationModuleName = constraintViolation.namespace.split(constraintViolation.namespaceDelimiter).last() + writer.withModule(RustModule(constraintViolationModuleName, RustMetadata(visibility = constrainedTypeVisibility))) { + rust( + """ + ##[derive(Debug, PartialEq)] + pub enum ${constraintViolation.name} { + Length(usize), + } + """, + ) + + if (shape.isReachableFromOperationInput()) { + rustBlock("impl ${constraintViolation.name}") { + rustBlockTemplate( + "pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField", + "String" to RuntimeType.String, + ) { + rustBlock("match self") { + rust( + """ + Self::Length(length) => crate::model::ValidationExceptionField { + message: format!("${lengthTrait.validationErrorMessage()}", length, &path), + path, + }, + """, + ) + } + } + } + } + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedTraitForEnumGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedTraitForEnumGenerator.kt new file mode 100644 index 0000000000..288065d75c --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedTraitForEnumGenerator.kt @@ -0,0 +1,51 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained +import software.amazon.smithy.rust.codegen.core.util.expectTrait + +/** + * [ConstrainedTraitForEnumGenerator] generates code that implements the [RuntimeType.ConstrainedTrait] trait on an + * enum shape. + */ +class ConstrainedTraitForEnumGenerator( + val model: Model, + val symbolProvider: RustSymbolProvider, + val writer: RustWriter, + val shape: StringShape, +) { + fun render() { + shape.expectTrait() + + val symbol = symbolProvider.toSymbol(shape) + val name = symbol.name + val unconstrainedType = "String" + + writer.rustTemplate( + """ + impl #{ConstrainedTrait} for $name { + type Unconstrained = $unconstrainedType; + } + + impl From<$unconstrainedType> for #{MaybeConstrained} { + fn from(value: $unconstrainedType) -> Self { + Self::Unconstrained(value) + } + } + """, + "ConstrainedTrait" to RuntimeType.ConstrainedTrait(), + "MaybeConstrained" to symbol.makeMaybeConstrained(), + ) + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/MapConstraintViolationGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/MapConstraintViolationGenerator.kt new file mode 100644 index 0000000000..684d83322f --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/MapConstraintViolationGenerator.kt @@ -0,0 +1,121 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.traits.LengthTrait +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.util.getTrait +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput +import software.amazon.smithy.rust.codegen.server.smithy.validationErrorMessage + +class MapConstraintViolationGenerator( + codegenContext: ServerCodegenContext, + private val modelsModuleWriter: RustWriter, + val shape: MapShape, +) { + private val model = codegenContext.model + private val constrainedShapeSymbolProvider = codegenContext.constrainedShapeSymbolProvider + private val symbolProvider = codegenContext.symbolProvider + private val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes + private val constraintViolationSymbolProvider = + with(codegenContext.constraintViolationSymbolProvider) { + if (publicConstrainedTypes) { + this + } else { + PubCrateConstraintViolationSymbolProvider(this) + } + } + + fun render() { + val keyShape = model.expectShape(shape.key.target, StringShape::class.java) + val valueShape = model.expectShape(shape.value.target) + val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape) + val constraintViolationName = constraintViolationSymbol.name + + val constraintViolationCodegenScopeMutableList: MutableList> = mutableListOf() + if (isKeyConstrained(keyShape, symbolProvider)) { + constraintViolationCodegenScopeMutableList.add("KeyConstraintViolationSymbol" to constraintViolationSymbolProvider.toSymbol(keyShape)) + } + if (isValueConstrained(valueShape, model, symbolProvider)) { + constraintViolationCodegenScopeMutableList.add("ValueConstraintViolationSymbol" to constraintViolationSymbolProvider.toSymbol(valueShape)) + constraintViolationCodegenScopeMutableList.add("KeySymbol" to constrainedShapeSymbolProvider.toSymbol(keyShape)) + } + val constraintViolationCodegenScope = constraintViolationCodegenScopeMutableList.toTypedArray() + + val constraintViolationVisibility = if (publicConstrainedTypes) { + Visibility.PUBLIC + } else { + Visibility.PUBCRATE + } + modelsModuleWriter.withModule( + RustModule( + constraintViolationSymbol.namespace.split(constraintViolationSymbol.namespaceDelimiter).last(), + RustMetadata(visibility = constraintViolationVisibility), + ), + ) { + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) We should really have two `ConstraintViolation` + // types here. One will just have variants for each constraint trait on the map shape, for use by the user. + // The other one will have variants if the shape's key or value is directly or transitively constrained, + // and is for use by the framework. + rustTemplate( + """ + ##[derive(Debug, PartialEq)] + pub${ if (constraintViolationVisibility == Visibility.PUBCRATE) " (crate) " else "" } enum $constraintViolationName { + ${if (shape.hasTrait()) "Length(usize)," else ""} + ${if (isKeyConstrained(keyShape, symbolProvider)) "##[doc(hidden)] Key(#{KeyConstraintViolationSymbol})," else ""} + ${if (isValueConstrained(valueShape, model, symbolProvider)) "##[doc(hidden)] Value(#{KeySymbol}, #{ValueConstraintViolationSymbol})," else ""} + } + """, + *constraintViolationCodegenScope, + ) + + if (shape.isReachableFromOperationInput()) { + rustBlock("impl $constraintViolationName") { + rustBlockTemplate( + "pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField", + "String" to RuntimeType.String, + ) { + rustBlock("match self") { + shape.getTrait()?.also { + rust( + """ + Self::Length(length) => crate::model::ValidationExceptionField { + message: format!("${it.validationErrorMessage()}", length, &path), + path, + }, + """, + ) + } + if (isKeyConstrained(keyShape, symbolProvider)) { + // Note how we _do not_ append the key's member name to the path. This is intentional, as + // per the `RestJsonMalformedLengthMapKey` test. Note keys are always strings. + // https://github.com/awslabs/smithy/blob/ee0b4ff90daaaa5101f32da936c25af8c91cc6e9/smithy-aws-protocol-tests/model/restJson1/validation/malformed-length.smithy#L296-L295 + rust("""Self::Key(key_constraint_violation) => key_constraint_violation.as_validation_exception_field(path),""") + } + if (isValueConstrained(valueShape, model, symbolProvider)) { + // `as_str()` works with regular `String`s and constrained string shapes. + rust("""Self::Value(key, value_constraint_violation) => value_constraint_violation.as_validation_exception_field(path + "/" + key.as_str()),""") + } + } + } + } + } + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedCollectionGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedCollectionGenerator.kt new file mode 100644 index 0000000000..b789c2166d --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedCollectionGenerator.kt @@ -0,0 +1,148 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained +import software.amazon.smithy.rust.codegen.server.smithy.isTransitivelyButNotDirectlyConstrained +import software.amazon.smithy.rust.codegen.server.smithy.typeNameContainsNonPublicType + +/** + * A generator for a wrapper tuple newtype over a collection shape's symbol + * type. + * + * This newtype is for a collection shape that is _transitively_ constrained, + * but not directly. That is, the collection shape does not have a constraint + * trait attached, but the members it holds reach a constrained shape. The + * generated newtype is therefore `pub(crate)`, as the class name indicates, + * and is not available to end users. After deserialization, upon constraint + * traits' enforcement, this type is converted into the regular `Vec` the user + * sees via the generated converters. + * + * TODO(https://github.com/awslabs/smithy-rs/issues/1401) If the collection + * shape is _directly_ constrained, use [ConstrainedCollectionGenerator] + * instead. + */ +class PubCrateConstrainedCollectionGenerator( + val codegenContext: ServerCodegenContext, + val writer: RustWriter, + val shape: CollectionShape, +) { + private val model = codegenContext.model + private val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes + private val unconstrainedShapeSymbolProvider = codegenContext.unconstrainedShapeSymbolProvider + private val pubCrateConstrainedShapeSymbolProvider = codegenContext.pubCrateConstrainedShapeSymbolProvider + private val constrainedShapeSymbolProvider = codegenContext.constrainedShapeSymbolProvider + private val symbolProvider = codegenContext.symbolProvider + + fun render() { + check(shape.canReachConstrainedShape(model, symbolProvider)) + + val symbol = symbolProvider.toSymbol(shape) + val constrainedSymbol = pubCrateConstrainedShapeSymbolProvider.toSymbol(shape) + + val unconstrainedSymbol = unconstrainedShapeSymbolProvider.toSymbol(shape) + val moduleName = constrainedSymbol.namespace.split(constrainedSymbol.namespaceDelimiter).last() + val name = constrainedSymbol.name + val innerShape = model.expectShape(shape.member.target) + val innerConstrainedSymbol = if (innerShape.isTransitivelyButNotDirectlyConstrained(model, symbolProvider)) { + pubCrateConstrainedShapeSymbolProvider.toSymbol(innerShape) + } else { + constrainedShapeSymbolProvider.toSymbol(innerShape) + } + + val codegenScope = arrayOf( + "InnerConstrainedSymbol" to innerConstrainedSymbol, + "ConstrainedTrait" to RuntimeType.ConstrainedTrait(), + "UnconstrainedSymbol" to unconstrainedSymbol, + "Symbol" to symbol, + "From" to RuntimeType.From, + ) + + writer.withModule(RustModule(moduleName, RustMetadata(visibility = Visibility.PUBCRATE))) { + rustTemplate( + """ + ##[derive(Debug, Clone)] + pub(crate) struct $name(pub(crate) std::vec::Vec<#{InnerConstrainedSymbol}>); + + impl #{ConstrainedTrait} for $name { + type Unconstrained = #{UnconstrainedSymbol}; + } + """, + *codegenScope, + ) + + if (publicConstrainedTypes) { + // If the target member shape is itself _not_ directly constrained, and is an aggregate non-Structure shape, + // then its corresponding constrained type is the `pub(crate)` wrapper tuple type, which needs converting into + // the public type the user is exposed to. The two types are isomorphic, and we can convert between them using + // `From`. So we track this particular case here in order to iterate over the list's members and convert + // each of them. + // + // Note that we could add the iteration code unconditionally and it would still be correct, but the `into()` calls + // would be useless. Clippy flags this as [`useless_conversion`]. We could deactivate the lint, but it's probably + // best that we just don't emit a useless iteration, lest the compiler not optimize it away (see [Godbolt]), + // and to make the generated code a little bit simpler. + // + // [`useless_conversion`]: https://rust-lang.github.io/rust-clippy/master/index.html#useless_conversion. + // [Godbolt]: https://godbolt.org/z/eheWebWMa + val innerNeedsConstraining = + !innerShape.isDirectlyConstrained(symbolProvider) && (innerShape is CollectionShape || innerShape is MapShape) + + rustTemplate( + """ + impl #{From}<#{Symbol}> for $name { + fn from(v: #{Symbol}) -> Self { + ${ if (innerNeedsConstraining) { + "Self(v.into_iter().map(|item| item.into()).collect())" + } else { + "Self(v)" + } } + } + } + + impl #{From}<$name> for #{Symbol} { + fn from(v: $name) -> Self { + ${ if (innerNeedsConstraining) { + "v.0.into_iter().map(|item| item.into()).collect()" + } else { + "v.0" + } } + } + } + """, + *codegenScope, + ) + } else { + val innerNeedsConversion = innerShape.typeNameContainsNonPublicType(model, symbolProvider, publicConstrainedTypes) + + rustTemplate( + """ + impl #{From}<$name> for #{Symbol} { + fn from(v: $name) -> Self { + ${ if (innerNeedsConversion) { + "v.0.into_iter().map(|item| item.into()).collect()" + } else { + "v.0" + } } + } + } + """, + *codegenScope, + ) + } + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedMapGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedMapGenerator.kt new file mode 100644 index 0000000000..591b11b7ed --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedMapGenerator.kt @@ -0,0 +1,142 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained +import software.amazon.smithy.rust.codegen.server.smithy.isTransitivelyButNotDirectlyConstrained +import software.amazon.smithy.rust.codegen.server.smithy.typeNameContainsNonPublicType + +/** + * A generator for a wrapper tuple newtype over a map shape's symbol type. + * + * This newtype is for a map shape that is _transitively_ constrained, but not + * directly. That is, the map shape does not have a constraint trait attached, + * but the keys and/or values it holds reach a constrained shape. The generated + * newtype is therefore `pub(crate)`, as the class name indicates, and is not + * available to end users. After deserialization, upon constraint traits' + * enforcement, this type is converted into the regular `HashMap` the user sees + * via the generated converters. + * + * If the map shape is _directly_ constrained, use [ConstrainedMapGenerator] + * instead. + */ +class PubCrateConstrainedMapGenerator( + val codegenContext: ServerCodegenContext, + val writer: RustWriter, + val shape: MapShape, +) { + private val model = codegenContext.model + private val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes + private val unconstrainedShapeSymbolProvider = codegenContext.unconstrainedShapeSymbolProvider + private val pubCrateConstrainedShapeSymbolProvider = codegenContext.pubCrateConstrainedShapeSymbolProvider + private val constrainedShapeSymbolProvider = codegenContext.constrainedShapeSymbolProvider + private val symbolProvider = codegenContext.symbolProvider + + fun render() { + check(shape.canReachConstrainedShape(model, symbolProvider)) + + val symbol = symbolProvider.toSymbol(shape) + val unconstrainedSymbol = unconstrainedShapeSymbolProvider.toSymbol(shape) + val constrainedSymbol = pubCrateConstrainedShapeSymbolProvider.toSymbol(shape) + val moduleName = constrainedSymbol.namespace.split(constrainedSymbol.namespaceDelimiter).last() + val name = constrainedSymbol.name + val keyShape = model.expectShape(shape.key.target, StringShape::class.java) + val valueShape = model.expectShape(shape.value.target) + val keySymbol = constrainedShapeSymbolProvider.toSymbol(keyShape) + val valueSymbol = if (valueShape.isTransitivelyButNotDirectlyConstrained(model, symbolProvider)) { + pubCrateConstrainedShapeSymbolProvider.toSymbol(valueShape) + } else { + constrainedShapeSymbolProvider.toSymbol(valueShape) + } + + val codegenScope = arrayOf( + "KeySymbol" to keySymbol, + "ValueSymbol" to valueSymbol, + "ConstrainedTrait" to RuntimeType.ConstrainedTrait(), + "UnconstrainedSymbol" to unconstrainedSymbol, + "Symbol" to symbol, + "From" to RuntimeType.From, + ) + + writer.withModule(RustModule(moduleName, RustMetadata(visibility = Visibility.PUBCRATE))) { + rustTemplate( + """ + ##[derive(Debug, Clone)] + pub(crate) struct $name(pub(crate) std::collections::HashMap<#{KeySymbol}, #{ValueSymbol}>); + + impl #{ConstrainedTrait} for $name { + type Unconstrained = #{UnconstrainedSymbol}; + } + """, + *codegenScope, + ) + + if (publicConstrainedTypes) { + // Unless the map holds an aggregate shape as its value shape whose symbol's type is _not_ `pub(crate)`, the + // `.into()` calls are useless. + // See the comment in [ConstrainedCollectionShape] for a more detailed explanation. + val innerNeedsConstraining = + !valueShape.isDirectlyConstrained(symbolProvider) && (valueShape is CollectionShape || valueShape is MapShape) + + rustTemplate( + """ + impl #{From}<#{Symbol}> for $name { + fn from(v: #{Symbol}) -> Self { + ${ if (innerNeedsConstraining) { + "Self(v.into_iter().map(|(k, v)| (k, v.into())).collect())" + } else { + "Self(v)" + } } + } + } + + impl #{From}<$name> for #{Symbol} { + fn from(v: $name) -> Self { + ${ if (innerNeedsConstraining) { + "v.0.into_iter().map(|(k, v)| (k, v.into())).collect()" + } else { + "v.0" + } } + } + } + """, + *codegenScope, + ) + } else { + val keyNeedsConversion = keyShape.typeNameContainsNonPublicType(model, symbolProvider, publicConstrainedTypes) + val valueNeedsConversion = valueShape.typeNameContainsNonPublicType(model, symbolProvider, publicConstrainedTypes) + + rustTemplate( + """ + impl #{From}<$name> for #{Symbol} { + fn from(v: $name) -> Self { + ${ if (keyNeedsConversion || valueNeedsConversion) { + val keyConversion = if (keyNeedsConversion) { ".into()" } else { "" } + val valueConversion = if (valueNeedsConversion) { ".into()" } else { "" } + "v.0.into_iter().map(|(k, v)| (k$keyConversion, v$valueConversion)).collect()" + } else { + "v.0" + } } + } + } + """, + *codegenScope, + ) + } + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolations.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolations.kt new file mode 100644 index 0000000000..f2c572eedc --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolations.kt @@ -0,0 +1,218 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.codegen.core.SymbolProvider +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.docs +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.smithy.makeRustBoxed +import software.amazon.smithy.rust.codegen.core.smithy.traits.RustBoxTrait +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.letIf +import software.amazon.smithy.rust.codegen.core.util.toPascalCase +import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape + +/** + * Renders constraint violation types that arise when building a structure shape builder. + * + * Used by [ServerBuilderGenerator] and [ServerBuilderGeneratorWithoutPublicConstrainedTypes]. + */ +class ServerBuilderConstraintViolations( + codegenContext: ServerCodegenContext, + private val shape: StructureShape, + private val builderTakesInUnconstrainedTypes: Boolean, +) { + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + private val constraintViolationSymbolProvider = + with(codegenContext.constraintViolationSymbolProvider) { + if (codegenContext.settings.codegenConfig.publicConstrainedTypes) { + this + } else { + PubCrateConstraintViolationSymbolProvider(this) + } + } + private val members: List = shape.allMembers.values.toList() + val all = members.flatMap { member -> + listOfNotNull( + forMember(member), + builderConstraintViolationForMember(member), + ) + } + + fun render( + writer: RustWriter, + visibility: Visibility, + nonExhaustive: Boolean, + shouldRenderAsValidationExceptionFieldList: Boolean, + ) { + Attribute.Derives(setOf(RuntimeType.Debug, RuntimeType.PartialEq)).render(writer) + writer.docs("Holds one variant for each of the ways the builder can fail.") + if (nonExhaustive) Attribute.NonExhaustive.render(writer) + val constraintViolationSymbolName = constraintViolationSymbolProvider.toSymbol(shape).name + writer.rustBlock("pub${ if (visibility == Visibility.PUBCRATE) " (crate) " else "" } enum $constraintViolationSymbolName") { + renderConstraintViolations(writer) + } + renderImplDisplayConstraintViolation(writer) + writer.rust("impl #T for ConstraintViolation { }", RuntimeType.StdError) + + if (shouldRenderAsValidationExceptionFieldList) { + renderAsValidationExceptionFieldList(writer) + } + } + + /** + * Returns the builder failure associated with the `member` field if its target is constrained. + */ + fun builderConstraintViolationForMember(member: MemberShape) = + if (builderTakesInUnconstrainedTypes && member.targetCanReachConstrainedShape(model, symbolProvider)) { + ConstraintViolation(member, ConstraintViolationKind.CONSTRAINED_SHAPE_FAILURE) + } else { + null + } + + /** + * Returns the builder failure associated with the [member] field if it is `required`. + */ + fun forMember(member: MemberShape): ConstraintViolation? { + check(members.contains(member)) + // TODO(https://github.com/awslabs/smithy-rs/issues/1302, https://github.com/awslabs/smithy/issues/1179): See above. + return if (symbolProvider.toSymbol(member).isOptional()) { + null + } else { + ConstraintViolation(member, ConstraintViolationKind.MISSING_MEMBER) + } + } + + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) This impl does not take into account the `sensitive` trait. + // When constraint violation error messages are adjusted to match protocol tests, we should ensure it's honored. + private fun renderImplDisplayConstraintViolation(writer: RustWriter) { + writer.rustBlock("impl #T for ConstraintViolation", RuntimeType.Display) { + rustBlock("fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result") { + rustBlock("match self") { + all.forEach { + val arm = if (it.hasInner()) { + "ConstraintViolation::${it.name()}(_)" + } else { + "ConstraintViolation::${it.name()}" + } + rust("""$arm => write!(f, "${it.message(symbolProvider, model)}"),""") + } + } + } + } + } + + private fun renderConstraintViolations(writer: RustWriter) { + for (constraintViolation in all) { + when (constraintViolation.kind) { + ConstraintViolationKind.MISSING_MEMBER -> { + writer.docs("${constraintViolation.message(symbolProvider, model).replaceFirstChar { it.uppercaseChar() }}.") + writer.rust("${constraintViolation.name()},") + } + + ConstraintViolationKind.CONSTRAINED_SHAPE_FAILURE -> { + val targetShape = model.expectShape(constraintViolation.forMember.target) + + val constraintViolationSymbol = + constraintViolationSymbolProvider.toSymbol(targetShape) + // If the corresponding structure's member is boxed, box this constraint violation symbol too. + .letIf(constraintViolation.forMember.hasTrait()) { + it.makeRustBoxed() + } + + // Note we cannot express the inner constraint violation as `>::Error`, because `T` might + // be `pub(crate)` and that would leak `T` in a public interface. + writer.docs("${constraintViolation.message(symbolProvider, model)}.".replaceFirstChar { it.uppercaseChar() }) + Attribute.DocHidden.render(writer) + writer.rust("${constraintViolation.name()}(#T),", constraintViolationSymbol) + } + } + } + } + + private fun renderAsValidationExceptionFieldList(writer: RustWriter) { + val validationExceptionFieldWritable = writable { + rustBlock("match self") { + all.forEach { + if (it.hasInner()) { + rust("""ConstraintViolation::${it.name()}(inner) => inner.as_validation_exception_field(path + "/${it.forMember.memberName}"),""") + } else { + rust( + """ + ConstraintViolation::${it.name()} => crate::model::ValidationExceptionField { + message: format!("Value null at '{}/${it.forMember.memberName}' failed to satisfy constraint: Member must not be null", path), + path: path + "/${it.forMember.memberName}", + }, + """, + ) + } + } + } + } + + writer.rustTemplate( + """ + impl ConstraintViolation { + pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { + #{ValidationExceptionFieldWritable:W} + } + } + """, + "ValidationExceptionFieldWritable" to validationExceptionFieldWritable, + "String" to RuntimeType.String, + ) + } +} + +/** + * The kinds of constraint violations that can occur when building the builder. + */ +enum class ConstraintViolationKind { + // A field is required but was not provided. + MISSING_MEMBER, + + // An unconstrained type was provided for a field targeting a constrained shape, but it failed to convert into the constrained type. + CONSTRAINED_SHAPE_FAILURE, +} + +data class ConstraintViolation(val forMember: MemberShape, val kind: ConstraintViolationKind) { + fun name() = when (kind) { + ConstraintViolationKind.MISSING_MEMBER -> "Missing${forMember.memberName.toPascalCase()}" + ConstraintViolationKind.CONSTRAINED_SHAPE_FAILURE -> forMember.memberName.toPascalCase() + } + + /** + * Whether the constraint violation is a Rust tuple struct with one element. + */ + fun hasInner() = kind == ConstraintViolationKind.CONSTRAINED_SHAPE_FAILURE + + /** + * A message for a `ConstraintViolation` variant. This is used in both Rust documentation and the `Display` trait implementation. + */ + fun message(symbolProvider: SymbolProvider, model: Model): String { + val memberName = symbolProvider.toMemberName(forMember) + val structureSymbol = symbolProvider.toSymbol(model.expectShape(forMember.container)) + return when (kind) { + ConstraintViolationKind.MISSING_MEMBER -> "`$memberName` was not provided but it is required when building `${structureSymbol.name}`" + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) Nest errors. Adjust message following protocol tests. + ConstraintViolationKind.CONSTRAINED_SHAPE_FAILURE -> "constraint violation occurred building member `$memberName` when building `${structureSymbol.name}`" + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt new file mode 100644 index 0000000000..6f319f6719 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt @@ -0,0 +1,542 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.codegen.core.SymbolProvider +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock +import software.amazon.smithy.rust.codegen.core.rustlang.deprecatedShape +import software.amazon.smithy.rust.codegen.core.rustlang.docs +import software.amazon.smithy.rust.codegen.core.rustlang.documentShape +import software.amazon.smithy.rust.codegen.core.rustlang.implInto +import software.amazon.smithy.rust.codegen.core.rustlang.render +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter +import software.amazon.smithy.rust.codegen.core.rustlang.withBlock +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata +import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.smithy.isRustBoxed +import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained +import software.amazon.smithy.rust.codegen.core.smithy.makeOptional +import software.amazon.smithy.rust.codegen.core.smithy.makeRustBoxed +import software.amazon.smithy.rust.codegen.core.smithy.mapRustType +import software.amazon.smithy.rust.codegen.core.smithy.rustType +import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.letIf +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType +import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.hasConstraintTraitOrTargetHasConstraintTrait +import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput +import software.amazon.smithy.rust.codegen.server.smithy.wouldHaveConstrainedWrapperTupleTypeWerePublicConstrainedTypesEnabled + +/** + * Generates a builder for the Rust type associated with the [StructureShape]. + * + * This generator is meant for use by the server project. Clients use the [BuilderGenerator] from the `codegen-client` + * Gradle subproject instead. + * + * This builder is different in that it enforces [constraint traits] upon calling `.build()`. If any constraint + * violations occur, the `build` method returns them. + * + * These are the main differences with the builders generated by the client's [BuilderGenerator]: + * + * - The design of this builder is simpler and closely follows what you get when using the [derive_builder] crate: + * * The builder has one method per struct member named _exactly_ like the struct member and whose input type + * matches _exactly_ the struct's member type. This method is generated by [renderBuilderMemberFn]. + * * The builder has one _setter_ method (i.e. prefixed with `set_`) per struct member whose input type is the + * corresponding _unconstrained type_ for the member. This method is always `pub(crate)` and meant for use for + * server deserializers only. + * * There are no convenience methods to add items to vector and hash map struct members. + * - The builder is not `PartialEq`. This is because the builder's members may or may not have been constrained (their + * types hold `MaybeConstrained`), and so it doesn't make sense to compare e.g. two builders holding the same data + * values, but one builder holds the member in the constrained variant while the other one holds it in the unconstrained + * variant. + * - The builder always implements `TryFrom for Structure` or `From for Structure`, depending on whether + * the structure is constrained (and hence enforcing the constraints might yield an error) or not, respectively. + * + * The builder is `pub(crate)` when `publicConstrainedTypes` is `false`, since in this case the user is never exposed + * to constrained types, and only the server's deserializers need to enforce constraint traits upon receiving a request. + * The user is exposed to [ServerBuilderGeneratorWithoutPublicConstrainedTypes] in this case instead, which intentionally + * _does not_ enforce constraints. + * + * [constraint traits]: https://awslabs.github.io/smithy/2.0/spec/constraint-traits.html + * [derive_builder]: https://docs.rs/derive_builder/latest/derive_builder/index.html + */ +class ServerBuilderGenerator( + codegenContext: ServerCodegenContext, + private val shape: StructureShape, +) { + companion object { + /** + * Returns whether a structure shape, whose builder has been generated with [ServerBuilderGenerator], requires a + * fallible builder to be constructed. + */ + fun hasFallibleBuilder( + structureShape: StructureShape, + model: Model, + symbolProvider: SymbolProvider, + takeInUnconstrainedTypes: Boolean, + ): Boolean = + if (takeInUnconstrainedTypes) { + structureShape.canReachConstrainedShape(model, symbolProvider) + } else { + structureShape + .members() + .map { symbolProvider.toSymbol(it) } + .any { !it.isOptional() } + } + } + + private val takeInUnconstrainedTypes = shape.isReachableFromOperationInput() + private val model = codegenContext.model + private val runtimeConfig = codegenContext.runtimeConfig + private val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes + private val visibility = if (publicConstrainedTypes) Visibility.PUBLIC else Visibility.PUBCRATE + private val symbolProvider = codegenContext.symbolProvider + private val constrainedShapeSymbolProvider = codegenContext.constrainedShapeSymbolProvider + private val pubCrateConstrainedShapeSymbolProvider = codegenContext.pubCrateConstrainedShapeSymbolProvider + private val members: List = shape.allMembers.values.toList() + private val structureSymbol = symbolProvider.toSymbol(shape) + private val builderSymbol = shape.serverBuilderSymbol(codegenContext) + private val moduleName = builderSymbol.namespace.split(builderSymbol.namespaceDelimiter).last() + private val isBuilderFallible = hasFallibleBuilder(shape, model, symbolProvider, takeInUnconstrainedTypes) + private val serverBuilderConstraintViolations = + ServerBuilderConstraintViolations(codegenContext, shape, takeInUnconstrainedTypes) + + private val codegenScope = arrayOf( + "RequestRejection" to ServerRuntimeType.RequestRejection(runtimeConfig), + "Structure" to structureSymbol, + "From" to RuntimeType.From, + "TryFrom" to RuntimeType.TryFrom, + "MaybeConstrained" to RuntimeType.MaybeConstrained(), + ) + + fun render(writer: RustWriter) { + writer.docs("See #D.", structureSymbol) + writer.withModule(RustModule(moduleName, RustMetadata(visibility = visibility))) { + renderBuilder(this) + } + } + + private fun renderBuilder(writer: RustWriter) { + if (isBuilderFallible) { + serverBuilderConstraintViolations.render( + writer, + visibility, + nonExhaustive = true, + shouldRenderAsValidationExceptionFieldList = shape.isReachableFromOperationInput(), + ) + + // Only generate converter from `ConstraintViolation` into `RequestRejection` if the structure shape is + // an operation input shape. + if (shape.hasTrait()) { + renderImplFromConstraintViolationForRequestRejection(writer) + } + + if (takeInUnconstrainedTypes) { + renderImplFromBuilderForMaybeConstrained(writer) + } + + renderTryFromBuilderImpl(writer) + } else { + renderFromBuilderImpl(writer) + } + + writer.docs("A builder for #D.", structureSymbol) + // Matching derives to the main structure, - `PartialEq` (see class documentation for why), + `Default` + // since we are a builder and everything is optional. + val baseDerives = structureSymbol.expectRustMetadata().derives + val derives = baseDerives.derives.intersect(setOf(RuntimeType.Debug, RuntimeType.Clone)) + RuntimeType.Default + baseDerives.copy(derives = derives).render(writer) + writer.rustBlock("pub${ if (visibility == Visibility.PUBCRATE) " (crate)" else "" } struct Builder") { + members.forEach { renderBuilderMember(this, it) } + } + + writer.rustBlock("impl Builder") { + for (member in members) { + if (publicConstrainedTypes) { + renderBuilderMemberFn(this, member) + } + + if (takeInUnconstrainedTypes) { + renderBuilderMemberSetterFn(this, member) + } + } + renderBuildFn(this) + } + } + + private fun renderImplFromConstraintViolationForRequestRejection(writer: RustWriter) { + writer.rustTemplate( + """ + impl #{From} for #{RequestRejection} { + fn from(constraint_violation: ConstraintViolation) -> Self { + let first_validation_exception_field = constraint_violation.as_validation_exception_field("".to_owned()); + let validation_exception = crate::error::ValidationException { + message: format!("1 validation error detected. {}", &first_validation_exception_field.message), + field_list: Some(vec![first_validation_exception_field]), + }; + Self::ConstraintViolation( + crate::operation_ser::serialize_structure_crate_error_validation_exception(&validation_exception) + .expect("impossible") + ) + } + } + """, + *codegenScope, + ) + } + + private fun renderImplFromBuilderForMaybeConstrained(writer: RustWriter) { + writer.rustTemplate( + """ + impl #{From} for #{StructureMaybeConstrained} { + fn from(builder: Builder) -> Self { + Self::Unconstrained(builder) + } + } + """, + *codegenScope, + "StructureMaybeConstrained" to structureSymbol.makeMaybeConstrained(), + ) + } + + private fun renderBuildFn(implBlockWriter: RustWriter) { + implBlockWriter.docs("""Consumes the builder and constructs a #D.""", structureSymbol) + if (isBuilderFallible) { + implBlockWriter.docs( + """ + The builder fails to construct a #D if a [`ConstraintViolation`] occurs. + """, + structureSymbol, + ) + + if (serverBuilderConstraintViolations.all.size > 1) { + implBlockWriter.docs("If the builder fails, it will return the _first_ encountered [`ConstraintViolation`].") + } + } + implBlockWriter.rustTemplate( + """ + pub fn build(self) -> #{ReturnType:W} { + self.build_enforcing_all_constraints() + } + """, + "ReturnType" to buildFnReturnType(isBuilderFallible, structureSymbol), + ) + renderBuildEnforcingAllConstraintsFn(implBlockWriter) + } + + private fun renderBuildEnforcingAllConstraintsFn(implBlockWriter: RustWriter) { + implBlockWriter.rustBlockTemplate( + "fn build_enforcing_all_constraints(self) -> #{ReturnType:W}", + "ReturnType" to buildFnReturnType(isBuilderFallible, structureSymbol), + ) { + conditionalBlock("Ok(", ")", conditional = isBuilderFallible) { + coreBuilder(this) + } + } + } + + fun renderConvenienceMethod(implBlock: RustWriter) { + implBlock.docs("Creates a new builder-style object to manufacture #D.", structureSymbol) + implBlock.rustBlock("pub fn builder() -> #T", builderSymbol) { + write("#T::default()", builderSymbol) + } + } + + private fun renderBuilderMember(writer: RustWriter, member: MemberShape) { + val memberSymbol = builderMemberSymbol(member) + val memberName = constrainedShapeSymbolProvider.toMemberName(member) + // Builder members are crate-public to enable using them directly in serializers/deserializers. + // During XML deserialization, `builder..take` is used to append to lists and maps. + writer.write("pub(crate) $memberName: #T,", memberSymbol) + } + + /** + * Render a `foo` method to set shape member `foo`. The caller must provide a value with the exact same type + * as the shape member's type. + * + * This method is meant for use by the user; it is not used by the generated crate's (de)serializers. + * + * This method is only generated when `publicConstrainedTypes` is `true`. Otherwise, the user has at their disposal + * the method from [ServerBuilderGeneratorWithoutPublicConstrainedTypes]. + */ + private fun renderBuilderMemberFn( + writer: RustWriter, + member: MemberShape, + ) { + check(publicConstrainedTypes) + val symbol = symbolProvider.toSymbol(member) + val memberName = symbolProvider.toMemberName(member) + + val hasBox = symbol.mapRustType { it.stripOuter() }.isRustBoxed() + val wrapInMaybeConstrained = takeInUnconstrainedTypes && member.targetCanReachConstrainedShape(model, symbolProvider) + + writer.documentShape(member, model) + writer.deprecatedShape(member) + + if (hasBox && wrapInMaybeConstrained) { + // In the case of recursive shapes, the member might be boxed. If so, and the member is also constrained, the + // implementation of this function needs to immediately unbox the value to wrap it in `MaybeConstrained`, + // and then re-box. Clippy warns us that we could have just taken in an unboxed value to avoid this round-trip + // to the heap. However, that will make the builder take in a value whose type does not exactly match the + // shape member's type. + // We don't want to introduce API asymmetry just for this particular case, so we disable the lint. + Attribute.Custom("allow(clippy::boxed_local)").render(writer) + } + writer.rustBlock("pub fn $memberName(mut self, input: ${symbol.rustType().render()}) -> Self") { + withBlock("self.$memberName = ", "; self") { + conditionalBlock("Some(", ")", conditional = !symbol.isOptional()) { + val maybeConstrainedVariant = + "${symbol.makeMaybeConstrained().rustType().namespace}::MaybeConstrained::Constrained" + + var varExpr = if (symbol.isOptional()) "v" else "input" + if (hasBox) varExpr = "*$varExpr" + if (!constrainedTypeHoldsFinalType(member)) varExpr = "($varExpr).into()" + + if (wrapInMaybeConstrained) { + conditionalBlock("input.map(##[allow(clippy::redundant_closure)] |v| ", ")", conditional = symbol.isOptional()) { + conditionalBlock("Box::new(", ")", conditional = hasBox) { + rust("$maybeConstrainedVariant($varExpr)") + } + } + } else { + write("input") + } + } + } + } + } + + /** + * Returns whether the constrained builder member type (the type on which the `Constrained` trait is implemented) + * is the final type the user sees when receiving the built struct. This is true when the corresponding constrained + * type is public and not `pub(crate)`, which happens when the target is a structure shape, a union shape, or is + * directly constrained. + * + * An example where this returns false is when the member shape targets a list whose members are lists of structures + * having at least one `required` member. In this case the member shape is transitively but not directly constrained, + * so the generated constrained type is `pub(crate)` and needs converting into the final type the user will be + * exposed to. + * + * See [PubCrateConstrainedShapeSymbolProvider] too. + */ + private fun constrainedTypeHoldsFinalType(member: MemberShape): Boolean { + val targetShape = model.expectShape(member.target) + return targetShape is StructureShape || + targetShape is UnionShape || + member.hasConstraintTraitOrTargetHasConstraintTrait(model, symbolProvider) + } + + /** + * Render a `set_foo` method. + * This method is able to take in unconstrained types for constrained shapes, like builders of structs in the case + * of structure shapes. + * + * This method is only used by deserializers at the moment and is therefore `pub(crate)`. + */ + private fun renderBuilderMemberSetterFn( + writer: RustWriter, + member: MemberShape, + ) { + val builderMemberSymbol = builderMemberSymbol(member) + val inputType = builderMemberSymbol.rustType().stripOuter().implInto() + .letIf( + // TODO(https://github.com/awslabs/smithy-rs/issues/1302, https://github.com/awslabs/smithy/issues/1179): + // The only reason why this condition can't simply be `member.isOptional` + // is because non-`required` blob streaming members are interpreted as + // `required`, so we can't use `member.isOptional` here. + symbolProvider.toSymbol(member).isOptional(), + ) { "Option<$it>" } + val memberName = symbolProvider.toMemberName(member) + + writer.documentShape(member, model) + // Setter names will never hit a reserved word and therefore never need escaping. + writer.rustBlock("pub(crate) fn set_${member.memberName.toSnakeCase()}(mut self, input: $inputType) -> Self") { + rust( + """ + self.$memberName = ${ + // TODO(https://github.com/awslabs/smithy-rs/issues/1302, https://github.com/awslabs/smithy/issues/1179): See above. + if (symbolProvider.toSymbol(member).isOptional()) { + "input.map(|v| v.into())" + } else { + "Some(input.into())" + } + }; + self + """, + ) + } + } + + private fun renderTryFromBuilderImpl(writer: RustWriter) { + writer.rustTemplate( + """ + impl #{TryFrom} for #{Structure} { + type Error = ConstraintViolation; + + fn try_from(builder: Builder) -> Result { + builder.build() + } + } + """, + *codegenScope, + ) + } + + private fun renderFromBuilderImpl(writer: RustWriter) { + writer.rustTemplate( + """ + impl #{From} for #{Structure} { + fn from(builder: Builder) -> Self { + builder.build() + } + } + """, + *codegenScope, + ) + } + + /** + * Returns the symbol for a builder's member. + * All builder members are optional, but only some are `Option`s where `T` needs to be constrained. + */ + private fun builderMemberSymbol(member: MemberShape): Symbol = + if (takeInUnconstrainedTypes && member.targetCanReachConstrainedShape(model, symbolProvider)) { + val strippedOption = if (member.hasConstraintTraitOrTargetHasConstraintTrait(model, symbolProvider)) { + constrainedShapeSymbolProvider.toSymbol(member) + } else { + pubCrateConstrainedShapeSymbolProvider.toSymbol(member) + } + // Strip the `Option` in case the member is not `required`. + .mapRustType { it.stripOuter() } + + val hadBox = strippedOption.isRustBoxed() + strippedOption + // Strip the `Box` in case the member can reach itself recursively. + .mapRustType { it.stripOuter() } + // Wrap it in the Cow-like `constrained::MaybeConstrained` type, since we know the target member shape can + // reach a constrained shape. + .makeMaybeConstrained() + // Box it in case the member can reach itself recursively. + .letIf(hadBox) { it.makeRustBoxed() } + // Ensure we always end up with an `Option`. + .makeOptional() + } else { + constrainedShapeSymbolProvider.toSymbol(member).makeOptional() + } + + /** + * Writes the code to instantiate the struct the builder builds. + * + * Builder member types are either: + * 1. `Option>`; or + * 2. `Option`. + * + * Where `U` is a constrained type. + * + * The structs they build have member types: + * a) `Option`; or + * b) `T`. + * + * `U` is equal to `T` when: + * - the shape for `U` has a constraint trait and `publicConstrainedTypes` is `true`; or + * - the member shape is a structure or union shape. + * Otherwise, `U` is always a `pub(crate)` tuple newtype holding `T`. + * + * For each member, this function first safely unwraps case 1. into 2., then converts `U` into `T` if necessary, + * and then converts into b) if necessary. + */ + private fun coreBuilder(writer: RustWriter) { + writer.rustBlock("#T", structureSymbol) { + for (member in members) { + val memberName = symbolProvider.toMemberName(member) + + withBlock("$memberName: self.$memberName", ",") { + // Write the modifier(s). + serverBuilderConstraintViolations.builderConstraintViolationForMember(member)?.also { constraintViolation -> + val hasBox = builderMemberSymbol(member) + .mapRustType { it.stripOuter() } + .isRustBoxed() + if (hasBox) { + rustTemplate( + """ + .map(|v| match *v { + #{MaybeConstrained}::Constrained(x) => Ok(Box::new(x)), + #{MaybeConstrained}::Unconstrained(x) => Ok(Box::new(x.try_into()?)), + }) + .map(|res| + res${ if (constrainedTypeHoldsFinalType(member)) "" else ".map(|v| v.into())" } + .map_err(|err| ConstraintViolation::${constraintViolation.name()}(Box::new(err))) + ) + .transpose()? + """, + *codegenScope, + ) + } else { + rustTemplate( + """ + .map(|v| match v { + #{MaybeConstrained}::Constrained(x) => Ok(x), + #{MaybeConstrained}::Unconstrained(x) => x.try_into(), + }) + .map(|res| + res${if (constrainedTypeHoldsFinalType(member)) "" else ".map(|v| v.into())"} + .map_err(ConstraintViolation::${constraintViolation.name()}) + ) + .transpose()? + """, + *codegenScope, + ) + + // Constrained types are not public and this is a member shape that would have generated a + // public constrained type, were the setting to be enabled. + // We've just checked the constraints hold by going through the non-public + // constrained type, but the user wants to work with the unconstrained type, so we have to + // unwrap it. + if (!publicConstrainedTypes && member.wouldHaveConstrainedWrapperTupleTypeWerePublicConstrainedTypesEnabled(model)) { + rust( + ".map(|v: #T| v.into())", + constrainedShapeSymbolProvider.toSymbol(model.expectShape(member.target)), + ) + } + } + } + serverBuilderConstraintViolations.forMember(member)?.also { + rust(".ok_or(ConstraintViolation::${it.name()})?") + } + } + } + } + } +} + +fun buildFnReturnType(isBuilderFallible: Boolean, structureSymbol: Symbol) = writable { + if (isBuilderFallible) { + rust("Result<#T, ConstraintViolation>", structureSymbol) + } else { + rust("#T", structureSymbol) + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorWithoutPublicConstrainedTypes.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorWithoutPublicConstrainedTypes.kt new file mode 100644 index 0000000000..897bdc1166 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorWithoutPublicConstrainedTypes.kt @@ -0,0 +1,238 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.codegen.core.SymbolProvider +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock +import software.amazon.smithy.rust.codegen.core.rustlang.deprecatedShape +import software.amazon.smithy.rust.codegen.core.rustlang.docs +import software.amazon.smithy.rust.codegen.core.rustlang.documentShape +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.withBlock +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata +import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.smithy.makeOptional +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType + +/** + * Generates a builder for the Rust type associated with the [StructureShape]. + * + * This builder is similar in design to [ServerBuilderGenerator], so consult its documentation in that regard. However, + * this builder has a few differences. + * + * Unlike [ServerBuilderGenerator], this builder only enforces constraints that are baked into the type system _when + * `publicConstrainedTypes` is false_. So in terms of honoring the Smithy spec, this builder only enforces enums + * and the `required` trait. + * + * Unlike [ServerBuilderGenerator], this builder is always public. It is the only builder type the user is exposed to + * when `publicConstrainedTypes` is false. + */ +class ServerBuilderGeneratorWithoutPublicConstrainedTypes( + codegenContext: ServerCodegenContext, + shape: StructureShape, +) { + companion object { + /** + * Returns whether a structure shape, whose builder has been generated with + * [ServerBuilderGeneratorWithoutPublicConstrainedTypes], requires a fallible builder to be constructed. + * + * This builder only enforces the `required` trait. + */ + fun hasFallibleBuilder( + structureShape: StructureShape, + symbolProvider: SymbolProvider, + ): Boolean = + structureShape + .members() + .map { symbolProvider.toSymbol(it) } + .any { !it.isOptional() } + } + + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + private val members: List = shape.allMembers.values.toList() + private val structureSymbol = symbolProvider.toSymbol(shape) + + private val builderSymbol = shape.serverBuilderSymbol(symbolProvider, false) + private val moduleName = builderSymbol.namespace.split("::").last() + private val isBuilderFallible = hasFallibleBuilder(shape, symbolProvider) + private val serverBuilderConstraintViolations = + ServerBuilderConstraintViolations(codegenContext, shape, builderTakesInUnconstrainedTypes = false) + + private val codegenScope = arrayOf( + "RequestRejection" to ServerRuntimeType.RequestRejection(codegenContext.runtimeConfig), + "Structure" to structureSymbol, + "From" to RuntimeType.From, + "TryFrom" to RuntimeType.TryFrom, + "MaybeConstrained" to RuntimeType.MaybeConstrained(), + ) + + fun render(writer: RustWriter) { + writer.docs("See #D.", structureSymbol) + writer.withModule(RustModule.public(moduleName)) { + renderBuilder(this) + } + } + + private fun renderBuilder(writer: RustWriter) { + if (isBuilderFallible) { + serverBuilderConstraintViolations.render( + writer, + Visibility.PUBLIC, + nonExhaustive = false, + shouldRenderAsValidationExceptionFieldList = false, + ) + + renderTryFromBuilderImpl(writer) + } else { + renderFromBuilderImpl(writer) + } + + writer.docs("A builder for #D.", structureSymbol) + // Matching derives to the main structure, - `PartialEq` (to be consistent with [ServerBuilderGenerator]), + `Default` + // since we are a builder and everything is optional. + val baseDerives = structureSymbol.expectRustMetadata().derives + val derives = baseDerives.derives.intersect(setOf(RuntimeType.Debug, RuntimeType.Clone)) + RuntimeType.Default + baseDerives.copy(derives = derives).render(writer) + writer.rustBlock("pub struct Builder") { + members.forEach { renderBuilderMember(this, it) } + } + + writer.rustBlock("impl Builder") { + for (member in members) { + renderBuilderMemberFn(this, member) + } + renderBuildFn(this) + } + } + + private fun renderBuildFn(implBlockWriter: RustWriter) { + implBlockWriter.docs("""Consumes the builder and constructs a #D.""", structureSymbol) + if (isBuilderFallible) { + implBlockWriter.docs( + """ + The builder fails to construct a #D if you do not provide a value for all non-`Option`al members. + """, + structureSymbol, + ) + } + implBlockWriter.rustTemplate( + """ + pub fn build(self) -> #{ReturnType:W} { + self.build_enforcing_required_and_enum_traits() + } + """, + "ReturnType" to buildFnReturnType(isBuilderFallible, structureSymbol), + ) + renderBuildEnforcingRequiredAndEnumTraitsFn(implBlockWriter) + } + + private fun renderBuildEnforcingRequiredAndEnumTraitsFn(implBlockWriter: RustWriter) { + implBlockWriter.rustBlockTemplate( + "fn build_enforcing_required_and_enum_traits(self) -> #{ReturnType:W}", + "ReturnType" to buildFnReturnType(isBuilderFallible, structureSymbol), + ) { + conditionalBlock("Ok(", ")", conditional = isBuilderFallible) { + coreBuilder(this) + } + } + } + + private fun coreBuilder(writer: RustWriter) { + writer.rustBlock("#T", structureSymbol) { + for (member in members) { + val memberName = symbolProvider.toMemberName(member) + + withBlock("$memberName: self.$memberName", ",") { + serverBuilderConstraintViolations.forMember(member)?.also { + rust(".ok_or(ConstraintViolation::${it.name()})?") + } + } + } + } + } + + fun renderConvenienceMethod(implBlock: RustWriter) { + implBlock.docs("Creates a new builder-style object to manufacture #D.", structureSymbol) + implBlock.rustBlock("pub fn builder() -> #T", builderSymbol) { + write("#T::default()", builderSymbol) + } + } + + private fun renderBuilderMember(writer: RustWriter, member: MemberShape) { + val memberSymbol = builderMemberSymbol(member) + val memberName = symbolProvider.toMemberName(member) + // Builder members are crate-public to enable using them directly in serializers/deserializers. + // During XML deserialization, `builder..take` is used to append to lists and maps. + writer.write("pub(crate) $memberName: #T,", memberSymbol) + } + + /** + * Render a `foo` method to set shape member `foo`. The caller must provide a value with the exact same type + * as the shape member's type. + * + * This method is meant for use by the user; it is not used by the generated crate's (de)serializers. + */ + private fun renderBuilderMemberFn(writer: RustWriter, member: MemberShape) { + val memberSymbol = symbolProvider.toSymbol(member) + val memberName = symbolProvider.toMemberName(member) + + writer.documentShape(member, model) + writer.deprecatedShape(member) + + writer.rustBlock("pub fn $memberName(mut self, input: #T) -> Self", memberSymbol) { + withBlock("self.$memberName = ", "; self") { + conditionalBlock("Some(", ")", conditional = !memberSymbol.isOptional()) { + rust("input") + } + } + } + } + + private fun renderTryFromBuilderImpl(writer: RustWriter) { + writer.rustTemplate( + """ + impl #{TryFrom} for #{Structure} { + type Error = ConstraintViolation; + + fn try_from(builder: Builder) -> Result { + builder.build() + } + } + """, + *codegenScope, + ) + } + + private fun renderFromBuilderImpl(writer: RustWriter) { + writer.rustTemplate( + """ + impl #{From} for #{Structure} { + fn from(builder: Builder) -> Self { + builder.build() + } + } + """, + *codegenScope, + ) + } + + /** + * Returns the symbol for a builder's member. + */ + private fun builderMemberSymbol(member: MemberShape): Symbol = symbolProvider.toSymbol(member).makeOptional() +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderSymbol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderSymbol.kt new file mode 100644 index 0000000000..a8ee7fd8f6 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderSymbol.kt @@ -0,0 +1,35 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.codegen.core.SymbolProvider +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.smithy.rustType +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext + +fun StructureShape.serverBuilderSymbol(codegenContext: ServerCodegenContext): Symbol = + this.serverBuilderSymbol(codegenContext.symbolProvider, !codegenContext.settings.codegenConfig.publicConstrainedTypes) + +fun StructureShape.serverBuilderSymbol(symbolProvider: SymbolProvider, pubCrate: Boolean): Symbol { + val structureSymbol = symbolProvider.toSymbol(this) + val builderNamespace = RustReservedWords.escapeIfNeeded(structureSymbol.name.toSnakeCase()) + + if (pubCrate) { + "_internal" + } else { + "" + } + val rustType = RustType.Opaque("Builder", "${structureSymbol.namespace}::$builderNamespace") + return Symbol.builder() + .rustType(rustType) + .name(rustType.name) + .namespace(rustType.namespace, "::") + .definitionFile(structureSymbol.definitionFile) + .build() +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt index 323a40ebf0..1514750cda 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt @@ -4,81 +4,111 @@ */ package software.amazon.smithy.rust.codegen.server.smithy.generators -import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.StringShape -import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator import software.amazon.smithy.rust.codegen.core.util.dq -import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType +import software.amazon.smithy.rust.codegen.core.util.expectTrait +import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput open class ServerEnumGenerator( - model: Model, - symbolProvider: RustSymbolProvider, + val codegenContext: ServerCodegenContext, private val writer: RustWriter, shape: StringShape, - enumTrait: EnumTrait, - private val runtimeConfig: RuntimeConfig, -) : EnumGenerator(model, symbolProvider, writer, shape, enumTrait) { +) : EnumGenerator(codegenContext.model, codegenContext.symbolProvider, writer, shape, shape.expectTrait()) { override var target: CodegenTarget = CodegenTarget.SERVER - private val errorStruct = "${enumName}UnknownVariantError" + + private val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes + private val constraintViolationSymbolProvider = + with(codegenContext.constraintViolationSymbolProvider) { + if (publicConstrainedTypes) { + this + } else { + PubCrateConstraintViolationSymbolProvider(this) + } + } + private val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape) + private val constraintViolationName = constraintViolationSymbol.name + private val codegenScope = arrayOf( + "String" to RuntimeType.String, + ) override fun renderFromForStr() { - writer.rust( - """ - ##[derive(Debug, PartialEq, Eq, Hash)] - pub struct $errorStruct(String); - """, - ) + writer.withModule( + RustModule.public(constraintViolationSymbol.namespace.split(constraintViolationSymbol.namespaceDelimiter).last()), + ) { + rustTemplate( + """ + ##[derive(Debug, PartialEq)] + pub struct $constraintViolationName(pub(crate) #{String}); + """, + *codegenScope, + ) + + if (shape.isReachableFromOperationInput()) { + val enumValueSet = enumTrait.enumDefinitionValues.joinToString(", ") + val message = "Value {} at '{}' failed to satisfy constraint: Member must satisfy enum value set: [$enumValueSet]" + + rustTemplate( + """ + impl $constraintViolationName { + pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { + crate::model::ValidationExceptionField { + message: format!(r##"$message"##, &self.0, &path), + path, + } + } + } + """, + *codegenScope, + ) + } + } writer.rustBlock("impl #T<&str> for $enumName", RuntimeType.TryFrom) { - write("type Error = $errorStruct;") - writer.rustBlock("fn try_from(s: &str) -> Result>::Error>", RuntimeType.TryFrom) { - writer.rustBlock("match s") { + rust("type Error = #T;", constraintViolationSymbol) + rustBlock("fn try_from(s: &str) -> Result>::Error>", RuntimeType.TryFrom) { + rustBlock("match s") { sortedMembers.forEach { member -> - write("${member.value.dq()} => Ok($enumName::${member.derivedName()}),") + rust("${member.value.dq()} => Ok($enumName::${member.derivedName()}),") } - write("_ => Err($errorStruct(s.to_owned()))") + rust("_ => Err(#T(s.to_owned()))", constraintViolationSymbol) } } } writer.rustTemplate( """ - impl #{From}<$errorStruct> for #{RequestRejection} { - fn from(e: $errorStruct) -> Self { - Self::EnumVariantNotFound(Box::new(e)) - } - } - impl #{StdError} for $errorStruct { } - impl #{Display} for $errorStruct { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.0.fmt(f) + impl #{TryFrom}<#{String}> for $enumName { + type Error = #{UnknownVariantSymbol}; + fn try_from(s: #{String}) -> std::result::Result>::Error> { + s.as_str().try_into() } } """, - "Display" to RuntimeType.Display, - "From" to RuntimeType.From, - "StdError" to RuntimeType.StdError, - "RequestRejection" to ServerRuntimeType.RequestRejection(runtimeConfig), + "String" to RuntimeType.String, + "TryFrom" to RuntimeType.TryFrom, + "UnknownVariantSymbol" to constraintViolationSymbol, ) } override fun renderFromStr() { - writer.rust( + writer.rustTemplate( """ impl std::str::FromStr for $enumName { - type Err = $errorStruct; - fn from_str(s: &str) -> std::result::Result { - $enumName::try_from(s) + type Err = #{UnknownVariantSymbol}; + fn from_str(s: &str) -> std::result::Result::Err> { + Self::try_from(s) } } """, + "UnknownVariantSymbol" to constraintViolationSymbol, ) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt index 9189380dde..13901e5213 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt @@ -6,11 +6,15 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.generators.Instantiator +import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput /** * Server enums do not have an `Unknown` variant like client enums do, so constructing an enum from @@ -24,11 +28,30 @@ private fun enumFromStringFn(enumSymbol: Symbol, data: String): Writable = writa ) } +class ServerBuilderKindBehavior(val codegenContext: CodegenContext) : Instantiator.BuilderKindBehavior { + override fun hasFallibleBuilder(shape: StructureShape): Boolean { + // Only operation input builders take in unconstrained types. + val takesInUnconstrainedTypes = shape.isReachableFromOperationInput() + return ServerBuilderGenerator.hasFallibleBuilder( + shape, + codegenContext.model, + codegenContext.symbolProvider, + takesInUnconstrainedTypes, + ) + } + + override fun setterName(memberShape: MemberShape): String = codegenContext.symbolProvider.toMemberName(memberShape) + + override fun doesSetterTakeInOption(memberShape: MemberShape): Boolean = + codegenContext.symbolProvider.toSymbol(memberShape).isOptional() +} + fun serverInstantiator(codegenContext: CodegenContext) = Instantiator( codegenContext.symbolProvider, codegenContext.model, codegenContext.runtimeConfig, + ServerBuilderKindBehavior(codegenContext), ::enumFromStringFn, defaultsForRequiredFields = true, ) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationGenerator.kt index b61b1baa45..b34685f7b8 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationGenerator.kt @@ -38,7 +38,8 @@ class ServerOperationGenerator( if (operation.errors.isEmpty()) { rust("std::convert::Infallible") } else { - rust("crate::error::${operationName}Error") + // Name comes from [ServerCombinedErrorGenerator]. + rust("crate::error::${symbolProvider.toSymbol(operation).name}Error") } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerStructureConstrainedTraitImpl.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerStructureConstrainedTraitImpl.kt new file mode 100644 index 0000000000..2812c59d74 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerStructureConstrainedTraitImpl.kt @@ -0,0 +1,32 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider + +class ServerStructureConstrainedTraitImpl( + private val symbolProvider: RustSymbolProvider, + private val publicConstrainedTypes: Boolean, + private val shape: StructureShape, + private val writer: RustWriter, +) { + fun render() { + writer.rustTemplate( + """ + impl #{ConstrainedTrait} for #{Structure} { + type Unconstrained = #{Builder}; + } + """, + "ConstrainedTrait" to RuntimeType.ConstrainedTrait(), + "Structure" to symbolProvider.toSymbol(shape), + "Builder" to shape.serverBuilderSymbol(symbolProvider, !publicConstrainedTypes), + ) + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGenerator.kt new file mode 100644 index 0000000000..602cbb7aaf --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGenerator.kt @@ -0,0 +1,139 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained +import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput + +/** + * Generates a Rust type for a constrained collection shape that is able to hold values for the corresponding + * _unconstrained_ shape. This type is a [RustType.Opaque] wrapper tuple newtype holding a `Vec`. Upon request parsing, + * server deserializers use this type to store the incoming values without enforcing the modeled constraints. Only after + * the full request has been parsed are constraints enforced, via the `impl TryFrom for + * ConstrainedSymbol`. + * + * This type is never exposed to the user; it is always `pub(crate)`. Only the deserializers use it. + * + * Consult [UnconstrainedShapeSymbolProvider] for more details and for an example. + */ +class UnconstrainedCollectionGenerator( + val codegenContext: ServerCodegenContext, + private val unconstrainedModuleWriter: RustWriter, + private val modelsModuleWriter: RustWriter, + val shape: CollectionShape, +) { + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + private val unconstrainedShapeSymbolProvider = codegenContext.unconstrainedShapeSymbolProvider + private val pubCrateConstrainedShapeSymbolProvider = codegenContext.pubCrateConstrainedShapeSymbolProvider + private val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes + private val constraintViolationSymbolProvider = + with(codegenContext.constraintViolationSymbolProvider) { + if (publicConstrainedTypes) { + this + } else { + PubCrateConstraintViolationSymbolProvider(this) + } + } + + fun render() { + check(shape.canReachConstrainedShape(model, symbolProvider)) + + val symbol = unconstrainedShapeSymbolProvider.toSymbol(shape) + val module = symbol.namespace.split(symbol.namespaceDelimiter).last() + val name = symbol.name + val innerShape = model.expectShape(shape.member.target) + val innerUnconstrainedSymbol = unconstrainedShapeSymbolProvider.toSymbol(innerShape) + val constrainedSymbol = pubCrateConstrainedShapeSymbolProvider.toSymbol(shape) + val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape) + val constraintViolationName = constraintViolationSymbol.name + val innerConstraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(innerShape) + + unconstrainedModuleWriter.withModule(RustModule(module, RustMetadata(visibility = Visibility.PUBCRATE))) { + rustTemplate( + """ + ##[derive(Debug, Clone)] + pub(crate) struct $name(pub(crate) std::vec::Vec<#{InnerUnconstrainedSymbol}>); + + impl From<$name> for #{MaybeConstrained} { + fn from(value: $name) -> Self { + Self::Unconstrained(value) + } + } + + impl #{TryFrom}<$name> for #{ConstrainedSymbol} { + type Error = #{ConstraintViolationSymbol}; + + fn try_from(value: $name) -> Result { + let res: Result<_, (usize, #{InnerConstraintViolationSymbol})> = value + .0 + .into_iter() + .enumerate() + .map(|(idx, inner)| inner.try_into().map_err(|inner_violation| (idx, inner_violation))) + .collect(); + res.map(Self) + .map_err(|(idx, inner_violation)| #{ConstraintViolationSymbol}(idx, inner_violation)) + } + } + """, + "InnerUnconstrainedSymbol" to innerUnconstrainedSymbol, + "InnerConstraintViolationSymbol" to innerConstraintViolationSymbol, + "ConstrainedSymbol" to constrainedSymbol, + "ConstraintViolationSymbol" to constraintViolationSymbol, + "MaybeConstrained" to constrainedSymbol.makeMaybeConstrained(), + "TryFrom" to RuntimeType.TryFrom, + ) + } + + val constraintViolationVisibility = if (publicConstrainedTypes) { + Visibility.PUBLIC + } else { + Visibility.PUBCRATE + } + modelsModuleWriter.withModule( + RustModule( + constraintViolationSymbol.namespace.split(constraintViolationSymbol.namespaceDelimiter).last(), + RustMetadata(visibility = constraintViolationVisibility), + ), + ) { + // The first component of the tuple struct is the index in the collection where the first constraint + // violation was found. + rustTemplate( + """ + ##[derive(Debug, PartialEq)] + pub struct $constraintViolationName( + pub(crate) usize, + pub(crate) #{InnerConstraintViolationSymbol} + ); + """, + "InnerConstraintViolationSymbol" to innerConstraintViolationSymbol, + ) + + if (shape.isReachableFromOperationInput()) { + rustTemplate( + """ + impl $constraintViolationName { + pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { + self.1.as_validation_exception_field(format!("{}/{}", path, &self.0)) + } + } + """, + "String" to RuntimeType.String, + ) + } + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGenerator.kt new file mode 100644 index 0000000000..4d47eb6229 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGenerator.kt @@ -0,0 +1,207 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.join +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained +import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained + +/** + * Generates a Rust type for a constrained map shape that is able to hold values for the corresponding + * _unconstrained_ shape. This type is a [RustType.Opaque] wrapper tuple newtype holding a `HashMap`. Upon request parsing, + * server deserializers use this type to store the incoming values without enforcing the modeled constraints. Only after + * the full request has been parsed are constraints enforced, via the `impl TryFrom for + * ConstrainedSymbol`. + * + * This type is never exposed to the user; it is always `pub(crate)`. Only the deserializers use it. + * + * Consult [UnconstrainedShapeSymbolProvider] for more details and for an example. + */ +class UnconstrainedMapGenerator( + val codegenContext: ServerCodegenContext, + private val unconstrainedModuleWriter: RustWriter, + val shape: MapShape, +) { + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + private val unconstrainedShapeSymbolProvider = codegenContext.unconstrainedShapeSymbolProvider + private val pubCrateConstrainedShapeSymbolProvider = codegenContext.pubCrateConstrainedShapeSymbolProvider + private val symbol = unconstrainedShapeSymbolProvider.toSymbol(shape) + private val name = symbol.name + private val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes + private val constraintViolationSymbolProvider = + with(codegenContext.constraintViolationSymbolProvider) { + if (publicConstrainedTypes) { + this + } else { + PubCrateConstraintViolationSymbolProvider(this) + } + } + private val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape) + private val constrainedShapeSymbolProvider = codegenContext.constrainedShapeSymbolProvider + private val constrainedSymbol = if (shape.isDirectlyConstrained(symbolProvider)) { + constrainedShapeSymbolProvider.toSymbol(shape) + } else { + pubCrateConstrainedShapeSymbolProvider.toSymbol(shape) + } + private val keyShape = model.expectShape(shape.key.target, StringShape::class.java) + private val valueShape = model.expectShape(shape.value.target) + + fun render() { + check(shape.canReachConstrainedShape(model, symbolProvider)) + + val module = symbol.namespace.split(symbol.namespaceDelimiter).last() + val keySymbol = unconstrainedShapeSymbolProvider.toSymbol(keyShape) + val valueSymbol = unconstrainedShapeSymbolProvider.toSymbol(valueShape) + + unconstrainedModuleWriter.withModule(RustModule(module, RustMetadata(visibility = Visibility.PUBCRATE))) { + rustTemplate( + """ + ##[derive(Debug, Clone)] + pub(crate) struct $name(pub(crate) std::collections::HashMap<#{KeySymbol}, #{ValueSymbol}>); + + impl From<$name> for #{MaybeConstrained} { + fn from(value: $name) -> Self { + Self::Unconstrained(value) + } + } + + """, + "KeySymbol" to keySymbol, + "ValueSymbol" to valueSymbol, + "MaybeConstrained" to constrainedSymbol.makeMaybeConstrained(), + ) + + renderTryFromUnconstrainedForConstrained(this) + } + } + + private fun renderTryFromUnconstrainedForConstrained(writer: RustWriter) { + writer.rustBlock("impl std::convert::TryFrom<$name> for #{T}", constrainedSymbol) { + rust("type Error = #T;", constraintViolationSymbol) + + rustBlock("fn try_from(value: $name) -> Result") { + if (isKeyConstrained(keyShape, symbolProvider) || isValueConstrained(valueShape, model, symbolProvider)) { + val resolveToNonPublicConstrainedValueType = + isValueConstrained(valueShape, model, symbolProvider) && + !valueShape.isDirectlyConstrained(symbolProvider) && + !valueShape.isStructureShape + val constrainedValueSymbol = if (resolveToNonPublicConstrainedValueType) { + pubCrateConstrainedShapeSymbolProvider.toSymbol(valueShape) + } else { + constrainedShapeSymbolProvider.toSymbol(valueShape) + } + + val constrainedKeySymbol = constrainedShapeSymbolProvider.toSymbol(keyShape) + val constrainKeyWritable = writable { + rustTemplate( + "let k: #{ConstrainedKeySymbol} = k.try_into().map_err(Self::Error::Key)?;", + "ConstrainedKeySymbol" to constrainedKeySymbol, + ) + } + val constrainValueWritable = writable { + rustTemplate( + """ + match #{ConstrainedValueSymbol}::try_from(v) { + Ok(v) => Ok((k, v)), + Err(inner_constraint_violation) => Err(Self::Error::Value(k, inner_constraint_violation)), + } + """, + "ConstrainedValueSymbol" to constrainedValueSymbol, + ) + } + val epilogueWritable = writable { rust("Ok((k, v))") } + + val constrainKVWritable = if ( + isKeyConstrained(keyShape, symbolProvider) && + isValueConstrained(valueShape, model, symbolProvider) + ) { + listOf(constrainKeyWritable, constrainValueWritable).join("\n") + } else if (isKeyConstrained(keyShape, symbolProvider)) { + listOf(constrainKeyWritable, epilogueWritable).join("\n") + } else if (isValueConstrained(valueShape, model, symbolProvider)) { + constrainValueWritable + } else { + epilogueWritable + } + + rustTemplate( + """ + let res: Result, Self::Error> = value.0 + .into_iter() + .map(|(k, v)| { + #{ConstrainKVWritable:W} + }) + .collect(); + let hm = res?; + """, + "ConstrainedKeySymbol" to constrainedKeySymbol, + "ConstrainedValueSymbol" to constrainedValueSymbol, + "ConstrainKVWritable" to constrainKVWritable, + ) + + val constrainedValueTypeIsNotFinalType = + resolveToNonPublicConstrainedValueType && shape.isDirectlyConstrained(symbolProvider) + if (constrainedValueTypeIsNotFinalType) { + // The map is constrained. Its value shape reaches a constrained shape, but the value shape itself + // is not directly constrained. The value shape must be an aggregate shape. But it is not a + // structure shape. So it must be a collection or map shape. In this case the type for the value + // shape that implements the `Constrained` trait _does not_ coincide with the regular type the user + // is exposed to. The former will be the `pub(crate)` wrapper tuple type created by a + // `Constrained*Generator`, whereas the latter will be an stdlib container type. Both types are + // isomorphic though, and we can convert between them using `From`, so that's what we do here. + // + // As a concrete example of this particular case, consider the model: + // + // ```smithy + // @length(min: 1) + // map Map { + // key: String, + // value: List, + // } + // + // list List { + // member: NiceString + // } + // + // @length(min: 1, max: 69) + // string NiceString + // ``` + rustTemplate( + """ + let hm: std::collections::HashMap<#{KeySymbol}, #{ValueSymbol}> = + hm.into_iter().map(|(k, v)| (k, v.into())).collect(); + """, + "KeySymbol" to symbolProvider.toSymbol(keyShape), + "ValueSymbol" to symbolProvider.toSymbol(valueShape), + ) + } + } else { + rust("let hm = value.0;") + } + + if (shape.isDirectlyConstrained(symbolProvider)) { + rust("Self::try_from(hm)") + } else { + rust("Ok(Self(hm))") + } + } + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGenerator.kt new file mode 100644 index 0000000000..dd470daf9e --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGenerator.kt @@ -0,0 +1,248 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.withBlock +import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained +import software.amazon.smithy.rust.codegen.core.smithy.makeRustBoxed +import software.amazon.smithy.rust.codegen.core.smithy.traits.RustBoxTrait +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.letIf +import software.amazon.smithy.rust.codegen.core.util.toPascalCase +import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained +import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput + +/** + * Generates a Rust type for a constrained union shape that is able to hold values for the corresponding _unconstrained_ + * shape. This type is a [RustType.Opaque] enum newtype, with each variant holding the corresponding unconstrained type. + * Upon request parsing, server deserializers use this type to store the incoming values without enforcing the modeled + * constraints. Only after the full request has been parsed are constraints enforced, via the `impl + * TryFrom for ConstrainedSymbol`. + * + * This type is never exposed to the user; it is always `pub(crate)`. Only the deserializers use it. + * + * Consult [UnconstrainedShapeSymbolProvider] for more details and for an example. + */ +class UnconstrainedUnionGenerator( + val codegenContext: ServerCodegenContext, + private val unconstrainedModuleWriter: RustWriter, + private val modelsModuleWriter: RustWriter, + val shape: UnionShape, +) { + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + private val pubCrateConstrainedShapeSymbolProvider = codegenContext.pubCrateConstrainedShapeSymbolProvider + private val unconstrainedShapeSymbolProvider = codegenContext.unconstrainedShapeSymbolProvider + private val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes + private val constraintViolationSymbolProvider = + with(codegenContext.constraintViolationSymbolProvider) { + if (publicConstrainedTypes) { + this + } else { + PubCrateConstraintViolationSymbolProvider(this) + } + } + private val symbol = unconstrainedShapeSymbolProvider.toSymbol(shape) + private val sortedMembers: List = shape.allMembers.values.sortedBy { symbolProvider.toMemberName(it) } + + fun render() { + check(shape.canReachConstrainedShape(model, symbolProvider)) + + val moduleName = symbol.namespace.split(symbol.namespaceDelimiter).last() + val name = symbol.name + val constrainedSymbol = pubCrateConstrainedShapeSymbolProvider.toSymbol(shape) + val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape) + val constraintViolationName = constraintViolationSymbol.name + + unconstrainedModuleWriter.withModule(RustModule(moduleName, RustMetadata(visibility = Visibility.PUBCRATE))) { + rustBlock( + """ + ##[allow(clippy::enum_variant_names)] + ##[derive(Debug, Clone)] + pub(crate) enum $name + """, + ) { + sortedMembers.forEach { member -> + rust( + "${unconstrainedShapeSymbolProvider.toMemberName(member)}(#T),", + unconstrainedShapeSymbolProvider.toSymbol(member), + ) + } + } + + rustTemplate( + """ + impl #{TryFrom}<$name> for #{ConstrainedSymbol} { + type Error = #{ConstraintViolationSymbol}; + + fn try_from(value: $name) -> Result { + #{body:W} + } + } + """, + "TryFrom" to RuntimeType.TryFrom, + "ConstrainedSymbol" to constrainedSymbol, + "ConstraintViolationSymbol" to constraintViolationSymbol, + "body" to generateTryFromUnconstrainedUnionImpl(), + ) + } + + modelsModuleWriter.rustTemplate( + """ + impl #{ConstrainedTrait} for #{ConstrainedSymbol} { + type Unconstrained = #{UnconstrainedSymbol}; + } + + impl From<#{UnconstrainedSymbol}> for #{MaybeConstrained} { + fn from(value: #{UnconstrainedSymbol}) -> Self { + Self::Unconstrained(value) + } + } + """, + "ConstrainedTrait" to RuntimeType.ConstrainedTrait(), + "MaybeConstrained" to constrainedSymbol.makeMaybeConstrained(), + "ConstrainedSymbol" to constrainedSymbol, + "UnconstrainedSymbol" to symbol, + ) + + val constraintViolationVisibility = if (publicConstrainedTypes) { + Visibility.PUBLIC + } else { + Visibility.PUBCRATE + } + modelsModuleWriter.withModule( + RustModule( + constraintViolationSymbol.namespace.split(constraintViolationSymbol.namespaceDelimiter).last(), + RustMetadata(visibility = constraintViolationVisibility), + ), + ) { + Attribute.Derives(setOf(RuntimeType.Debug, RuntimeType.PartialEq)).render(this) + rustBlock("pub${ if (constraintViolationVisibility == Visibility.PUBCRATE) " (crate)" else "" } enum $constraintViolationName") { + constraintViolations().forEach { renderConstraintViolation(this, it) } + } + + if (shape.isReachableFromOperationInput()) { + rustBlock("impl $constraintViolationName") { + rustBlockTemplate( + "pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField", + "String" to RuntimeType.String, + ) { + withBlock("match self {", "}") { + for (constraintViolation in constraintViolations()) { + rust("""Self::${constraintViolation.name()}(inner) => inner.as_validation_exception_field(path + "/${constraintViolation.forMember.memberName}"),""") + } + } + } + } + } + } + } + + data class ConstraintViolation(val forMember: MemberShape) { + fun name() = forMember.memberName.toPascalCase() + } + + private fun constraintViolations() = + sortedMembers + .filter { it.targetCanReachConstrainedShape(model, symbolProvider) } + .map { ConstraintViolation(it) } + + private fun renderConstraintViolation(writer: RustWriter, constraintViolation: ConstraintViolation) { + val targetShape = model.expectShape(constraintViolation.forMember.target) + + val constraintViolationSymbol = + constraintViolationSymbolProvider.toSymbol(targetShape) + // If the corresponding union's member is boxed, box this constraint violation symbol too. + .letIf(constraintViolation.forMember.hasTrait()) { + it.makeRustBoxed() + } + + writer.rust( + "${constraintViolation.name()}(#T),", + constraintViolationSymbol, + ) + } + + private fun generateTryFromUnconstrainedUnionImpl() = writable { + withBlock("Ok(", ")") { + withBlock("match value {", "}") { + sortedMembers.forEach { member -> + val memberName = unconstrainedShapeSymbolProvider.toMemberName(member) + withBlockTemplate( + "#{UnconstrainedUnion}::$memberName(unconstrained) => Self::$memberName(", + "),", + "UnconstrainedUnion" to symbol, + ) { + if (!member.canReachConstrainedShape(model, symbolProvider)) { + rust("unconstrained") + } else { + val targetShape = model.expectShape(member.target) + val resolveToNonPublicConstrainedType = + targetShape !is StructureShape && targetShape !is UnionShape && !targetShape.hasTrait() && + (!publicConstrainedTypes || !targetShape.isDirectlyConstrained(symbolProvider)) + + val (unconstrainedVar, boxIt) = if (member.hasTrait()) { + "(*unconstrained)" to ".map(Box::new).map_err(Box::new)" + } else { + "unconstrained" to "" + } + + if (resolveToNonPublicConstrainedType) { + val constrainedSymbol = + if (!publicConstrainedTypes && targetShape.isDirectlyConstrained(symbolProvider)) { + codegenContext.constrainedShapeSymbolProvider.toSymbol(targetShape) + } else { + pubCrateConstrainedShapeSymbolProvider.toSymbol(targetShape) + } + rustTemplate( + """ + { + let constrained: #{ConstrainedSymbol} = $unconstrainedVar + .try_into() + $boxIt + .map_err(Self::Error::${ConstraintViolation(member).name()})?; + constrained.into() + } + """, + "ConstrainedSymbol" to constrainedSymbol, + ) + } else { + rust( + """ + $unconstrainedVar + .try_into() + $boxIt + .map_err(Self::Error::${ConstraintViolation(member).name()})? + """, + ) + } + } + } + } + } + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt index 55ff0fbe73..f866d83e3a 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt @@ -5,21 +5,49 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators.http +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter +import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingSection import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType +import software.amazon.smithy.rust.codegen.core.smithy.mapRustType import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderSymbol +import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape class ServerRequestBindingGenerator( protocol: Protocol, - codegenContext: CodegenContext, + private val codegenContext: ServerCodegenContext, operationShape: OperationShape, ) { - private val httpBindingGenerator = HttpBindingGenerator(protocol, codegenContext, operationShape) + private fun serverBuilderSymbol(shape: StructureShape): Symbol = shape.serverBuilderSymbol( + codegenContext.symbolProvider, + !codegenContext.settings.codegenConfig.publicConstrainedTypes, + ) + private val httpBindingGenerator = + HttpBindingGenerator( + protocol, + codegenContext, + codegenContext.unconstrainedShapeSymbolProvider, + operationShape, + ::serverBuilderSymbol, + listOf( + ServerRequestAfterDeserializingIntoAHashMapOfHttpPrefixHeadersWrapInUnconstrainedMapHttpBindingCustomization( + codegenContext, + ), + ), + ) fun generateDeserializeHeaderFn(binding: HttpBindingDescriptor): RuntimeType = httpBindingGenerator.generateDeserializeHeaderFn(binding) @@ -39,3 +67,22 @@ class ServerRequestBindingGenerator( binding: HttpBindingDescriptor, ): RuntimeType = httpBindingGenerator.generateDeserializePrefixHeaderFn(binding) } + +/** + * A customization to, just after we've deserialized HTTP request headers bound to a map shape via `@httpPrefixHeaders`, + * wrap the `std::collections::HashMap` in an unconstrained type wrapper newtype. + */ +class ServerRequestAfterDeserializingIntoAHashMapOfHttpPrefixHeadersWrapInUnconstrainedMapHttpBindingCustomization(val codegenContext: ServerCodegenContext) : + HttpBindingCustomization() { + override fun section(section: HttpBindingSection): Writable = when (section) { + is HttpBindingSection.BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders -> emptySection + is HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders -> writable { + if (section.memberShape.targetCanReachConstrainedShape(codegenContext.model, codegenContext.unconstrainedShapeSymbolProvider)) { + rust( + "let out = out.map(#T);", + codegenContext.unconstrainedShapeSymbolProvider.toSymbol(section.memberShape).mapRustType { it.stripOuter() }, + ) + } + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt index 1967e4304c..020f558ce9 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt @@ -5,21 +5,67 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators.http +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape -import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol +import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingSection import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderSymbol +import software.amazon.smithy.rust.codegen.server.smithy.workingWithPublicConstrainedWrapperTupleType class ServerResponseBindingGenerator( protocol: Protocol, - codegenContext: CodegenContext, + private val codegenContext: ServerCodegenContext, operationShape: OperationShape, ) { - private val httpBindingGenerator = HttpBindingGenerator(protocol, codegenContext, operationShape) + private fun builderSymbol(shape: StructureShape): Symbol = shape.serverBuilderSymbol(codegenContext) + + private val httpBindingGenerator = + HttpBindingGenerator( + protocol, + codegenContext, + codegenContext.symbolProvider, + operationShape, + ::builderSymbol, + listOf( + ServerResponseBeforeIteratingOverMapBoundWithHttpPrefixHeadersUnwrapConstrainedMapHttpBindingCustomization( + codegenContext, + ), + ), + ) fun generateAddHeadersFn(shape: Shape): RuntimeType? = httpBindingGenerator.generateAddHeadersFn(shape, HttpMessageType.RESPONSE) } + +/** + * A customization to, just before we iterate over a _constrained_ map shape that is bound to HTTP response headers via + * `@httpPrefixHeaders`, unwrap the wrapper newtype and take a shared reference to the actual `std::collections::HashMap` + * within it. + */ +class ServerResponseBeforeIteratingOverMapBoundWithHttpPrefixHeadersUnwrapConstrainedMapHttpBindingCustomization(val codegenContext: ServerCodegenContext) : + HttpBindingCustomization() { + override fun section(section: HttpBindingSection): Writable = when (section) { + is HttpBindingSection.BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders -> writable { + if (workingWithPublicConstrainedWrapperTupleType( + section.shape, + codegenContext.model, + codegenContext.settings.codegenConfig.publicConstrainedTypes, + ) + ) { + rust("let ${section.variableName} = &${section.variableName}.0;") + } + } + is HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders -> emptySection + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index 2dea2153b2..86b33c617f 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -5,8 +5,11 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators.protocol +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.knowledge.TopDownIndex import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.asType import software.amazon.smithy.rust.codegen.core.rustlang.rust @@ -21,10 +24,22 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml +import software.amazon.smithy.rust.codegen.core.smithy.protocols.awsJsonFieldName +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserCustomization +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserSection +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.ReturnSymbolToParse +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.restJsonFieldName import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType +import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderSymbol import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerAwsJsonSerializerGenerator +import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerRestJsonSerializerGenerator +import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape private fun allOperations(codegenContext: CodegenContext): List { val index = TopDownIndex.of(codegenContext.model) @@ -79,9 +94,9 @@ interface ServerProtocol : Protocol { } class ServerAwsJsonProtocol( - codegenContext: CodegenContext, + private val serverCodegenContext: ServerCodegenContext, awsJsonVersion: AwsJsonVersion, -) : AwsJson(codegenContext, awsJsonVersion), ServerProtocol { +) : AwsJson(serverCodegenContext, awsJsonVersion), ServerProtocol { private val runtimeConfig = codegenContext.runtimeConfig private val codegenScope = arrayOf( "SmithyHttpServer" to ServerCargoDependency.SmithyHttpServer(runtimeConfig).asType(), @@ -89,11 +104,33 @@ class ServerAwsJsonProtocol( private val symbolProvider = codegenContext.symbolProvider private val service = codegenContext.serviceShape + override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { + fun builderSymbol(shape: StructureShape): Symbol = + shape.serverBuilderSymbol(serverCodegenContext) + fun returnSymbolToParse(shape: Shape): ReturnSymbolToParse = + if (shape.canReachConstrainedShape(codegenContext.model, symbolProvider)) { + ReturnSymbolToParse(serverCodegenContext.unconstrainedShapeSymbolProvider.toSymbol(shape), true) + } else { + ReturnSymbolToParse(codegenContext.symbolProvider.toSymbol(shape), false) + } + return JsonParserGenerator( + codegenContext, + httpBindingResolver, + ::awsJsonFieldName, + ::builderSymbol, + ::returnSymbolToParse, + listOf( + ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonParserCustomization(serverCodegenContext), + ), + ) + } + override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = - ServerAwsJsonSerializerGenerator(codegenContext, httpBindingResolver, awsJsonVersion) + ServerAwsJsonSerializerGenerator(serverCodegenContext, httpBindingResolver, awsJsonVersion) companion object { - fun fromCoreProtocol(awsJson: AwsJson): ServerAwsJsonProtocol = ServerAwsJsonProtocol(awsJson.codegenContext, awsJson.version) + fun fromCoreProtocol(awsJson: AwsJson): ServerAwsJsonProtocol = + ServerAwsJsonProtocol(awsJson.codegenContext as ServerCodegenContext, awsJson.version) } override fun markerStruct(): RuntimeType { @@ -203,12 +240,38 @@ private fun restRouterConstruction( } class ServerRestJsonProtocol( - codegenContext: CodegenContext, -) : RestJson(codegenContext), ServerProtocol { + private val serverCodegenContext: ServerCodegenContext, +) : RestJson(serverCodegenContext), ServerProtocol { val runtimeConfig = codegenContext.runtimeConfig + override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { + fun builderSymbol(shape: StructureShape): Symbol = + shape.serverBuilderSymbol(serverCodegenContext) + fun returnSymbolToParse(shape: Shape): ReturnSymbolToParse = + if (shape.canReachConstrainedShape(codegenContext.model, codegenContext.symbolProvider)) { + ReturnSymbolToParse(serverCodegenContext.unconstrainedShapeSymbolProvider.toSymbol(shape), true) + } else { + ReturnSymbolToParse(serverCodegenContext.symbolProvider.toSymbol(shape), false) + } + return JsonParserGenerator( + codegenContext, + httpBindingResolver, + ::restJsonFieldName, + ::builderSymbol, + ::returnSymbolToParse, + listOf( + ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonParserCustomization( + serverCodegenContext, + ), + ), + ) + } + + override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = + ServerRestJsonSerializerGenerator(serverCodegenContext, httpBindingResolver) + companion object { - fun fromCoreProtocol(restJson: RestJson): ServerRestJsonProtocol = ServerRestJsonProtocol(restJson.codegenContext) + fun fromCoreProtocol(restJson: RestJson): ServerRestJsonProtocol = ServerRestJsonProtocol(restJson.codegenContext as ServerCodegenContext) } override fun markerStruct() = ServerRuntimeType.Protocol("RestJson1", "rest_json_1", runtimeConfig) @@ -257,3 +320,22 @@ class ServerRestXmlProtocol( override fun serverContentTypeCheckNoModeledInput() = true } + +/** + * A customization to, just before we box a recursive member that we've deserialized into `Option`, convert it into + * `MaybeConstrained` if the target shape can reach a constrained shape. + */ +class ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonParserCustomization(val codegenContext: ServerCodegenContext) : + JsonParserCustomization() { + override fun section(section: JsonParserSection): Writable = when (section) { + is JsonParserSection.BeforeBoxingDeserializedMember -> writable { + // We're only interested in _structure_ member shapes that can reach constrained shapes. + if ( + codegenContext.model.expectShape(section.shape.container) is StructureShape && + section.shape.targetCanReachConstrainedShape(codegenContext.model, codegenContext.symbolProvider) + ) { + rust(".map(|x| x.into())") + } + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index ea8792e40a..3fbdc2ebc4 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -935,18 +935,8 @@ class ServerProtocolTestGenerator( FailingTest(RestJson, "RestJsonMalformedUnionNoFieldsSet", TestType.MalformedRequest), - // Tests involving constraint traits, which are not yet implemented. - // See https://github.com/awslabs/smithy-rs/pull/1342. - FailingTest(RestJsonValidation, "RestJsonMalformedEnumList_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumList_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumMapKey_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumMapKey_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumMapValue_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumMapValue_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumString_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumString_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumUnion_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumUnion_case1", TestType.MalformedRequest), + // Tests involving constraint traits, which are not yet fully implemented. + // See https://github.com/awslabs/smithy-rs/issues/1401. FailingTest(RestJsonValidation, "RestJsonMalformedLengthBlobOverride_case0", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedLengthBlobOverride_case1", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedLengthListOverride_case0", TestType.MalformedRequest), @@ -960,17 +950,8 @@ class ServerProtocolTestGenerator( FailingTest(RestJsonValidation, "RestJsonMalformedLengthBlob_case1", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedLengthList_case0", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedLengthList_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthListValue_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthListValue_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthMap_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthMap_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthMapKey_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthMapKey_case1", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedLengthMapValue_case0", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedLengthMapValue_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthString_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthString_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthString_case2", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedPatternListOverride_case0", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedPatternListOverride_case1", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedPatternMapKeyOverride_case0", TestType.MalformedRequest), @@ -1010,9 +991,6 @@ class ServerProtocolTestGenerator( FailingTest(RestJsonValidation, "RestJsonMalformedRangeFloat_case1", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedLengthMaxStringOverride", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedLengthMinStringOverride", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthQueryStringNoValue", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthMaxString", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthMinString", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedRangeMaxByteOverride", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedRangeMaxFloatOverride", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedRangeMinByteOverride", TestType.MalformedRequest), @@ -1021,10 +999,6 @@ class ServerProtocolTestGenerator( FailingTest(RestJsonValidation, "RestJsonMalformedRangeMaxFloat", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedRangeMinByte", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedRangeMinFloat", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedRequiredBodyExplicitNull", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedRequiredBodyUnset", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedRequiredHeaderUnset", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedRecursiveStructures", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedPatternSensitiveString", TestType.MalformedRequest), // Some tests for the S3 service (restXml). diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt index 8215d485ba..815608fb24 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt @@ -10,18 +10,18 @@ import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.escape import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.writable -import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolGeneratorFactory import software.amazon.smithy.rust.codegen.core.smithy.protocols.awsJsonFieldName -import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonCustomization -import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSection +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerCustomization import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerSection import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeIteratingOverMapJsonCustomization import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerAwsJsonProtocol import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol @@ -56,9 +56,9 @@ class ServerAwsJsonFactory(private val version: AwsJsonVersion) : * AwsJson requires errors to be serialized in server responses with an additional `__type` field. This * customization writes the right field depending on the version of the AwsJson protocol. */ -class ServerAwsJsonError(private val awsJsonVersion: AwsJsonVersion) : JsonCustomization() { - override fun section(section: JsonSection): Writable = when (section) { - is JsonSection.ServerError -> writable { +class ServerAwsJsonError(private val awsJsonVersion: AwsJsonVersion) : JsonSerializerCustomization() { + override fun section(section: JsonSerializerSection): Writable = when (section) { + is JsonSerializerSection.ServerError -> writable { if (section.structureShape.hasTrait()) { val typeId = when (awsJsonVersion) { // AwsJson 1.0 wants the whole shape ID (namespace#Shape). @@ -82,7 +82,7 @@ class ServerAwsJsonError(private val awsJsonVersion: AwsJsonVersion) : JsonCusto * https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#operation-error-serialization */ class ServerAwsJsonSerializerGenerator( - private val codegenContext: CodegenContext, + private val codegenContext: ServerCodegenContext, private val httpBindingResolver: HttpBindingResolver, private val awsJsonVersion: AwsJsonVersion, private val jsonSerializerGenerator: JsonSerializerGenerator = @@ -90,6 +90,6 @@ class ServerAwsJsonSerializerGenerator( codegenContext, httpBindingResolver, ::awsJsonFieldName, - customizations = listOf(ServerAwsJsonError(awsJsonVersion)), + customizations = listOf(ServerAwsJsonError(awsJsonVersion), BeforeIteratingOverMapJsonCustomization(codegenContext)), ), ) : StructuredDataSerializerGenerator by jsonSerializerGenerator diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 2cf86be5eb..4cc16278cb 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -14,6 +14,7 @@ import software.amazon.smithy.model.knowledge.HttpBindingIndex import software.amazon.smithy.model.node.ExpectationNotMetException import software.amazon.smithy.model.shapes.BooleanShape import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.NumberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape @@ -32,32 +33,29 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.asType import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock -import software.amazon.smithy.rust.codegen.core.rustlang.render import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization -import software.amazon.smithy.rust.codegen.core.smithy.extractSymbolFromOption -import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.TypeConversionGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.error.errorSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.MakeOperationGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTraitImplGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.smithy.mapRustType import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator -import software.amazon.smithy.rust.codegen.core.smithy.toOptional import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErrors import software.amazon.smithy.rust.codegen.core.smithy.wrapOptional @@ -74,16 +72,19 @@ import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType +import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.http.ServerRequestBindingGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.http.ServerResponseBindingGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderSymbol import java.util.logging.Logger /** * Implement operations' input parsing and output serialization. Protocols can plug their own implementations * and overrides by creating a protocol factory inheriting from this class and feeding it to the [ServerProtocolLoader]. - * See `ServerRestJsonFactory.kt` for more info. + * See `ServerRestJson.kt` for more info. */ class ServerHttpBoundProtocolGenerator( codegenContext: ServerCodegenContext, @@ -117,6 +118,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( ) : ProtocolTraitImplGenerator { private val logger = Logger.getLogger(javaClass.name) private val symbolProvider = codegenContext.symbolProvider + private val unconstrainedShapeSymbolProvider = codegenContext.unconstrainedShapeSymbolProvider private val model = codegenContext.model private val runtimeConfig = codegenContext.runtimeConfig private val httpBindingResolver = protocol.httpBindingResolver @@ -592,7 +594,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( * case it will generate response headers for the given error shape. * * It sets three groups of headers in order. Headers from one group take precedence over headers in a later group. - * 1. Headers bound by the `httpHeader` and `httpPrefixHeader` traits. + * 1. Headers bound by the `httpHeader` and `httpPrefixHeader` traits. = null * 2. The protocol-specific `Content-Type` header for the operation. * 3. Additional protocol-specific headers for errors, if [errorShape] is non-null. */ @@ -712,7 +714,10 @@ private class ServerHttpBoundProtocolTraitImplGenerator( val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape) val structuredDataParser = protocol.structuredDataParser(operationShape) Attribute.AllowUnusedMut.render(this) - rust("let mut input = #T::default();", inputShape.builderSymbol(symbolProvider)) + rust( + "let mut input = #T::default();", + inputShape.serverBuilderSymbol(codegenContext), + ) val parser = structuredDataParser.serverInputParser(operationShape) val noInputs = model.expectShape(operationShape.inputShape).expectTrait().originalId == null if (parser != null) { @@ -732,9 +737,21 @@ private class ServerHttpBoundProtocolTraitImplGenerator( val member = binding.member val parsedValue = serverRenderBindingParser(binding, operationShape, httpBindingGenerator, structuredDataParser) if (parsedValue != null) { - withBlock("input = input.${member.setterName()}(", ");") { - parsedValue(this) - } + rust("if let Some(value) = ") + parsedValue(this) + rust( + """ + { + input = input.${member.setterName()}(${ + if (symbolProvider.toSymbol(binding.member).isOptional()) { + "Some(value)" + } else { + "value" + } + }); + } + """, + ) } } serverRenderUriPathParser(this, operationShape) @@ -750,7 +767,13 @@ private class ServerHttpBoundProtocolTraitImplGenerator( ) } } - val err = if (StructureGenerator.hasFallibleBuilder(inputShape, symbolProvider)) { + val err = if (ServerBuilderGenerator.hasFallibleBuilder( + inputShape, + model, + symbolProvider, + takeInUnconstrainedTypes = true, + ) + ) { "?" } else "" rustTemplate("input.build()$err", *codegenScope) @@ -884,13 +907,13 @@ private class ServerHttpBoundProtocolTraitImplGenerator( .forEachIndexed { index, segment -> val binding = pathBindings.find { it.memberName == segment.content } if (binding != null && segment.isLabel) { - val deserializer = generateParseFn(binding, true) + val deserializer = generateParseStrFn(binding, true) rustTemplate( """ input = input.${binding.member.setterName()}( - ${symbolProvider.toOptional(binding.member, "#{deserializer}(m$index)?")} + #{deserializer}(m$index)? ); - """.trimIndent(), + """, *codegenScope, "deserializer" to deserializer, ) @@ -905,13 +928,6 @@ private class ServerHttpBoundProtocolTraitImplGenerator( // * a map of set of string. enum class QueryParamsTargetMapValueType { STRING, LIST, SET; - - fun asRustType(): RustType = - when (this) { - STRING -> RustType.String - LIST -> RustType.Vec(RustType.String) - SET -> RustType.HashSet(RustType.String) - } } private fun queryParamsTargetMapValueType(targetMapValue: Shape): QueryParamsTargetMapValueType = @@ -924,8 +940,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( } else { throw ExpectationNotMetException( """ - @httpQueryParams trait applied to non-supported target - $targetMapValue of type ${targetMapValue.type} + @httpQueryParams trait applied to non-supported target $targetMapValue of type ${targetMapValue.type} """.trimIndent(), targetMapValue.sourceLocation, ) @@ -947,9 +962,8 @@ private class ServerHttpBoundProtocolTraitImplGenerator( fun HttpBindingDescriptor.queryParamsBindingTargetMapValueType(): QueryParamsTargetMapValueType { check(this.location == HttpLocation.QUERY_PARAMS) - val queryParamsTarget = model.expectShape(this.member.target) - val mapTarget = queryParamsTarget.asMapShape().get() - return queryParamsTargetMapValueType(model.expectShape(mapTarget.value.target)) + val queryParamsTarget = model.expectShape(this.member.target, MapShape::class.java) + return queryParamsTargetMapValueType(model.expectShape(queryParamsTarget.value.target)) } with(writer) { @@ -962,11 +976,16 @@ private class ServerHttpBoundProtocolTraitImplGenerator( ) if (queryParamsBinding != null) { - rustTemplate( - "let mut query_params: #{HashMap} = #{HashMap}::new();", - "HashMap" to software.amazon.smithy.rust.codegen.core.rustlang.RustType.HashMap.RuntimeType, - ) + val target = model.expectShape(queryParamsBinding.member.target, MapShape::class.java) + val hasConstrainedTarget = target.canReachConstrainedShape(model, symbolProvider) + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) Here we only check the target shape; + // constraint traits on member shapes are not implemented yet. + val targetSymbol = unconstrainedShapeSymbolProvider.toSymbol(target) + withBlock("let mut query_params: #T = ", ";", targetSymbol) { + conditionalBlock("#T(", ")", conditional = hasConstrainedTarget, targetSymbol) { + rust("#T::new()", RustType.HashMap.RuntimeType) + } + } } val (queryBindingsTargettingCollection, queryBindingsTargettingSimple) = queryBindings.partition { model.expectShape(it.member.target) is CollectionShape } @@ -979,13 +998,13 @@ private class ServerHttpBoundProtocolTraitImplGenerator( rustBlock("for (k, v) in pairs") { queryBindingsTargettingSimple.forEach { - val deserializer = generateParseFn(it, false) + val deserializer = generateParseStrFn(it, false) val memberName = symbolProvider.toMemberName(it.member) rustTemplate( """ if !seen_$memberName && k == "${it.locationName}" { input = input.${it.member.setterName()}( - ${symbolProvider.toOptional(it.member, "#{deserializer}(&v)?")} + #{deserializer}(&v)? ); seen_$memberName = true; } @@ -993,22 +1012,20 @@ private class ServerHttpBoundProtocolTraitImplGenerator( "deserializer" to deserializer, ) } - queryBindingsTargettingCollection.forEach { - rustBlock("if k == ${it.locationName.dq()}") { + queryBindingsTargettingCollection.forEachIndexed { idx, it -> + rustBlock("${if (idx > 0) "else " else ""}if k == ${it.locationName.dq()}") { val targetCollectionShape = model.expectShape(it.member.target, CollectionShape::class.java) val memberShape = model.expectShape(targetCollectionShape.member.target) when { memberShape.isStringShape -> { - // NOTE: This path is traversed with or without @enum applied. The `try_from` is used - // as a common conversion. - rustTemplate( - """ - let v = <#{memberShape}>::try_from(v.as_ref())?; - """, - *codegenScope, - "memberShape" to symbolProvider.toSymbol(memberShape), - ) + if (queryParamsBinding != null) { + // If there's an `@httpQueryParams` binding, it will want to consume the parsed data + // too further down, so we need to clone it. + rust("let v = v.clone().into_owned();") + } else { + rust("let v = v.into_owned();") + } } memberShape.isTimestampShape -> { val index = HttpBindingIndex.of(model) @@ -1042,47 +1059,79 @@ private class ServerHttpBoundProtocolTraitImplGenerator( } if (queryParamsBinding != null) { + val target = model.expectShape(queryParamsBinding.member.target, MapShape::class.java) + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) Here we only check the target shape; + // constraint traits on member shapes are not implemented yet. + val hasConstrainedTarget = target.canReachConstrainedShape(model, symbolProvider) when (queryParamsBinding.queryParamsBindingTargetMapValueType()) { QueryParamsTargetMapValueType.STRING -> { - rust("query_params.entry(String::from(k)).or_insert_with(|| String::from(v));") - } else -> { - rustTemplate( - """ - let entry = query_params.entry(String::from(k)).or_default(); - entry.push(String::from(v)); - """.trimIndent(), - ) + rust("query_params.${if (hasConstrainedTarget) "0." else ""}entry(String::from(k)).or_insert_with(|| String::from(v));") + } + QueryParamsTargetMapValueType.LIST, QueryParamsTargetMapValueType.SET -> { + if (hasConstrainedTarget) { + val collectionShape = model.expectShape(target.value.target, CollectionShape::class.java) + val collectionSymbol = unconstrainedShapeSymbolProvider.toSymbol(collectionShape) + rust( + // `or_insert_with` instead of `or_insert` to avoid the allocation when the entry is + // not empty. + """ + let entry = query_params.0.entry(String::from(k)).or_insert_with(|| #T(std::vec::Vec::new())); + entry.0.push(String::from(v)); + """, + collectionSymbol, + ) + } else { + rust( + """ + let entry = query_params.entry(String::from(k)).or_default(); + entry.push(String::from(v)); + """, + ) + } } } } } if (queryParamsBinding != null) { - rust("input = input.${queryParamsBinding.member.setterName()}(Some(query_params));") + val isOptional = unconstrainedShapeSymbolProvider.toSymbol(queryParamsBinding.member).isOptional() + withBlock("input = input.${queryParamsBinding.member.setterName()}(", ");") { + conditionalBlock("Some(", ")", conditional = isOptional) { + write("query_params") + } + } } - queryBindingsTargettingCollection.forEach { - val memberName = symbolProvider.toMemberName(it.member) - rustTemplate( - """ - input = input.${it.member.setterName()}( - if $memberName.is_empty() { - None - } else { - Some($memberName) + queryBindingsTargettingCollection.forEach { binding -> + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) Constraint traits on member shapes are not + // implemented yet. + val hasConstrainedTarget = + model.expectShape(binding.member.target, CollectionShape::class.java).canReachConstrainedShape(model, symbolProvider) + val memberName = unconstrainedShapeSymbolProvider.toMemberName(binding.member) + val isOptional = unconstrainedShapeSymbolProvider.toSymbol(binding.member).isOptional() + rustBlock("if !$memberName.is_empty()") { + withBlock( + "input = input.${ + binding.member.setterName() + }(", + ");", + ) { + conditionalBlock("Some(", ")", conditional = isOptional) { + conditionalBlock( + "#T(", + ")", + conditional = hasConstrainedTarget, + unconstrainedShapeSymbolProvider.toSymbol(binding.member).mapRustType { it.stripOuter() }, + ) { + write(memberName) + } } - ); - """.trimIndent(), - ) + } + } } } } private fun serverRenderHeaderParser(writer: RustWriter, binding: HttpBindingDescriptor, operationShape: OperationShape) { - val httpBindingGenerator = - ServerRequestBindingGenerator( - protocol, - codegenContext, - operationShape, - ) + val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape) val deserializer = httpBindingGenerator.generateDeserializeHeaderFn(binding) writer.rustTemplate( """ @@ -1096,12 +1145,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( private fun serverRenderPrefixHeadersParser(writer: RustWriter, binding: HttpBindingDescriptor, operationShape: OperationShape) { check(binding.location == HttpLocation.PREFIX_HEADERS) - val httpBindingGenerator = - ServerRequestBindingGenerator( - protocol, - codegenContext, - operationShape, - ) + val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape) val deserializer = httpBindingGenerator.generateDeserializePrefixHeadersFn(binding) writer.rustTemplate( """ @@ -1112,10 +1156,9 @@ private class ServerHttpBoundProtocolTraitImplGenerator( ) } - private fun generateParseFn(binding: HttpBindingDescriptor, percentDecoding: Boolean): RuntimeType { - val output = symbolProvider.toSymbol(binding.member) + private fun generateParseStrFn(binding: HttpBindingDescriptor, percentDecoding: Boolean): RuntimeType { + val output = unconstrainedShapeSymbolProvider.toSymbol(binding.member) val fnName = generateParseStrFnName(binding) - val symbol = output.extractSymbolFromOption() return RuntimeType.forInlineFun(fnName, operationDeserModule) { rustBlockTemplate( "pub fn $fnName(value: &str) -> std::result::Result<#{O}, #{RequestRejection}>", @@ -1126,24 +1169,15 @@ private class ServerHttpBoundProtocolTraitImplGenerator( when { target.isStringShape -> { - // NOTE: This path is traversed with or without @enum applied. The `try_from` is used as a - // common conversion. if (percentDecoding) { rustTemplate( """ - let value = #{PercentEncoding}::percent_decode_str(value).decode_utf8()?; - let value = #{T}::try_from(value.as_ref())?; + let value = #{PercentEncoding}::percent_decode_str(value).decode_utf8()?.into_owned(); """, *codegenScope, - "T" to symbol, ) } else { - rustTemplate( - """ - let value = #{T}::try_from(value)?; - """, - "T" to symbol, - ) + rust("let value = value.to_owned();") } } target.isTimestampShape -> { @@ -1187,7 +1221,6 @@ private class ServerHttpBoundProtocolTraitImplGenerator( ) } } - rust( """ Ok(${symbolProvider.wrapOptional(binding.member, "value")}) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJsonFactory.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJson.kt similarity index 63% rename from codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJsonFactory.kt rename to codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJson.kt index d3d0fea631..a913b806d2 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJsonFactory.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJson.kt @@ -6,9 +6,14 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport +import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolGeneratorFactory +import software.amazon.smithy.rust.codegen.core.smithy.protocols.restJsonFieldName +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeIteratingOverMapJsonCustomization import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerRestJsonProtocol /** @@ -36,3 +41,15 @@ class ServerRestJsonFactory : ProtocolGeneratorFactory { + shape.hasTrait() + } else -> PANIC("this method does not support shape type ${shape.type}") +} + +fun StringShape.isReachableFromOperationInput() = isShapeReachableFromOperationInput(this) +fun StructureShape.isReachableFromOperationInput() = isShapeReachableFromOperationInput(this) +fun CollectionShape.isReachableFromOperationInput() = isShapeReachableFromOperationInput(this) +fun UnionShape.isReachableFromOperationInput() = isShapeReachableFromOperationInput(this) +fun MapShape.isReachableFromOperationInput() = isShapeReachableFromOperationInput(this) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/AttachValidationExceptionToConstrainedOperationInputsInAllowList.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/AttachValidationExceptionToConstrainedOperationInputsInAllowList.kt new file mode 100644 index 0000000000..257bce1f0e --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/AttachValidationExceptionToConstrainedOperationInputsInAllowList.kt @@ -0,0 +1,74 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.transformers + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.neighbor.Walker +import software.amazon.smithy.model.shapes.EnumShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.SetShape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.transform.ModelTransformer +import software.amazon.smithy.rust.codegen.core.util.inputShape +import software.amazon.smithy.rust.codegen.server.smithy.hasConstraintTrait + +/** + * Attach the `smithy.framework#ValidationException` error to operations whose inputs are constrained, if they belong + * to a service in an allowlist. + * + * Some of the models we generate in CI have constrained operation inputs, but the operations don't have + * `smithy.framework#ValidationException` in their list of errors. This is a codegen error, unless + * `disableDefaultValidation` is set to `true`, a code generation mode we don't support yet. See [1] for more details. + * Until we implement said mode, we manually attach the error to build these models, since we don't own them (they're + * either actual AWS service model excerpts, or they come from the `awslabs/smithy` library. + * + * [1]: https://github.com/awslabs/smithy-rs/pull/1199#discussion_r809424783 + * + * TODO(https://github.com/awslabs/smithy-rs/issues/1401): This transformer will go away once we add support for + * `disableDefaultValidation` set to `true`, allowing service owners to map from constraint violations to operation errors. + */ +object AttachValidationExceptionToConstrainedOperationInputsInAllowList { + private val sherviceShapeIdAllowList = + setOf( + // These we currently generate server SDKs for. + ShapeId.from("aws.protocoltests.restjson#RestJson"), + ShapeId.from("aws.protocoltests.json10#JsonRpc10"), + ShapeId.from("aws.protocoltests.json#JsonProtocol"), + ShapeId.from("com.amazonaws.s3#AmazonS3"), + ShapeId.from("com.amazonaws.ebs#Ebs"), + + // These are only loaded in the classpath and need this model transformer, but we don't generate server + // SDKs for them. Here they are for reference. + // ShapeId.from("aws.protocoltests.restxml#RestXml"), + // ShapeId.from("com.amazonaws.glacier#Glacier"), + // ShapeId.from("aws.protocoltests.ec2#AwsEc2"), + // ShapeId.from("aws.protocoltests.query#AwsQuery"), + // ShapeId.from("com.amazonaws.machinelearning#AmazonML_20141212"), + ) + + fun transform(model: Model): Model { + val walker = Walker(model) + + val operationsWithConstrainedInputWithoutValidationException = model.serviceShapes + .filter { sherviceShapeIdAllowList.contains(it.toShapeId()) } + .flatMap { it.operations } + .map { model.expectShape(it, OperationShape::class.java) } + .filter { operationShape -> + // Walk the shapes reachable via this operation input. + walker.walkShapes(operationShape.inputShape(model)) + .any { it is SetShape || it is EnumShape || it.hasConstraintTrait() } + } + .filter { !it.errors.contains(ShapeId.from("smithy.framework#ValidationException")) } + + return ModelTransformer.create().mapShapes(model) { shape -> + if (shape is OperationShape && operationsWithConstrainedInputWithoutValidationException.contains(shape)) { + shape.toBuilder().addError("smithy.framework#ValidationException").build() + } else { + shape + } + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/RemoveEbsModelValidationException.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/RemoveEbsModelValidationException.kt new file mode 100644 index 0000000000..d4c6feaed6 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/RemoveEbsModelValidationException.kt @@ -0,0 +1,38 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.transformers + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.transform.ModelTransformer +import software.amazon.smithy.rust.codegen.core.util.orNull + +/** + * The Amazon Elastic Block Store (Amazon EBS) model is one model that we generate in CI. + * Unfortunately, it defines its own `ValidationException` shape, that conflicts with + * `smithy.framework#ValidationException` [0]. + * + * So this is a model that a service owner would generate when "disabling default validation": in such a code generation + * mode, the service owner is responsible for mapping an operation input-level constraint violation into a modeled + * operation error. This mode, as well as what the end goal for validation exception responses looks like, is described + * in more detail in [1]. We don't support this mode yet. + * + * So this transformer simply removes the EBB model's `ValidationException`. A subsequent model transformer, + * [AttachValidationExceptionToConstrainedOperationInputsInAllowList], ensures that it is replaced by + * `smithy.framework#ValidationException`. + * + * [0]: https://github.com/awslabs/smithy-rs/blob/274adf155042cde49251a0e6b8842d6f56cd5b6d/codegen-core/common-test-models/ebs.json#L1270-L1288 + * [1]: https://github.com/awslabs/smithy-rs/pull/1199#discussion_r809424783 + * + * TODO(https://github.com/awslabs/smithy-rs/issues/1401): This transformer will go away once we implement + * `disableDefaultValidation` set to `true`, allowing service owners to map from constraint violations to operation errors. + */ +object RemoveEbsModelValidationException { + fun transform(model: Model): Model { + val shapeToRemove = model.getShape(ShapeId.from("com.amazonaws.ebs#ValidationException")).orNull() + return ModelTransformer.create().removeShapes(model, listOfNotNull(shapeToRemove)) + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ShapesReachableFromOperationInputTagger.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ShapesReachableFromOperationInputTagger.kt new file mode 100644 index 0000000000..cf58f3f9d9 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ShapesReachableFromOperationInputTagger.kt @@ -0,0 +1,72 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.transformers + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.neighbor.Walker +import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.transform.ModelTransformer +import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE +import software.amazon.smithy.rust.codegen.server.smithy.traits.ShapeReachableFromOperationInputTagTrait + +/** + * Tag shapes reachable from operation input with the + * [ShapeReachableFromOperationInputTagTrait] tag. + * + * This is useful to determine whether we need to generate code to + * enforce constraints upon request deserialization in the server. + * + * This needs to be a model transformer; it cannot be lazily calculated + * when needed. This is because other model transformers may transform + * the model such that shapes that were reachable from operation + * input are no longer so. For example, [EventStreamNormalizer] pulls + * event stream error variants out of the union shape where they are defined. + * As such, [ShapesReachableFromOperationInputTagger] needs to run + * before these model transformers. + * + * WARNING: This transformer tags _all_ [aggregate shapes], and _some_ [simple shapes], + * but not all of them. Read the implementation to find out what shape types it + * currently tags. + * + * [simple shapes]: https://awslabs.github.io/smithy/2.0/spec/simple-types.html + * [aggregate shapes]: https://awslabs.github.io/smithy/2.0/spec/aggregate-types.html#aggregate-types + */ +object ShapesReachableFromOperationInputTagger { + fun transform(model: Model): Model { + val inputShapes = model.operationShapes.map { + model.expectShape(it.inputShape, StructureShape::class.java) + } + val walker = Walker(model) + val shapesReachableFromOperationInputs = inputShapes + .flatMap { walker.walkShapes(it) } + .toSet() + + return ModelTransformer.create().mapShapes(model) { shape -> + when (shape) { + is StructureShape, is UnionShape, is ListShape, is MapShape, is StringShape -> { + if (shapesReachableFromOperationInputs.contains(shape)) { + val builder = when (shape) { + is StructureShape -> shape.toBuilder() + is UnionShape -> shape.toBuilder() + is ListShape -> shape.toBuilder() + is MapShape -> shape.toBuilder() + is StringShape -> shape.toBuilder() + else -> UNREACHABLE("the `when` is exhaustive") + } + builder.addTrait(ShapeReachableFromOperationInputTagTrait()).build() + } else { + shape + } + } + else -> shape + } + } + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProviderTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProviderTest.kt new file mode 100644 index 0000000000..bcf7fe34ce --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProviderTest.kt @@ -0,0 +1,98 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import io.kotest.matchers.shouldBe +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.smithy.rustType +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider + +const val baseModelString = + """ + namespace test + + service TestService { + version: "123", + operations: [TestOperation] + } + + operation TestOperation { + input: TestInputOutput, + output: TestInputOutput, + } + + structure TestInputOutput { + constrainedString: ConstrainedString, + constrainedMap: ConstrainedMap, + unconstrainedMap: TransitivelyConstrainedMap + } + + @length(min: 1, max: 69) + string ConstrainedString + + string UnconstrainedString + + @length(min: 1, max: 69) + map ConstrainedMap { + key: String, + value: String + } + + map TransitivelyConstrainedMap { + key: String, + value: ConstrainedMap + } + + @length(min: 1, max: 69) + list ConstrainedCollection { + member: String + } + """ + +class ConstrainedShapeSymbolProviderTest { + private val model = baseModelString.asSmithyModel() + private val serviceShape = model.lookup("test#TestService") + private val symbolProvider = serverTestSymbolProvider(model, serviceShape) + private val constrainedShapeSymbolProvider = ConstrainedShapeSymbolProvider(symbolProvider, model, serviceShape) + + private val constrainedMapShape = model.lookup("test#ConstrainedMap") + private val constrainedMapType = constrainedShapeSymbolProvider.toSymbol(constrainedMapShape).rustType() + + @Test + fun `it should return a constrained string type for a constrained string shape`() { + val constrainedStringShape = model.lookup("test#ConstrainedString") + val constrainedStringType = constrainedShapeSymbolProvider.toSymbol(constrainedStringShape).rustType() + + constrainedStringType shouldBe RustType.Opaque("ConstrainedString", "crate::model") + } + + @Test + fun `it should return a constrained map type for a constrained map shape`() { + constrainedMapType shouldBe RustType.Opaque("ConstrainedMap", "crate::model") + } + + @Test + fun `it should not blindly delegate to the base symbol provider when the shape is an aggregate shape and is not directly constrained`() { + val unconstrainedMapShape = model.lookup("test#TransitivelyConstrainedMap") + val unconstrainedMapType = constrainedShapeSymbolProvider.toSymbol(unconstrainedMapShape).rustType() + + unconstrainedMapType shouldBe RustType.HashMap(RustType.String, constrainedMapType) + } + + @Test + fun `it should delegate to the base symbol provider for unconstrained simple shapes`() { + val unconstrainedStringShape = model.lookup("test#UnconstrainedString") + val unconstrainedStringSymbol = constrainedShapeSymbolProvider.toSymbol(unconstrainedStringShape) + + unconstrainedStringSymbol shouldBe symbolProvider.toSymbol(unconstrainedStringShape) + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt new file mode 100644 index 0000000000..80e2d93dae --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt @@ -0,0 +1,135 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import io.kotest.inspectors.forAll +import io.kotest.matchers.shouldBe +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider + +class ConstraintsTest { + private val model = + """ + namespace test + + service TestService { + version: "123", + operations: [TestOperation] + } + + operation TestOperation { + input: TestInputOutput, + output: TestInputOutput, + } + + structure TestInputOutput { + map: MapA, + + recursive: RecursiveShape + } + + structure RecursiveShape { + shape: RecursiveShape, + mapB: MapB + } + + @length(min: 1, max: 69) + map MapA { + key: String, + value: MapB + } + + map MapB { + key: String, + value: StructureA + } + + @uniqueItems + list ListA { + member: MyString + } + + @pattern("\\w+") + string MyString + + @length(min: 1, max: 69) + string LengthString + + structure StructureA { + @range(min: 1, max: 69) + int: Integer, + + @required + string: String + } + + // This shape is not in the service closure. + structure StructureB { + @pattern("\\w+") + patternString: String, + + @required + requiredString: String, + + mapA: MapA, + + @length(min: 1, max: 5) + mapAPrecedence: MapA + } + """.asSmithyModel() + private val symbolProvider = serverTestSymbolProvider(model) + + private val testInputOutput = model.lookup("test#TestInputOutput") + private val recursiveShape = model.lookup("test#RecursiveShape") + private val mapA = model.lookup("test#MapA") + private val mapB = model.lookup("test#MapB") + private val listA = model.lookup("test#ListA") + private val myString = model.lookup("test#MyString") + private val lengthString = model.lookup("test#LengthString") + private val structA = model.lookup("test#StructureA") + private val structAInt = model.lookup("test#StructureA\$int") + private val structAString = model.lookup("test#StructureA\$string") + + @Test + fun `it should not recognize uniqueItems as a constraint trait because it's deprecated`() { + listA.isDirectlyConstrained(symbolProvider) shouldBe false + } + + @Test + fun `it should detect supported constrained traits as constrained`() { + listOf(mapA, structA, lengthString).forAll { + it.isDirectlyConstrained(symbolProvider) shouldBe true + } + } + + @Test + fun `it should not detect unsupported constrained traits as constrained`() { + listOf(structAInt, structAString, myString).forAll { + it.isDirectlyConstrained(symbolProvider) shouldBe false + } + } + + @Test + fun `it should evaluate reachability of constrained shapes`() { + mapA.canReachConstrainedShape(model, symbolProvider) shouldBe true + structAInt.canReachConstrainedShape(model, symbolProvider) shouldBe false + + // This should be true when we start supporting the `pattern` trait on string shapes. + listA.canReachConstrainedShape(model, symbolProvider) shouldBe false + + // All of these eventually reach `StructureA`, which is constrained because one of its members is `required`. + testInputOutput.canReachConstrainedShape(model, symbolProvider) shouldBe true + mapB.canReachConstrainedShape(model, symbolProvider) shouldBe true + recursiveShape.canReachConstrainedShape(model, symbolProvider) shouldBe true + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProviderTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProviderTest.kt new file mode 100644 index 0000000000..21baefe747 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProviderTest.kt @@ -0,0 +1,113 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import io.kotest.assertions.throwables.shouldThrow +import io.kotest.matchers.shouldBe +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.smithy.rustType +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProviders + +class PubCrateConstrainedShapeSymbolProviderTest { + private val model = """ + $baseModelString + + list TransitivelyConstrainedCollection { + member: Structure + } + + structure Structure { + @required + requiredMember: String + } + + structure StructureWithMemberTargetingAggregateShape { + member: TransitivelyConstrainedCollection + } + + union Union { + structure: Structure + } + """.asSmithyModel() + + private val serverTestSymbolProviders = serverTestSymbolProviders(model) + private val symbolProvider = serverTestSymbolProviders.symbolProvider + private val pubCrateConstrainedShapeSymbolProvider = serverTestSymbolProviders.pubCrateConstrainedShapeSymbolProvider + + @Test + fun `it should crash when provided with a shape that is directly constrained`() { + val constrainedStringShape = model.lookup("test#ConstrainedString") + shouldThrow { pubCrateConstrainedShapeSymbolProvider.toSymbol(constrainedStringShape) } + } + + @Test + fun `it should crash when provided with a shape that is unconstrained`() { + val unconstrainedStringShape = model.lookup("test#UnconstrainedString") + shouldThrow { pubCrateConstrainedShapeSymbolProvider.toSymbol(unconstrainedStringShape) } + } + + @Test + fun `it should return an opaque type for transitively constrained collection shapes`() { + val transitivelyConstrainedCollectionShape = model.lookup("test#TransitivelyConstrainedCollection") + val transitivelyConstrainedCollectionType = + pubCrateConstrainedShapeSymbolProvider.toSymbol(transitivelyConstrainedCollectionShape).rustType() + + transitivelyConstrainedCollectionType shouldBe RustType.Opaque( + "TransitivelyConstrainedCollectionConstrained", + "crate::constrained::transitively_constrained_collection_constrained", + ) + } + + @Test + fun `it should return an opaque type for transitively constrained map shapes`() { + val transitivelyConstrainedMapShape = model.lookup("test#TransitivelyConstrainedMap") + val transitivelyConstrainedMapType = + pubCrateConstrainedShapeSymbolProvider.toSymbol(transitivelyConstrainedMapShape).rustType() + + transitivelyConstrainedMapType shouldBe RustType.Opaque( + "TransitivelyConstrainedMapConstrained", + "crate::constrained::transitively_constrained_map_constrained", + ) + } + + @Test + fun `it should not blindly delegate to the base symbol provider when provided with a transitively constrained structure member shape targeting an aggregate shape`() { + val memberShape = model.lookup("test#StructureWithMemberTargetingAggregateShape\$member") + val memberType = pubCrateConstrainedShapeSymbolProvider.toSymbol(memberShape).rustType() + + memberType shouldBe RustType.Option( + RustType.Opaque( + "TransitivelyConstrainedCollectionConstrained", + "crate::constrained::transitively_constrained_collection_constrained", + ), + ) + } + + @Test + fun `it should delegate to the base symbol provider when provided with a structure shape`() { + val structureShape = model.lookup("test#TestInputOutput") + val structureSymbol = pubCrateConstrainedShapeSymbolProvider.toSymbol(structureShape) + + structureSymbol shouldBe symbolProvider.toSymbol(structureShape) + } + + @Test + fun `it should delegate to the base symbol provider when provided with a union shape`() { + val unionShape = model.lookup("test#Union") + val unionSymbol = pubCrateConstrainedShapeSymbolProvider.toSymbol(unionShape) + + unionSymbol shouldBe symbolProvider.toSymbol(unionShape) + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/UnconstrainedShapeSymbolProviderTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/UnconstrainedShapeSymbolProviderTest.kt new file mode 100644 index 0000000000..7c8efe9c17 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/UnconstrainedShapeSymbolProviderTest.kt @@ -0,0 +1,103 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import io.kotest.matchers.shouldBe +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.rustlang.render +import software.amazon.smithy.rust.codegen.core.smithy.rustType +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProviders + +/** + * While [UnconstrainedShapeSymbolProvider] _must_ be in the `codegen` subproject, these tests need to be in the + * `codegen-server` subproject, because they use [serverTestSymbolProvider]. + */ +class UnconstrainedShapeSymbolProviderTest { + private val baseModelString = + """ + namespace test + + service TestService { + version: "123", + operations: [TestOperation] + } + + operation TestOperation { + input: TestInputOutput, + output: TestInputOutput, + } + + structure TestInputOutput { + list: ListA + } + """ + + @Test + fun `it should adjust types for unconstrained shapes`() { + val model = + """ + $baseModelString + + list ListA { + member: ListB + } + + list ListB { + member: StructureC + } + + structure StructureC { + @required + string: String + } + """.asSmithyModel() + + val unconstrainedShapeSymbolProvider = serverTestSymbolProviders(model).unconstrainedShapeSymbolProvider + + val listAShape = model.lookup("test#ListA") + val listAType = unconstrainedShapeSymbolProvider.toSymbol(listAShape).rustType() + + val listBShape = model.lookup("test#ListB") + val listBType = unconstrainedShapeSymbolProvider.toSymbol(listBShape).rustType() + + val structureCShape = model.lookup("test#StructureC") + val structureCType = unconstrainedShapeSymbolProvider.toSymbol(structureCShape).rustType() + + listAType shouldBe RustType.Opaque("ListAUnconstrained", "crate::unconstrained::list_a_unconstrained") + listBType shouldBe RustType.Opaque("ListBUnconstrained", "crate::unconstrained::list_b_unconstrained") + structureCType shouldBe RustType.Opaque("Builder", "crate::model::structure_c") + } + + @Test + fun `it should delegate to the base symbol provider if called with a shape that cannot reach a constrained shape`() { + val model = + """ + $baseModelString + + list ListA { + member: StructureB + } + + structure StructureB { + string: String + } + """.asSmithyModel() + + val unconstrainedShapeSymbolProvider = serverTestSymbolProviders(model).unconstrainedShapeSymbolProvider + + val listAShape = model.lookup("test#ListA") + val structureBShape = model.lookup("test#StructureB") + + unconstrainedShapeSymbolProvider.toSymbol(structureBShape).rustType().render() shouldBe "crate::model::StructureB" + unconstrainedShapeSymbolProvider.toSymbol(listAShape).rustType().render() shouldBe "std::vec::Vec" + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraintsAreNotUsedTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraintsAreNotUsedTest.kt new file mode 100644 index 0000000000..0624358ba3 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraintsAreNotUsedTest.kt @@ -0,0 +1,254 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import io.kotest.inspectors.forSome +import io.kotest.inspectors.shouldForAll +import io.kotest.matchers.collections.shouldHaveAtLeastSize +import io.kotest.matchers.collections.shouldHaveSize +import io.kotest.matchers.shouldBe +import io.kotest.matchers.string.shouldContain +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.util.lookup +import java.util.logging.Level + +internal class ValidateUnsupportedConstraintsAreNotUsedTest { + private val baseModel = + """ + namespace test + + service TestService { + version: "123", + operations: [TestOperation] + } + + operation TestOperation { + input: TestInputOutput, + output: TestInputOutput, + } + """ + + private fun validateModel(model: Model, serverCodegenConfig: ServerCodegenConfig = ServerCodegenConfig()): ValidationResult { + val service = model.lookup("test#TestService") + return validateUnsupportedConstraints(model, service, serverCodegenConfig) + } + + @Test + fun `it should detect when an operation with constrained input but that does not have ValidationException attached in errors`() { + val model = + """ + $baseModel + + structure TestInputOutput { + @required + requiredString: String + } + """.asSmithyModel() + val service = model.lookup("test#TestService") + val validationResult = validateOperationsWithConstrainedInputHaveValidationExceptionAttached(model, service) + + validationResult.messages shouldHaveSize 1 + validationResult.messages[0].message shouldContain "Operation test#TestOperation takes in input that is constrained" + } + + @Test + fun `it should detect when unsupported constraint traits on member shapes are used`() { + val model = + """ + $baseModel + + structure TestInputOutput { + @length(min: 1, max: 69) + lengthString: String + } + """.asSmithyModel() + val validationResult = validateModel(model) + + validationResult.messages shouldHaveSize 1 + validationResult.messages[0].message shouldContain "The member shape `test#TestInputOutput\$lengthString` has the constraint trait `smithy.api#length` attached" + } + + @Test + fun `it should not detect when the required trait on a member shape is used`() { + val model = + """ + $baseModel + + structure TestInputOutput { + @required + string: String + } + """.asSmithyModel() + val validationResult = validateModel(model) + + validationResult.messages shouldHaveSize 0 + } + + private val constraintTraitOnStreamingBlobShapeModel = + """ + $baseModel + + structure TestInputOutput { + @required + streamingBlob: StreamingBlob + } + + @streaming + @length(min: 69) + blob StreamingBlob + """.asSmithyModel() + + @Test + fun `it should detect when constraint traits on streaming blob shapes are used`() { + val validationResult = validateModel(constraintTraitOnStreamingBlobShapeModel) + + validationResult.messages shouldHaveSize 2 + validationResult.messages.forSome { + it.message shouldContain + """ + The blob shape `test#StreamingBlob` has both the `smithy.api#length` and `smithy.api#streaming` constraint traits attached. + It is unclear what the semantics for streaming blob shapes are. + """.trimIndent().replace("\n", " ") + } + } + + @Test + fun `it should detect when constraint traits in event streams are used`() { + val model = + """ + $baseModel + + structure TestInputOutput { + eventStream: EventStream + } + + @streaming + union EventStream { + message: Message + } + + structure Message { + lengthString: LengthString + } + + @length(min: 1) + string LengthString + """.asSmithyModel() + val validationResult = validateModel(model) + + validationResult.messages shouldHaveSize 1 + validationResult.messages[0].message shouldContain + """ + The string shape `test#LengthString` has the constraint trait `smithy.api#length` attached. + This shape is also part of an event stream; it is unclear what the semantics for constrained shapes in event streams are. + """.trimIndent().replace("\n", " ") + } + + @Test + fun `it should detect when the length trait on collection shapes or on blob shapes is used`() { + val model = + """ + $baseModel + + structure TestInputOutput { + collection: LengthCollection, + blob: LengthBlob + } + + @length(min: 1) + list LengthCollection { + member: String + } + + @length(min: 1) + blob LengthBlob + """.asSmithyModel() + val validationResult = validateModel(model) + + validationResult.messages shouldHaveSize 2 + validationResult.messages.forSome { it.message shouldContain "The list shape `test#LengthCollection` has the constraint trait `smithy.api#length` attached" } + validationResult.messages.forSome { it.message shouldContain "The blob shape `test#LengthBlob` has the constraint trait `smithy.api#length` attached" } + } + + @Test + fun `it should detect when the pattern trait on string shapes is used`() { + val model = + """ + $baseModel + + structure TestInputOutput { + patternString: PatternString + } + + @pattern("^[A-Za-z]+$") + string PatternString + """.asSmithyModel() + val validationResult = validateModel(model) + + validationResult.messages shouldHaveSize 1 + validationResult.messages[0].message shouldContain "The string shape `test#PatternString` has the constraint trait `smithy.api#pattern` attached" + } + + @Test + fun `it should detect when the range trait is used`() { + val model = + """ + $baseModel + + structure TestInputOutput { + rangeInteger: RangeInteger + } + + @range(min: 1) + integer RangeInteger + """.asSmithyModel() + val validationResult = validateModel(model) + + validationResult.messages shouldHaveSize 1 + validationResult.messages[0].message shouldContain "The integer shape `test#RangeInteger` has the constraint trait `smithy.api#range` attached" + } + + @Test + fun `it should abort when ignoreUnsupportedConstraints is false and unsupported constraints are used`() { + val validationResult = validateModel(constraintTraitOnStreamingBlobShapeModel, ServerCodegenConfig()) + + validationResult.messages shouldHaveAtLeastSize 1 + validationResult.shouldAbort shouldBe true + } + + @Test + fun `it should not abort when ignoreUnsupportedConstraints is true and unsupported constraints are used`() { + val validationResult = validateModel( + constraintTraitOnStreamingBlobShapeModel, + ServerCodegenConfig().copy(ignoreUnsupportedConstraints = true), + ) + + validationResult.messages shouldHaveAtLeastSize 1 + validationResult.shouldAbort shouldBe false + } + + @Test + fun `it should set log level to error when ignoreUnsupportedConstraints is false and unsupported constraints are used`() { + val validationResult = validateModel(constraintTraitOnStreamingBlobShapeModel, ServerCodegenConfig()) + + validationResult.messages shouldHaveAtLeastSize 1 + validationResult.messages.shouldForAll { it.level shouldBe Level.SEVERE } + } + + @Test + fun `it should set log level to warn when ignoreUnsupportedConstraints is true and unsupported constraints are used`() { + val validationResult = validateModel( + constraintTraitOnStreamingBlobShapeModel, + ServerCodegenConfig().copy(ignoreUnsupportedConstraints = true), + ) + + validationResult.messages shouldHaveAtLeastSize 1 + validationResult.messages.shouldForAll { it.level shouldBe Level.WARNING } + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGeneratorTest.kt new file mode 100644 index 0000000000..3fa9559943 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGeneratorTest.kt @@ -0,0 +1,158 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import io.kotest.matchers.string.shouldContain +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.ExtensionContext +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.ArgumentsProvider +import org.junit.jupiter.params.provider.ArgumentsSource +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.node.ObjectNode +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.transformers.ShapesReachableFromOperationInputTagger +import java.util.stream.Stream + +class ConstrainedMapGeneratorTest { + + data class TestCase(val model: Model, val validMap: ObjectNode, val invalidMap: ObjectNode) + + class ConstrainedMapGeneratorTestProvider : ArgumentsProvider { + private val testCases = listOf( + // Min and max. + Triple("@length(min: 11, max: 12)", 11, 13), + // Min equal to max. + Triple("@length(min: 11, max: 11)", 11, 12), + // Only min. + Triple("@length(min: 11)", 15, 10), + // Only max. + Triple("@length(max: 11)", 11, 12), + ).map { + val validStringMap = List(it.second) { index -> index.toString() to "value" }.toMap() + val inValidStringMap = List(it.third) { index -> index.toString() to "value" }.toMap() + Triple(it.first, ObjectNode.fromStringMap(validStringMap), ObjectNode.fromStringMap(inValidStringMap)) + }.map { (trait, validMap, invalidMap) -> + TestCase( + """ + namespace test + + $trait + map ConstrainedMap { + key: String, + value: String + } + """.asSmithyModel().let(ShapesReachableFromOperationInputTagger::transform), + validMap, + invalidMap, + ) + } + + override fun provideArguments(context: ExtensionContext?): Stream = + testCases.map { Arguments.of(it) }.stream() + } + + @ParameterizedTest + @ArgumentsSource(ConstrainedMapGeneratorTestProvider::class) + fun `it should generate constrained map types`(testCase: TestCase) { + val constrainedMapShape = testCase.model.lookup("test#ConstrainedMap") + + val codegenContext = serverTestCodegenContext(testCase.model) + val symbolProvider = codegenContext.symbolProvider + + val project = TestWorkspace.testProject(symbolProvider) + + project.withModule(ModelsModule) { + render(codegenContext, this, constrainedMapShape) + + val instantiator = serverInstantiator(codegenContext) + rustBlock("##[cfg(test)] fn build_valid_map() -> std::collections::HashMap") { + instantiator.render(this, constrainedMapShape, testCase.validMap) + } + rustBlock("##[cfg(test)] fn build_invalid_map() -> std::collections::HashMap") { + instantiator.render(this, constrainedMapShape, testCase.invalidMap) + } + + unitTest( + name = "try_from_success", + test = """ + let map = build_valid_map(); + let _constrained: ConstrainedMap = map.try_into().unwrap(); + """, + ) + unitTest( + name = "try_from_fail", + test = """ + let map = build_invalid_map(); + let constrained_res: Result = map.try_into(); + constrained_res.unwrap_err(); + """, + ) + unitTest( + name = "inner", + test = """ + let map = build_valid_map(); + let constrained = ConstrainedMap::try_from(map.clone()).unwrap(); + + assert_eq!(constrained.inner(), &map); + """, + ) + unitTest( + name = "into_inner", + test = """ + let map = build_valid_map(); + let constrained = ConstrainedMap::try_from(map.clone()).unwrap(); + + assert_eq!(constrained.into_inner(), map); + """, + ) + } + + project.compileAndTest() + } + + @Test + fun `type should not be constructible without using a constructor`() { + val model = """ + namespace test + + @length(min: 1, max: 69) + map ConstrainedMap { + key: String, + value: String + } + """.asSmithyModel().let(ShapesReachableFromOperationInputTagger::transform) + val constrainedMapShape = model.lookup("test#ConstrainedMap") + + val writer = RustWriter.forModule(ModelsModule.name) + + val codegenContext = serverTestCodegenContext(model) + render(codegenContext, writer, constrainedMapShape) + + // Check that the wrapped type is `pub(crate)`. + writer.toString() shouldContain "pub struct ConstrainedMap(pub(crate) std::collections::HashMap);" + } + + private fun render( + codegenContext: ServerCodegenContext, + writer: RustWriter, + constrainedMapShape: MapShape, + ) { + ConstrainedMapGenerator(codegenContext, writer, constrainedMapShape).render() + MapConstraintViolationGenerator(codegenContext, writer, constrainedMapShape).render() + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGeneratorTest.kt new file mode 100644 index 0000000000..75db6303f7 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGeneratorTest.kt @@ -0,0 +1,179 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import io.kotest.matchers.string.shouldContain +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.ExtensionContext +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.ArgumentsProvider +import org.junit.jupiter.params.provider.ArgumentsSource +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext +import java.util.stream.Stream + +class ConstrainedStringGeneratorTest { + + data class TestCase(val model: Model, val validString: String, val invalidString: String) + + class ConstrainedStringGeneratorTestProvider : ArgumentsProvider { + private val testCases = listOf( + // Min and max. + Triple("@length(min: 11, max: 12)", "validString", "invalidString"), + // Min equal to max. + Triple("@length(min: 11, max: 11)", "validString", "invalidString"), + // Only min. + Triple("@length(min: 11)", "validString", ""), + // Only max. + Triple("@length(max: 11)", "", "invalidString"), + // Count Unicode scalar values, not `.len()`. + Triple( + "@length(min: 3, max: 3)", + "👍👍👍", // These three emojis are three Unicode scalar values. + "👍👍👍👍", + ), + ).map { + TestCase( + """ + namespace test + + ${it.first} + string ConstrainedString + """.asSmithyModel(), + it.second, + it.third, + ) + } + + override fun provideArguments(context: ExtensionContext?): Stream = + testCases.map { Arguments.of(it) }.stream() + } + + @ParameterizedTest + @ArgumentsSource(ConstrainedStringGeneratorTestProvider::class) + fun `it should generate constrained string types`(testCase: TestCase) { + val constrainedStringShape = testCase.model.lookup("test#ConstrainedString") + + val codegenContext = serverTestCodegenContext(testCase.model) + val symbolProvider = codegenContext.symbolProvider + + val project = TestWorkspace.testProject(symbolProvider) + + project.withModule(ModelsModule) { + ConstrainedStringGenerator(codegenContext, this, constrainedStringShape).render() + + unitTest( + name = "try_from_success", + test = """ + let string = "${testCase.validString}".to_owned(); + let _constrained: ConstrainedString = string.try_into().unwrap(); + """, + ) + unitTest( + name = "try_from_fail", + test = """ + let string = "${testCase.invalidString}".to_owned(); + let constrained_res: Result = string.try_into(); + constrained_res.unwrap_err(); + """, + ) + unitTest( + name = "inner", + test = """ + let string = "${testCase.validString}".to_owned(); + let constrained = ConstrainedString::try_from(string).unwrap(); + + assert_eq!(constrained.inner(), "${testCase.validString}"); + """, + ) + unitTest( + name = "into_inner", + test = """ + let string = "${testCase.validString}".to_owned(); + let constrained = ConstrainedString::try_from(string.clone()).unwrap(); + + assert_eq!(constrained.into_inner(), string); + """, + ) + } + + project.compileAndTest() + } + + @Test + fun `type should not be constructible without using a constructor`() { + val model = """ + namespace test + + @length(min: 1, max: 69) + string ConstrainedString + """.asSmithyModel() + val constrainedStringShape = model.lookup("test#ConstrainedString") + + val codegenContext = serverTestCodegenContext(model) + + val writer = RustWriter.forModule(ModelsModule.name) + + ConstrainedStringGenerator(codegenContext, writer, constrainedStringShape).render() + + // Check that the wrapped type is `pub(crate)`. + writer.toString() shouldContain "pub struct ConstrainedString(pub(crate) std::string::String);" + } + + @Test + fun `Display implementation`() { + val model = """ + namespace test + + @length(min: 1, max: 69) + string ConstrainedString + + @sensitive + @length(min: 1, max: 78) + string SensitiveConstrainedString + """.asSmithyModel() + val constrainedStringShape = model.lookup("test#ConstrainedString") + val sensitiveConstrainedStringShape = model.lookup("test#SensitiveConstrainedString") + + val codegenContext = serverTestCodegenContext(model) + + val project = TestWorkspace.testProject(codegenContext.symbolProvider) + + project.withModule(ModelsModule) { + ConstrainedStringGenerator(codegenContext, this, constrainedStringShape).render() + ConstrainedStringGenerator(codegenContext, this, sensitiveConstrainedStringShape).render() + + unitTest( + name = "non_sensitive_string_display_implementation", + test = """ + let string = "a non-sensitive string".to_owned(); + let constrained = ConstrainedString::try_from(string).unwrap(); + assert_eq!(format!("{}", constrained), "a non-sensitive string") + """, + ) + + unitTest( + name = "sensitive_string_display_implementation", + test = """ + let string = "That is how heavy a secret can become. It can make blood flow easier than ink.".to_owned(); + let constrained = SensitiveConstrainedString::try_from(string).unwrap(); + assert_eq!(format!("{}", constrained), "*** Sensitive Data Redacted ***") + """, + ) + } + + project.compileAndTest() + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerCombinedErrorGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerCombinedErrorGeneratorTest.kt index da3ea23d8e..d414aa63fc 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerCombinedErrorGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerCombinedErrorGeneratorTest.kt @@ -8,15 +8,14 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.RustModule -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ServerCombinedErrorGenerator import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest -import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverRenderWithModelBuilder import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider class ServerCombinedErrorGeneratorTest { @@ -55,16 +54,20 @@ class ServerCombinedErrorGeneratorTest { val project = TestWorkspace.testProject(symbolProvider) project.withModule(RustModule.public("error")) { listOf("FooException", "ComplexError", "InvalidGreeting", "Deprecated").forEach { - model.lookup("error#$it").renderWithModelBuilder(model, symbolProvider, this, CodegenTarget.SERVER) + model.lookup("error#$it").serverRenderWithModelBuilder(model, symbolProvider, this) } val errors = listOf("FooException", "ComplexError", "InvalidGreeting").map { model.lookup("error#$it") } - val generator = ServerCombinedErrorGenerator(model, symbolProvider, symbolProvider.toSymbol(model.lookup("error#Greeting")), errors) - generator.render(this) + ServerCombinedErrorGenerator( + model, + symbolProvider, + symbolProvider.toSymbol(model.lookup("error#Greeting")), + errors, + ).render(this) unitTest( name = "generates_combined_error_enums", test = """ - let variant = InvalidGreeting::builder().message("an error").build(); + let variant = InvalidGreeting { message: String::from("an error") }; assert_eq!(format!("{}", variant), "InvalidGreeting: an error"); assert_eq!(variant.message(), "an error"); assert_eq!( diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGeneratorTest.kt index a07447a715..0e813cebbd 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGeneratorTest.kt @@ -9,12 +9,10 @@ import io.kotest.matchers.string.shouldNotContain import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest -import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.lookup -import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext class ServerEnumGeneratorTest { private val model = """ @@ -36,30 +34,26 @@ class ServerEnumGeneratorTest { string InstanceType """.asSmithyModel() + private val codegenContext = serverTestCodegenContext(model) + private val writer = RustWriter.forModule("model") + private val shape = model.lookup("test#InstanceType") + @Test fun `it generates TryFrom, FromStr and errors for enums`() { - val provider = serverTestSymbolProvider(model) - val writer = RustWriter.forModule("model") - val shape = model.lookup("test#InstanceType") - val generator = ServerEnumGenerator(model, provider, writer, shape, shape.expectTrait(), TestRuntimeConfig) - generator.render() + ServerEnumGenerator(codegenContext, writer, shape).render() writer.compileAndTest( """ use std::str::FromStr; assert_eq!(InstanceType::try_from("t2.nano").unwrap(), InstanceType::T2Nano); assert_eq!(InstanceType::from_str("t2.nano").unwrap(), InstanceType::T2Nano); - assert_eq!(InstanceType::try_from("unknown").unwrap_err(), InstanceTypeUnknownVariantError("unknown".to_string())); + assert_eq!(InstanceType::try_from("unknown").unwrap_err(), crate::model::instance_type::ConstraintViolation(String::from("unknown"))); """, ) } @Test fun `it generates enums without the unknown variant`() { - val provider = serverTestSymbolProvider(model) - val writer = RustWriter.forModule("model") - val shape = model.lookup("test#InstanceType") - val generator = ServerEnumGenerator(model, provider, writer, shape, shape.expectTrait(), TestRuntimeConfig) - generator.render() + ServerEnumGenerator(codegenContext, writer, shape).render() writer.compileAndTest( """ // check no unknown @@ -74,11 +68,7 @@ class ServerEnumGeneratorTest { @Test fun `it generates enums without non_exhaustive`() { - val provider = serverTestSymbolProvider(model) - val writer = RustWriter.forModule("model") - val shape = model.lookup("test#InstanceType") - val generator = ServerEnumGenerator(model, provider, writer, shape, shape.expectTrait(), TestRuntimeConfig) - generator.render() + ServerEnumGenerator(codegenContext, writer, shape).render() writer.toString() shouldNotContain "#[non_exhaustive]" } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt index 945dac55dc..1bfbd26e2f 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt @@ -13,18 +13,17 @@ import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.withBlock -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest -import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverRenderWithModelBuilder import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext class ServerInstantiatorTest { @@ -140,9 +139,9 @@ class ServerInstantiatorTest { val project = TestWorkspace.testProject() project.withModule(RustModule.Model) { - structure.renderWithModelBuilder(model, symbolProvider, this, CodegenTarget.SERVER) - inner.renderWithModelBuilder(model, symbolProvider, this, CodegenTarget.SERVER) - nestedStruct.renderWithModelBuilder(model, symbolProvider, this, CodegenTarget.SERVER) + structure.serverRenderWithModelBuilder(model, symbolProvider, this) + inner.serverRenderWithModelBuilder(model, symbolProvider, this) + nestedStruct.serverRenderWithModelBuilder(model, symbolProvider, this) UnionGenerator(model, symbolProvider, this, union).render() unitTest("server_instantiator_test") { diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGeneratorTest.kt new file mode 100644 index 0000000000..42774274d9 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGeneratorTest.kt @@ -0,0 +1,124 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverRenderWithModelBuilder +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext + +class UnconstrainedCollectionGeneratorTest { + @Test + fun `it should generate unconstrained lists`() { + val model = + """ + namespace test + + list ListA { + member: ListB + } + + list ListB { + member: StructureC + } + + structure StructureC { + @required + int: Integer, + + @required + string: String + } + """.asSmithyModel() + val codegenContext = serverTestCodegenContext(model) + val symbolProvider = codegenContext.symbolProvider + + val listA = model.lookup("test#ListA") + val listB = model.lookup("test#ListB") + + val project = TestWorkspace.testProject(symbolProvider) + + project.withModule(RustModule.public("model")) { + model.lookup("test#StructureC").serverRenderWithModelBuilder(model, symbolProvider, this) + } + + project.withModule(RustModule.private("constrained")) { + listOf(listA, listB).forEach { + PubCrateConstrainedCollectionGenerator(codegenContext, this, it).render() + } + } + project.withModule(RustModule.private("unconstrained")) unconstrainedModuleWriter@{ + project.withModule(ModelsModule) modelsModuleWriter@{ + listOf(listA, listB).forEach { + UnconstrainedCollectionGenerator( + codegenContext, + this@unconstrainedModuleWriter, + this@modelsModuleWriter, + it, + ).render() + } + + this@unconstrainedModuleWriter.unitTest( + name = "list_a_unconstrained_fail_to_constrain_with_first_error", + test = """ + let c_builder1 = crate::model::StructureC::builder().int(69); + let c_builder2 = crate::model::StructureC::builder().string("david".to_owned()); + let list_b_unconstrained = list_b_unconstrained::ListBUnconstrained(vec![c_builder1, c_builder2]); + let list_a_unconstrained = list_a_unconstrained::ListAUnconstrained(vec![list_b_unconstrained]); + + let expected_err = + crate::model::list_a::ConstraintViolation(0, crate::model::list_b::ConstraintViolation( + 0, crate::model::structure_c::ConstraintViolation::MissingString, + )); + + assert_eq!( + expected_err, + crate::constrained::list_a_constrained::ListAConstrained::try_from(list_a_unconstrained).unwrap_err() + ); + """, + ) + + this@unconstrainedModuleWriter.unitTest( + name = "list_a_unconstrained_succeed_to_constrain", + test = """ + let c_builder = crate::model::StructureC::builder().int(69).string(String::from("david")); + let list_b_unconstrained = list_b_unconstrained::ListBUnconstrained(vec![c_builder]); + let list_a_unconstrained = list_a_unconstrained::ListAUnconstrained(vec![list_b_unconstrained]); + + let expected: Vec> = vec![vec![crate::model::StructureC { + string: "david".to_owned(), + int: 69 + }]]; + let actual: Vec> = + crate::constrained::list_a_constrained::ListAConstrained::try_from(list_a_unconstrained).unwrap().into(); + + assert_eq!(expected, actual); + """, + ) + + this@unconstrainedModuleWriter.unitTest( + name = "list_a_unconstrained_converts_into_constrained", + test = """ + let c_builder = crate::model::StructureC::builder(); + let list_b_unconstrained = list_b_unconstrained::ListBUnconstrained(vec![c_builder]); + let list_a_unconstrained = list_a_unconstrained::ListAUnconstrained(vec![list_b_unconstrained]); + + let _list_a: crate::constrained::MaybeConstrained = list_a_unconstrained.into(); + """, + ) + project.compileAndTest() + } + } + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGeneratorTest.kt new file mode 100644 index 0000000000..a5877b7c00 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGeneratorTest.kt @@ -0,0 +1,164 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverRenderWithModelBuilder +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext + +class UnconstrainedMapGeneratorTest { + @Test + fun `it should generate unconstrained maps`() { + val model = + """ + namespace test + + map MapA { + key: String, + value: MapB + } + + map MapB { + key: String, + value: StructureC + } + + structure StructureC { + @required + int: Integer, + + @required + string: String + } + """.asSmithyModel() + val codegenContext = serverTestCodegenContext(model) + val symbolProvider = codegenContext.symbolProvider + + val mapA = model.lookup("test#MapA") + val mapB = model.lookup("test#MapB") + + val project = TestWorkspace.testProject(symbolProvider) + + project.withModule(RustModule.public("model")) { + model.lookup("test#StructureC").serverRenderWithModelBuilder(model, symbolProvider, this) + } + + project.withModule(RustModule.private("constrained")) { + listOf(mapA, mapB).forEach { + PubCrateConstrainedMapGenerator(codegenContext, this, it).render() + } + } + project.withModule(RustModule.private("unconstrained")) unconstrainedModuleWriter@{ + project.withModule(ModelsModule) modelsModuleWriter@{ + listOf(mapA, mapB).forEach { + UnconstrainedMapGenerator(codegenContext, this@unconstrainedModuleWriter, it).render() + + MapConstraintViolationGenerator(codegenContext, this@modelsModuleWriter, it).render() + } + + this@unconstrainedModuleWriter.unitTest( + name = "map_a_unconstrained_fail_to_constrain_with_some_error", + test = """ + let c_builder1 = crate::model::StructureC::builder().int(69); + let c_builder2 = crate::model::StructureC::builder().string(String::from("david")); + let map_b_unconstrained = map_b_unconstrained::MapBUnconstrained( + std::collections::HashMap::from([ + (String::from("KeyB1"), c_builder1), + (String::from("KeyB2"), c_builder2), + ]) + ); + let map_a_unconstrained = map_a_unconstrained::MapAUnconstrained( + std::collections::HashMap::from([ + (String::from("KeyA"), map_b_unconstrained), + ]) + ); + + // Any of these two errors could be returned; it depends on the order in which the maps are visited. + let missing_string_expected_err = crate::model::map_a::ConstraintViolation::Value( + "KeyA".to_owned(), + crate::model::map_b::ConstraintViolation::Value( + "KeyB1".to_owned(), + crate::model::structure_c::ConstraintViolation::MissingString, + ) + ); + let missing_int_expected_err = crate::model::map_a::ConstraintViolation::Value( + "KeyA".to_owned(), + crate::model::map_b::ConstraintViolation::Value( + "KeyB2".to_owned(), + crate::model::structure_c::ConstraintViolation::MissingInt, + ) + ); + + let actual_err = crate::constrained::map_a_constrained::MapAConstrained::try_from(map_a_unconstrained).unwrap_err(); + + assert!(actual_err == missing_string_expected_err || actual_err == missing_int_expected_err); + """, + ) + + this@unconstrainedModuleWriter.unitTest( + name = "map_a_unconstrained_succeed_to_constrain", + test = """ + let c_builder = crate::model::StructureC::builder().int(69).string(String::from("david")); + let map_b_unconstrained = map_b_unconstrained::MapBUnconstrained( + std::collections::HashMap::from([ + (String::from("KeyB"), c_builder), + ]) + ); + let map_a_unconstrained = map_a_unconstrained::MapAUnconstrained( + std::collections::HashMap::from([ + (String::from("KeyA"), map_b_unconstrained), + ]) + ); + + let expected = std::collections::HashMap::from([ + (String::from("KeyA"), std::collections::HashMap::from([ + (String::from("KeyB"), crate::model::StructureC { + int: 69, + string: String::from("david") + }), + ])) + ]); + + assert_eq!( + expected, + crate::constrained::map_a_constrained::MapAConstrained::try_from(map_a_unconstrained).unwrap().into() + ); + """, + ) + + this@unconstrainedModuleWriter.unitTest( + name = "map_a_unconstrained_converts_into_constrained", + test = """ + let c_builder = crate::model::StructureC::builder(); + let map_b_unconstrained = map_b_unconstrained::MapBUnconstrained( + std::collections::HashMap::from([ + (String::from("KeyB"), c_builder), + ]) + ); + let map_a_unconstrained = map_a_unconstrained::MapAUnconstrained( + std::collections::HashMap::from([ + (String::from("KeyA"), map_b_unconstrained), + ]) + ); + + let _map_a: crate::constrained::MaybeConstrained = map_a_unconstrained.into(); + """, + ) + + project.compileAndTest() + } + } + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGeneratorTest.kt new file mode 100644 index 0000000000..f31285de98 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGeneratorTest.kt @@ -0,0 +1,102 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule +import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverRenderWithModelBuilder +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext + +class UnconstrainedUnionGeneratorTest { + @Test + fun `it should generate unconstrained unions`() { + val model = + """ + namespace test + + union Union { + structure: Structure + } + + structure Structure { + @required + requiredMember: String + } + """.asSmithyModel() + val codegenContext = serverTestCodegenContext(model) + val symbolProvider = codegenContext.symbolProvider + + val unionShape = model.lookup("test#Union") + + val project = TestWorkspace.testProject(symbolProvider) + + project.withModule(RustModule.public("model")) { + model.lookup("test#Structure").serverRenderWithModelBuilder(model, symbolProvider, this) + } + + project.withModule(ModelsModule) { + UnionGenerator(model, symbolProvider, this, unionShape, renderUnknownVariant = false).render() + } + project.withModule(RustModule.private("unconstrained")) unconstrainedModuleWriter@{ + project.withModule(ModelsModule) modelsModuleWriter@{ + UnconstrainedUnionGenerator(codegenContext, this@unconstrainedModuleWriter, this@modelsModuleWriter, unionShape).render() + + this@unconstrainedModuleWriter.unitTest( + name = "unconstrained_union_fail_to_constrain", + test = """ + let builder = crate::model::Structure::builder(); + let union_unconstrained = union_unconstrained::UnionUnconstrained::Structure(builder); + + let expected_err = crate::model::union::ConstraintViolation::Structure( + crate::model::structure::ConstraintViolation::MissingRequiredMember, + ); + + assert_eq!( + expected_err, + crate::model::Union::try_from(union_unconstrained).unwrap_err() + ); + """, + ) + + this@unconstrainedModuleWriter.unitTest( + name = "unconstrained_union_succeed_to_constrain", + test = """ + let builder = crate::model::Structure::builder().required_member(String::from("david")); + let union_unconstrained = union_unconstrained::UnionUnconstrained::Structure(builder); + + let expected: crate::model::Union = crate::model::Union::Structure(crate::model::Structure { + required_member: String::from("david"), + }); + let actual: crate::model::Union = crate::model::Union::try_from(union_unconstrained).unwrap(); + + assert_eq!(expected, actual); + """, + ) + + this@unconstrainedModuleWriter.unitTest( + name = "unconstrained_union_converts_into_constrained", + test = """ + let builder = crate::model::Structure::builder(); + let union_unconstrained = union_unconstrained::UnionUnconstrained::Structure(builder); + + let _union: crate::constrained::MaybeConstrained = + union_unconstrained.into(); + """, + ) + project.compileAndTest() + } + } + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/EventStreamTestTools.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/EventStreamTestTools.kt index 3a917af0dc..52d312ffa9 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/EventStreamTestTools.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/EventStreamTestTools.kt @@ -313,7 +313,11 @@ object EventStreamTestModels { """.trimIndent(), ) { Ec2QueryProtocol(it) }, - ).flatMap { listOf(it, it.copy(target = CodegenTarget.SERVER)) } + ) + // TODO(https://github.com/awslabs/smithy-rs/issues/1442) Server tests + // should be run from the server subproject using the + // `serverTestSymbolProvider()`. + // .flatMap { listOf(it, it.copy(target = CodegenTarget.SERVER)) } class UnmarshallTestCasesProvider : ArgumentsProvider { override fun provideArguments(context: ExtensionContext?): Stream = diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/parse/EventStreamUnmarshallerGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/parse/EventStreamUnmarshallerGeneratorTest.kt index 0b219a569d..5094426d75 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/parse/EventStreamUnmarshallerGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/parse/EventStreamUnmarshallerGeneratorTest.kt @@ -7,12 +7,14 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols.parse import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.ArgumentsSource +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.EventStreamUnmarshallerGenerator -import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.testRustSettings import software.amazon.smithy.rust.codegen.core.testutil.unitTest @@ -34,14 +36,13 @@ class EventStreamUnmarshallerGeneratorTest { target = testCase.target, ) val protocol = testCase.protocolBuilder(codegenContext) + fun builderSymbol(shape: StructureShape): Symbol = shape.builderSymbol(codegenContext.symbolProvider) val generator = EventStreamUnmarshallerGenerator( protocol, - test.model, - TestRuntimeConfig, - test.symbolProvider, + codegenContext, test.operationShape, test.streamShape, - target = testCase.target, + ::builderSymbol, ) test.project.lib { diff --git a/rust-runtime/aws-smithy-http-server/Cargo.toml b/rust-runtime/aws-smithy-http-server/Cargo.toml index 10888a09e5..46d5c52237 100644 --- a/rust-runtime/aws-smithy-http-server/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server/Cargo.toml @@ -34,7 +34,8 @@ once_cell = "1.13" regex = "1.5.5" serde_urlencoded = "0.7" strum_macros = "0.24" -thiserror = "1" +# TODO Investigate. +thiserror = "<=1.0.36" tracing = "0.1.35" tokio = { version = "1.8.4", features = ["full"] } tower = { version = "0.4.11", features = ["util", "make"], default-features = false } diff --git a/rust-runtime/aws-smithy-http-server/examples/pokemon-service/src/lib.rs b/rust-runtime/aws-smithy-http-server/examples/pokemon-service/src/lib.rs index 5804a3fb7f..318dac9091 100644 --- a/rust-runtime/aws-smithy-http-server/examples/pokemon-service/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server/examples/pokemon-service/src/lib.rs @@ -214,7 +214,7 @@ pub async fn capture_pokemon( ) -> Result { if input.region != "Kanto" { return Err(error::CapturePokemonError::UnsupportedRegionError( - error::UnsupportedRegionError::builder().build(), + error::UnsupportedRegionError { region: input.region }, )); } let output_stream = stream! { @@ -230,7 +230,9 @@ pub async fn capture_pokemon( if ! matches!(pokeball, "Master Ball" | "Great Ball" | "Fast Ball") { yield Err( crate::error::CapturePokemonEventsError::InvalidPokeballError( - crate::error::InvalidPokeballError::builder().pokeball(pokeball).build() + crate::error::InvalidPokeballError { + pokeball: pokeball.to_owned() + } ) ); } else { @@ -253,11 +255,12 @@ pub async fn capture_pokemon( .to_string(); let pokedex: Vec = (0..255).collect(); yield Ok(crate::model::CapturePokemonEvents::Event( - crate::model::CaptureEvent::builder() - .name(pokemon) - .shiny(shiny) - .pokedex_update(Blob::new(pokedex)) - .build(), + crate::model::CaptureEvent { + name: Some(pokemon), + shiny: Some(shiny), + pokedex_update: Some(Blob::new(pokedex)), + captured: Some(true), + } )); } } diff --git a/rust-runtime/aws-smithy-http-server/src/rejection.rs b/rust-runtime/aws-smithy-http-server/src/rejection.rs index 0d6e656dba..65d4497ffe 100644 --- a/rust-runtime/aws-smithy-http-server/src/rejection.rs +++ b/rust-runtime/aws-smithy-http-server/src/rejection.rs @@ -179,15 +179,12 @@ pub enum RequestRejection { FloatParse(crate::Error), BoolParse(crate::Error), - // TODO(https://github.com/awslabs/smithy-rs/issues/1243): In theory, we could get rid of this - // error, but it would be a lot of effort for comparatively low benefit. - /// Used when consuming the input struct builder. - Build(crate::Error), - - /// Used by the server when the enum variant sent by a client is not known. - /// Unlike the rejections above, the inner type is code generated, - /// with each enum having its own generated error type. - EnumVariantNotFound(Box), + /// Used when consuming the input struct builder, and constraint violations occur. + // Unlike the rejections above, this does not take in `crate::Error`, since it is constructed + // directly in the code-generated SDK instead of in this crate. + // TODO(https://github.com/awslabs/smithy-rs/issues/1703): this will hold a type that can be + // rendered into a protocol-specific response later on. + ConstraintViolation(String), } #[derive(Debug, Display)] @@ -237,7 +234,6 @@ impl From for RequestRejection { convert_to_request_rejection!(aws_smithy_json::deserialize::Error, JsonDeserialize); convert_to_request_rejection!(aws_smithy_xml::decode::XmlError, XmlDeserialize); -convert_to_request_rejection!(aws_smithy_http::operation::BuildError, Build); convert_to_request_rejection!(aws_smithy_http::header::ParseError, HeaderParse); convert_to_request_rejection!(aws_smithy_types::date_time::DateTimeParseError, DateTimeParse); convert_to_request_rejection!(aws_smithy_types::primitive::PrimitiveParseError, PrimitiveParse); diff --git a/rust-runtime/aws-smithy-http-server/src/runtime_error.rs b/rust-runtime/aws-smithy-http-server/src/runtime_error.rs index e389240f8e..7503af1925 100644 --- a/rust-runtime/aws-smithy-http-server/src/runtime_error.rs +++ b/rust-runtime/aws-smithy-http-server/src/runtime_error.rs @@ -21,14 +21,13 @@ //! and converts into the corresponding `RuntimeError`, and then it uses the its //! [`RuntimeError::into_response`] method to render and send a response. -use http::StatusCode; - use crate::extension::RuntimeErrorExtension; use crate::proto::aws_json_10::AwsJson1_0; use crate::proto::aws_json_11::AwsJson1_1; use crate::proto::rest_json_1::RestJson1; use crate::proto::rest_xml::RestXml; use crate::response::IntoResponse; +use http::StatusCode; #[derive(Debug)] pub enum RuntimeError { @@ -40,6 +39,10 @@ pub enum RuntimeError { // TODO(https://github.com/awslabs/smithy-rs/issues/1663) NotAcceptable, UnsupportedMediaType, + + // TODO(https://github.com/awslabs/smithy-rs/issues/1703): this will hold a type that can be + // rendered into a protocol-specific response later on. + Validation(String), } /// String representation of the runtime error type. @@ -52,6 +55,7 @@ impl RuntimeError { Self::InternalFailure(_) => "InternalFailureException", Self::NotAcceptable => "NotAcceptableException", Self::UnsupportedMediaType => "UnsupportedMediaTypeException", + Self::Validation(_) => "ValidationException", } } @@ -61,6 +65,7 @@ impl RuntimeError { Self::InternalFailure(_) => StatusCode::INTERNAL_SERVER_ERROR, Self::NotAcceptable => StatusCode::NOT_ACCEPTABLE, Self::UnsupportedMediaType => StatusCode::UNSUPPORTED_MEDIA_TYPE, + Self::Validation(_) => StatusCode::BAD_REQUEST, } } } @@ -93,48 +98,78 @@ impl IntoResponse for InternalFailureException { impl IntoResponse for RuntimeError { fn into_response(self) -> http::Response { - http::Response::builder() + let res = http::Response::builder() .status(self.status_code()) .header("Content-Type", "application/json") .header("X-Amzn-Errortype", self.name()) - .extension(RuntimeErrorExtension::new(self.name().to_string())) - // See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#empty-body-serialization - .body(crate::body::to_boxed("{}")) + .extension(RuntimeErrorExtension::new(self.name().to_string())); + + let body = match self { + RuntimeError::Validation(reason) => crate::body::to_boxed(reason), + _ => crate::body::to_boxed("{}"), + }; + + res + .body(body) .expect("invalid HTTP response for `RuntimeError`; please file a bug report under https://github.com/awslabs/smithy-rs/issues") } } impl IntoResponse for RuntimeError { fn into_response(self) -> http::Response { - http::Response::builder() + let res = http::Response::builder() .status(self.status_code()) .header("Content-Type", "application/xml") - .extension(RuntimeErrorExtension::new(self.name().to_string())) - .body(crate::body::to_boxed("")) + .extension(RuntimeErrorExtension::new(self.name().to_string())); + + let body = match self { + // TODO(https://github.com/awslabs/smithy/issues/1446) The Smithy spec does not yet + // define constraint violation HTTP body responses for RestXml. + RuntimeError::Validation(_reason) => todo!("https://github.com/awslabs/smithy/issues/1446"), + // See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#empty-body-serialization + _ => crate::body::to_boxed("{}"), + }; + + res + .body(body) .expect("invalid HTTP response for `RuntimeError`; please file a bug report under https://github.com/awslabs/smithy-rs/issues") } } impl IntoResponse for RuntimeError { fn into_response(self) -> http::Response { - http::Response::builder() + let res = http::Response::builder() .status(self.status_code()) .header("Content-Type", "application/x-amz-json-1.0") - .extension(RuntimeErrorExtension::new(self.name().to_string())) - // See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#empty-body-serialization - .body(crate::body::to_boxed("")) + .extension(RuntimeErrorExtension::new(self.name().to_string())); + + let body = match self { + RuntimeError::Validation(reason) => crate::body::to_boxed(reason), + // See https://awslabs.github.io/smithy/2.0/aws/protocols/aws-json-1_0-protocol.html#empty-body-serialization + _ => crate::body::to_boxed("{}"), + }; + + res + .body(body) .expect("invalid HTTP response for `RuntimeError`; please file a bug report under https://github.com/awslabs/smithy-rs/issues") } } impl IntoResponse for RuntimeError { fn into_response(self) -> http::Response { - http::Response::builder() + let res = http::Response::builder() .status(self.status_code()) .header("Content-Type", "application/x-amz-json-1.1") - .extension(RuntimeErrorExtension::new(self.name().to_string())) - // See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#empty-body-serialization - .body(crate::body::to_boxed("")) + .extension(RuntimeErrorExtension::new(self.name().to_string())); + + let body = match self { + RuntimeError::Validation(reason) => crate::body::to_boxed(reason), + // https://awslabs.github.io/smithy/2.0/aws/protocols/aws-json-1_1-protocol.html#empty-body-serialization + _ => crate::body::to_boxed(""), + }; + + res + .body(body) .expect("invalid HTTP response for `RuntimeError`; please file a bug report under https://github.com/awslabs/smithy-rs/issues") } } @@ -155,6 +190,7 @@ impl From for RuntimeError { fn from(err: crate::rejection::RequestRejection) -> Self { match err { crate::rejection::RequestRejection::MissingContentType(_reason) => Self::UnsupportedMediaType, + crate::rejection::RequestRejection::ConstraintViolation(reason) => Self::Validation(reason), _ => Self::Serialization(crate::Error::new(err)), } } diff --git a/rust-runtime/inlineable/src/constrained.rs b/rust-runtime/inlineable/src/constrained.rs new file mode 100644 index 0000000000..1276eccbcd --- /dev/null +++ b/rust-runtime/inlineable/src/constrained.rs @@ -0,0 +1,15 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +pub(crate) trait Constrained { + type Unconstrained; +} + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub(crate) enum MaybeConstrained { + Constrained(T), + Unconstrained(T::Unconstrained), +} diff --git a/rust-runtime/inlineable/src/lib.rs b/rust-runtime/inlineable/src/lib.rs index 3cbcd5e5ff..2c2634110c 100644 --- a/rust-runtime/inlineable/src/lib.rs +++ b/rust-runtime/inlineable/src/lib.rs @@ -3,6 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ +#[allow(unused)] +mod constrained; #[allow(dead_code)] mod ec2_query_errors; #[allow(dead_code)]