diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index 469b82d380..53cf679909 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -242,13 +242,14 @@ message = """ * The `length` trait on `string` shapes. * The `length` trait on `map` shapes. +* The `range` trait on `integer` shapes. * The `pattern` trait on `string` shapes. Upon receiving a request that violates the modeled constraints, the server SDK will reject it with a message indicating why. Unsupported (constraint trait, target shape) combinations will now fail at code generation time, whereas previously they were just ignored. This is a breaking change to raise awareness in service owners of their server SDKs behaving differently than what was modeled. To continue generating a server SDK with unsupported constraint traits, set `codegenConfig.ignoreUnsupportedConstraints` to `true` in your `smithy-build.json`. """ -references = ["smithy-rs#1199", "smithy-rs#1342", "smithy-rs#1401", "smithy-rs#1998"] +references = ["smithy-rs#1199", "smithy-rs#1342", "smithy-rs#1401", "smithy-rs#2005", "smithy-rs#1998"] meta = { "breaking" = true, "tada" = true, "bug" = false, "target" = "server" } author = "david-perez" diff --git a/codegen-core/common-test-models/constraints.smithy b/codegen-core/common-test-models/constraints.smithy index 0bdfe6d2a1..ccef90c147 100644 --- a/codegen-core/common-test-models/constraints.smithy +++ b/codegen-core/common-test-models/constraints.smithy @@ -19,7 +19,6 @@ service ConstraintsService { // combination. QueryParamsTargetingLengthMapOperation, QueryParamsTargetingMapOfLengthStringOperation, - QueryParamsTargetingMapOfEnumStringOperation, QueryParamsTargetingMapOfListOfLengthStringOperation, QueryParamsTargetingMapOfSetOfLengthStringOperation, QueryParamsTargetingMapOfListOfEnumStringOperation, @@ -30,6 +29,9 @@ service ConstraintsService { QueryParamsTargetingMapOfListOfLengthPatternStringOperation, HttpPrefixHeadersTargetingLengthMapOperation, + + QueryParamsTargetingMapOfEnumStringOperation, + QueryParamsTargetingMapOfListOfEnumStringOperation, // TODO(https://github.com/awslabs/smithy-rs/issues/1431) // HttpPrefixHeadersTargetingMapOfEnumStringOperation, @@ -47,7 +49,7 @@ operation ConstrainedShapesOperation { errors: [ValidationException] } -@http(uri: "/constrained-http-bound-shapes-operation/{lengthStringLabel}/{enumStringLabel}", method: "POST") +@http(uri: "/constrained-http-bound-shapes-operation/{rangeIntegerLabel}/{lengthStringLabel}/{enumStringLabel}", method: "POST") operation ConstrainedHttpBoundShapesOperation { input: ConstrainedHttpBoundShapesOperationInputOutput, output: ConstrainedHttpBoundShapesOperationInputOutput, @@ -173,18 +175,24 @@ structure ConstrainedHttpBoundShapesOperationInputOutput { @httpLabel lengthStringLabel: LengthString, + @required + @httpLabel + rangeIntegerLabel: RangeInteger, + @required @httpLabel enumStringLabel: EnumString, - // TODO(https://github.com/awslabs/smithy-rs/issues/1394) `@required` not working - // @required - @httpPrefixHeaders("X-Prefix-Headers-") + @required + @httpPrefixHeaders("X-Length-String-Prefix-Headers-") lengthStringHeaderMap: MapOfLengthString, @httpHeader("X-Length") lengthStringHeader: LengthString, + @httpHeader("X-Range-Integer") + rangeIntegerHeader: RangeInteger, + // @httpHeader("X-Length-MediaType") // lengthStringHeaderWithMediaType: MediaTypeLengthString, @@ -196,6 +204,14 @@ structure ConstrainedHttpBoundShapesOperationInputOutput { @httpHeader("X-Length-List") lengthStringListHeader: ListOfLengthString, + // TODO(https://github.com/awslabs/smithy-rs/issues/1401): a `set` shape is + // just a `list` shape with `uniqueItems`, which hasn't been implemented yet. + // @httpHeader("X-Range-Integer-Set") + // rangeIntegerSetHeader: SetOfRangeInteger, + + @httpHeader("X-Range-Integer-List") + rangeIntegerListHeader: ListOfRangeInteger, + // TODO(https://github.com/awslabs/smithy-rs/issues/1431) // @httpHeader("X-Enum") //enumStringHeader: EnumString, @@ -206,6 +222,9 @@ structure ConstrainedHttpBoundShapesOperationInputOutput { @httpQuery("lengthString") lengthStringQuery: LengthString, + @httpQuery("rangeInteger") + rangeIntegerQuery: RangeInteger, + @httpQuery("enumString") enumStringQuery: EnumString, @@ -217,6 +236,14 @@ structure ConstrainedHttpBoundShapesOperationInputOutput { // @httpQuery("lengthStringSet") // lengthStringSetQuery: SetOfLengthString, + @httpQuery("rangeIntegerList") + rangeIntegerListQuery: ListOfRangeInteger, + + // TODO(https://github.com/awslabs/smithy-rs/issues/1401): a `set` shape is + // just a `list` shape with `uniqueItems`, which hasn't been implemented yet. + // @httpQuery("rangeIntegerSet") + // rangeIntegerSetQuery: SetOfRangeInteger, + @httpQuery("enumStringList") enumStringListQuery: ListOfEnumString, } @@ -332,6 +359,11 @@ structure ConA { maxLengthString: MaxLengthString, fixedLengthString: FixedLengthString, + rangeInteger: RangeInteger, + minRangeInteger: MinRangeInteger, + maxRangeInteger: MaxRangeInteger, + fixedValueInteger: FixedValueInteger, + conBList: ConBList, conBList2: ConBList2, @@ -352,7 +384,13 @@ structure ConA { // setOfLengthString: SetOfLengthString, mapOfLengthString: MapOfLengthString, - nonStreamingBlob: NonStreamingBlob, + listOfRangeInteger: ListOfRangeInteger, + // TODO(https://github.com/awslabs/smithy-rs/issues/1401): a `set` shape is + // just a `list` shape with `uniqueItems`, which hasn't been implemented yet. + // setOfRangeInteger: SetOfRangeInteger, + mapOfRangeInteger: MapOfRangeInteger, + + nonStreamingBlob: NonStreamingBlob patternString: PatternString, mapOfPatternString: MapOfPatternString, @@ -374,6 +412,11 @@ map MapOfLengthString { value: LengthString, } +map MapOfRangeInteger { + key: String, + value: RangeInteger, +} + map MapOfEnumString { key: EnumString, value: EnumString, @@ -407,6 +450,13 @@ map MapOfSetOfLengthString { value: ListOfLengthString } +// TODO(https://github.com/awslabs/smithy-rs/issues/1401): a `set` shape is +// just a `list` shape with `uniqueItems`, which hasn't been implemented yet. +// map MapOfSetOfRangeInteger { +// key: LengthString, +// value: SetOfRangeInteger, +// } + @length(min: 2, max: 8) list LengthListOfLengthString { member: LengthString @@ -418,7 +468,7 @@ string LengthString @length(min: 2) string MinLengthString -@length(min: 69) +@length(max: 69) string MaxLengthString @length(min: 69, max: 69) @@ -435,6 +485,18 @@ string LengthPatternString @length(min: 1, max: 69) string MediaTypeLengthString +@range(min: -0, max: 69) +integer RangeInteger + +@range(min: -10) +integer MinRangeInteger + +@range(max: 69) +integer MaxRangeInteger + +@range(min: 69, max: 69) +integer FixedValueInteger + /// A union with constrained members. union ConstrainedUnion { enumString: EnumString, @@ -480,6 +542,16 @@ list ListOfLengthString { member: LengthString } +// TODO(https://github.com/awslabs/smithy-rs/issues/1401): a `set` shape is +// just a `list` shape with `uniqueItems`, which hasn't been implemented yet. +// set SetOfRangeInteger { +// member: RangeInteger +// } + +list ListOfRangeInteger { + member: RangeInteger +} + list ListOfEnumString { member: EnumString } diff --git a/codegen-core/common-test-models/malformed-range-extras.smithy b/codegen-core/common-test-models/malformed-range-extras.smithy new file mode 100644 index 0000000000..8fd9d93c11 --- /dev/null +++ b/codegen-core/common-test-models/malformed-range-extras.smithy @@ -0,0 +1,662 @@ +$version: "2.0" + +namespace aws.protocoltests.extras.restjson.validation + +use aws.api#service +use aws.protocols#restJson1 +use smithy.test#httpMalformedRequestTests +use smithy.framework#ValidationException + +/// A REST JSON service that sends JSON requests and responses with validation applied +@service(sdkId: "Rest Json Validation Protocol") +@restJson1 +service MalformedRangeValidation { + version: "2022-11-23", + operations: [ + MalformedRange, + MalformedRangeOverride, + ] +} + +@suppress(["UnstableTrait"]) +@http(uri: "/MalformedRange", method: "POST") +operation MalformedRange { + input: MalformedRangeInput, + errors: [ValidationException] +} + +@suppress(["UnstableTrait"]) +@http(uri: "/MalformedRangeOverride", method: "POST") +operation MalformedRangeOverride { + input: MalformedRangeOverrideInput, + errors: [ValidationException] +} + +apply MalformedRange @httpMalformedRequestTests([ + { + id: "RestJsonMalformedRangeShort", + documentation: """ + When a short member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRange", + body: """ + { "short" : $value:L }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value $value:L at '/short' failed to satisfy constraint: Member must be between 2 and 8, inclusive", + "fieldList" : [{"message": "Value $value:L at '/short' failed to satisfy constraint: Member must be between 2 and 8, inclusive", "path": "/short"}]}""" + } + } + }, + testParameters: { + value: ["1", "9"] + } + }, + { + id: "RestJsonMalformedRangeMinShort", + documentation: """ + When a short member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRange", + body: """ + { "minShort" : 1 }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value 1 at '/minShort' failed to satisfy constraint: Member must be greater than or equal to 2", + "fieldList" : [{"message": "Value 1 at '/minShort' failed to satisfy constraint: Member must be greater than or equal to 2", "path": "/minShort"}]}""" + } + } + } + }, + { + id: "RestJsonMalformedRangeMaxShort", + documentation: """ + When a short member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRange", + body: """ + { "maxShort" : 9 }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value 9 at '/maxShort' failed to satisfy constraint: Member must be less than or equal to 8", + "fieldList" : [{"message": "Value 9 at '/maxShort' failed to satisfy constraint: Member must be less than or equal to 8", "path": "/maxShort"}]}""" + } + } + } + }, + { + id: "RestJsonMalformedRangeInteger", + documentation: """ + When a integer member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRange", + body: """ + { "integer" : $value:L }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value $value:L at '/integer' failed to satisfy constraint: Member must be between 2 and 8, inclusive", + "fieldList" : [{"message": "Value $value:L at '/integer' failed to satisfy constraint: Member must be between 2 and 8, inclusive", "path": "/integer"}]}""" + } + } + }, + testParameters: { + value: ["1", "9"] + } + }, + { + id: "RestJsonMalformedRangeMinInteger", + documentation: """ + When a integer member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRange", + body: """ + { "minInteger" : 1 }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value 1 at '/minInteger' failed to satisfy constraint: Member must be greater than or equal to 2", + "fieldList" : [{"message": "Value 1 at '/minInteger' failed to satisfy constraint: Member must be greater than or equal to 2", "path": "/minInteger"}]}""" + } + } + } + }, + { + id: "RestJsonMalformedRangeMaxInteger", + documentation: """ + When a integer member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRange", + body: """ + { "maxInteger" : 9 }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value 9 at '/maxInteger' failed to satisfy constraint: Member must be less than or equal to 8", + "fieldList" : [{"message": "Value 9 at '/maxInteger' failed to satisfy constraint: Member must be less than or equal to 8", "path": "/maxInteger"}]}""" + } + } + } + }, + { + id: "RestJsonMalformedRangeLong", + documentation: """ + When a long member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRange", + body: """ + { "long" : $value:L }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value $value:L at '/long' failed to satisfy constraint: Member must be between 2 and 8, inclusive", + "fieldList" : [{"message": "Value $value:L at '/long' failed to satisfy constraint: Member must be between 2 and 8, inclusive", "path": "/long"}]}""" + } + } + }, + testParameters: { + value: ["1", "9"] + } + }, + { + id: "RestJsonMalformedRangeMinLong", + documentation: """ + When a long member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRange", + body: """ + { "minLong" : 1 }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value 1 at '/minLong' failed to satisfy constraint: Member must be greater than or equal to 2", + "fieldList" : [{"message": "Value 1 at '/minLong' failed to satisfy constraint: Member must be greater than or equal to 2", "path": "/minLong"}]}""" + } + } + } + }, + { + id: "RestJsonMalformedRangeMaxLong", + documentation: """ + When a long member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRange", + body: """ + { "maxLong" : 9 }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value 9 at '/maxLong' failed to satisfy constraint: Member must be less than or equal to 8", + "fieldList" : [{"message": "Value 9 at '/maxLong' failed to satisfy constraint: Member must be less than or equal to 8", "path": "/maxLong"}]}""" + } + } + } + }, +]) + +// now repeat the above tests, but for the more specific constraints applied to the input member +apply MalformedRangeOverride @httpMalformedRequestTests([ + { + id: "RestJsonMalformedRangeShortOverride", + documentation: """ + When a short member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRangeOverride", + body: """ + { "short" : $value:L }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value $value:L at '/short' failed to satisfy constraint: Member must be between 4 and 6, inclusive", + "fieldList" : [{"message": "Value $value:L at '/short' failed to satisfy constraint: Member must be between 4 and 6, inclusive", "path": "/short"}]}""" + } + } + }, + testParameters: { + value: ["3", "7"] + } + }, + { + id: "RestJsonMalformedRangeMinShortOverride", + documentation: """ + When a short member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRangeOverride", + body: """ + { "minShort" : 3 }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value 3 at '/minShort' failed to satisfy constraint: Member must be greater than or equal to 4", + "fieldList" : [{"message": "Value 3 at '/minShort' failed to satisfy constraint: Member must be greater than or equal to 4", "path": "/minShort"}]}""" + } + } + } + }, + { + id: "RestJsonMalformedRangeMaxShortOverride", + documentation: """ + When a short member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRangeOverride", + body: """ + { "maxShort" : 7 }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value 7 at '/maxShort' failed to satisfy constraint: Member must be less than or equal to 6", + "fieldList" : [{"message": "Value 7 at '/maxShort' failed to satisfy constraint: Member must be less than or equal to 6", "path": "/maxShort"}]}""" + } + } + } + }, + { + id: "RestJsonMalformedRangeIntegerOverride", + documentation: """ + When a integer member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRangeOverride", + body: """ + { "integer" : $value:L }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value $value:L at '/integer' failed to satisfy constraint: Member must be between 4 and 6, inclusive", + "fieldList" : [{"message": "Value $value:L at '/integer' failed to satisfy constraint: Member must be between 4 and 6, inclusive", "path": "/integer"}]}""" + } + } + }, + testParameters: { + value: ["3", "7"] + } + }, + { + id: "RestJsonMalformedRangeMinIntegerOverride", + documentation: """ + When a integer member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRangeOverride", + body: """ + { "minInteger" : 3 }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value 3 at '/minInteger' failed to satisfy constraint: Member must be greater than or equal to 4", + "fieldList" : [{"message": "Value 3 at '/minInteger' failed to satisfy constraint: Member must be greater than or equal to 4", "path": "/minInteger"}]}""" + } + } + } + }, + { + id: "RestJsonMalformedRangeMaxIntegerOverride", + documentation: """ + When a integer member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRangeOverride", + body: """ + { "maxInteger" : 7 }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value 7 at '/maxInteger' failed to satisfy constraint: Member must be less than or equal to 6", + "fieldList" : [{"message": "Value 7 at '/maxInteger' failed to satisfy constraint: Member must be less than or equal to 6", "path": "/maxInteger"}]}""" + } + } + } + }, + { + id: "RestJsonMalformedRangeLongOverride", + documentation: """ + When a long member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRangeOverride", + body: """ + { "long" : $value:L }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value $value:L at '/long' failed to satisfy constraint: Member must be between 4 and 6, inclusive", + "fieldList" : [{"message": "Value $value:L at '/long' failed to satisfy constraint: Member must be between 4 and 6, inclusive", "path": "/long"}]}""" + } + } + }, + testParameters: { + value: ["3", "7"] + } + }, + { + id: "RestJsonMalformedRangeMinLongOverride", + documentation: """ + When a long member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRangeOverride", + body: """ + { "minLong" : 3 }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value 3 at '/minLong' failed to satisfy constraint: Member must be greater than or equal to 4", + "fieldList" : [{"message": "Value 3 at '/minLong' failed to satisfy constraint: Member must be greater than or equal to 4", "path": "/minLong"}]}""" + } + } + } + }, + { + id: "RestJsonMalformedRangeMaxLongOverride", + documentation: """ + When a long member does not fit within range bounds, + the response should be a 400 ValidationException.""", + protocol: restJson1, + request: { + method: "POST", + uri: "/MalformedRangeOverride", + body: """ + { "maxLong" : 7 }""", + headers: { + "content-type": "application/json" + } + }, + response: { + code: 400, + headers: { + "x-amzn-errortype": "ValidationException" + }, + body: { + mediaType: "application/json", + assertion: { + contents: """ + { "message" : "1 validation error detected. Value 7 at '/maxLong' failed to satisfy constraint: Member must be less than or equal to 6", + "fieldList" : [{"message": "Value 7 at '/maxLong' failed to satisfy constraint: Member must be less than or equal to 6", "path": "/maxLong"}]}""" + } + } + } + }, +]) + +structure MalformedRangeInput { + short: RangeShort, + minShort: MinShort, + maxShort: MaxShort, + + integer: RangeInteger, + minInteger: MinInteger, + maxInteger: MaxInteger, + + long: RangeLong, + minLong: MinLong, + maxLong: MaxLong, +} + +structure MalformedRangeOverrideInput { + @range(min: 4, max: 6) + short: RangeShort, + @range(min: 4) + minShort: MinShort, + @range(max: 6) + maxShort: MaxShort, + + @range(min: 4, max: 6) + integer: RangeInteger, + @range(min: 4) + minInteger: MinInteger, + @range(max: 6) + maxInteger: MaxInteger, + + @range(min: 4, max: 6) + long: RangeLong, + @range(min: 4) + minLong: MinLong, + @range(max: 6) + maxLong: MaxLong, +} + +@range(min: 2, max: 8) +short RangeShort + +@range(min: 2) +short MinShort + +@range(max: 8) +short MaxShort + +@range(min: 2, max: 8) +integer RangeInteger + +@range(min: 2) +integer MinInteger + +@range(max: 8) +integer MaxInteger + +@range(min: 2, max: 8) +long RangeLong + +@range(min: 2) +long MinLong + +@range(max: 8) +long MaxLong diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt index ebc4acd604..d1b14e9d8c 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt @@ -15,6 +15,8 @@ import software.amazon.smithy.codegen.core.SymbolWriter.Factory import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.BooleanShape import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.DoubleShape +import software.amazon.smithy.model.shapes.FloatShape import software.amazon.smithy.model.shapes.NumberShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.ShapeId @@ -22,6 +24,7 @@ import software.amazon.smithy.model.traits.DeprecatedTrait import software.amazon.smithy.model.traits.DocumentationTrait import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.ValueExpression import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.letIf @@ -465,33 +468,82 @@ class RustWriter private constructor( } /** - * Generate a wrapping if statement around a field. - * - * - If the field is optional, it will only be called if the field is present - * - If the field is an unboxed primitive, it will only be called if the field is non-zero - * + * Generate a wrapping if statement around a nullable value. + * The provided code block will only be called if the value is not `None`. */ - fun ifSet(shape: Shape, member: Symbol, outerField: String, block: RustWriter.(field: String) -> Unit) { + fun ifSome(member: Symbol, value: ValueExpression, block: RustWriter.(value: ValueExpression) -> Unit) { when { member.isOptional() -> { - val derefName = safeName("inner") - rustBlock("if let Some($derefName) = $outerField") { - block(derefName) + val innerValue = ValueExpression.Reference(safeName("inner")) + rustBlock("if let Some(${innerValue.name}) = ${value.asRef()}") { + block(innerValue) } } - shape is NumberShape -> rustBlock("if ${outerField.removePrefix("&")} != 0") { - block(outerField) + else -> this.block(value) + } + } + + /** + * Generate a wrapping if statement around a primitive field. + * The specified block will only be called if the field is not set to its default value - `0` for + * numbers, `false` for booleans. + */ + fun ifNotDefault(shape: Shape, variable: ValueExpression, block: RustWriter.(field: ValueExpression) -> Unit) { + when (shape) { + is FloatShape, is DoubleShape -> rustBlock("if ${variable.asValue()} != 0.0") { + block(variable) + } + + is NumberShape -> rustBlock("if ${variable.asValue()} != 0") { + block(variable) } - shape is BooleanShape -> rustBlock("if ${outerField.removePrefix("&")}") { - block(outerField) + is BooleanShape -> rustBlock("if ${variable.asValue()}") { + block(variable) } - else -> this.block(outerField) + else -> rustBlock("") { + this.block(variable) + } } } + /** + * Generate a wrapping if statement around a field. + * + * - If the field is optional, it will only be called if the field is present + * - If the field is an unboxed primitive, it will only be called if the field is non-zero + * + * # Example + * + * For a nullable structure shape (e.g. `Option`), the following code will be generated: + * + * ``` + * if let Some(v) = my_nullable_struct { + * /* {block(variable)} */ + * } + * ``` + * + * # Example + * + * For a non-nullable integer shape, the following code will be generated: + * + * ``` + * if my_int != 0 { + * /* {block(variable)} */ + * } + * ``` + */ + fun ifSet( + shape: Shape, + member: Symbol, + variable: ValueExpression, + block: RustWriter.(field: ValueExpression) -> Unit, + ) { + ifSome(member, variable) { inner -> ifNotDefault(shape, inner, block) } + } + fun listForEach( target: Shape, outerField: String, @@ -550,7 +602,8 @@ class RustWriter private constructor( inner class RustWriteableInjector : BiFunction { override fun apply(t: Any, u: String): String { @Suppress("UNCHECKED_CAST") - val func = t as? Writable ?: throw CodegenException("RustWriteableInjector.apply choked on non-function t ($t)") + val func = + t as? Writable ?: throw CodegenException("RustWriteableInjector.apply choked on non-function t ($t)") val innerWriter = RustWriter(filename, namespace, printWarning = false) func(innerWriter) innerWriter.dependencies.forEach { addDependency(it) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorGenerator.kt index 4303fa1a68..071a5bd89a 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorGenerator.kt @@ -25,6 +25,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.Std import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.mapRustType +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.ValueExpression import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.util.REDACTION import software.amazon.smithy.rust.codegen.core.util.dq @@ -144,8 +145,8 @@ class ErrorGenerator( if (it.shouldRedact(model)) { write("""write!(f, ": {}", $REDACTION)?;""") } else { - ifSet(it, symbolProvider.toSymbol(it), "&self.message") { field -> - write("""write!(f, ": {}", $field)?;""") + ifSet(it, symbolProvider.toSymbol(it), ValueExpression.Reference("&self.message")) { field -> + write("""write!(f, ": {}", ${field.asRef()})?;""") } } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt index 219784eff3..8ff3880c17 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt @@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators.http import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.codegen.core.SymbolProvider +import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.HttpBinding import software.amazon.smithy.model.knowledge.HttpBindingIndex import software.amazon.smithy.model.shapes.BlobShape @@ -29,8 +30,8 @@ import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.asOptional -import software.amazon.smithy.rust.codegen.core.rustlang.autoDeref import software.amazon.smithy.rust.codegen.core.rustlang.render import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock @@ -45,12 +46,14 @@ import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedSectionGen import software.amazon.smithy.rust.codegen.core.smithy.customize.Section import software.amazon.smithy.rust.codegen.core.smithy.generators.OperationBuildError import software.amazon.smithy.rust.codegen.core.smithy.generators.operationBuildError +import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.makeOptional import software.amazon.smithy.rust.codegen.core.smithy.mapRustType import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.EventStreamUnmarshallerGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.ValueExpression import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE import software.amazon.smithy.rust.codegen.core.util.dq @@ -80,6 +83,10 @@ enum class HttpMessageType { sealed class HttpBindingSection(name: String) : Section(name) { data class BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders(val variableName: String, val shape: MapShape) : HttpBindingSection("BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders") + + data class BeforeRenderingHeaderValue(var context: HttpBindingGenerator.HeaderValueSerializationContext) : + HttpBindingSection("BeforeRenderingHeaderValue") + data class AfterDeserializingIntoAHashMapOfHttpPrefixHeaders(val memberShape: MemberShape) : HttpBindingSection("AfterDeserializingIntoAHashMapOfHttpPrefixHeaders") } @@ -299,6 +306,7 @@ class HttpBindingGenerator( "error_symbol" to errorSymbol, ) } + HttpMessageType.REQUEST -> { rust("let body_str = std::str::from_utf8(body)?;") } @@ -319,6 +327,7 @@ class HttpBindingGenerator( rust("Ok(body_str.to_string())") } } + is BlobShape -> rust( "Ok(#T::new(body))", symbolProvider.toSymbol(targetShape), @@ -401,6 +410,7 @@ class HttpBindingGenerator( }) """, ) + is RustType.HashSet -> rust( """ @@ -411,6 +421,7 @@ class HttpBindingGenerator( }) """, ) + else -> { if (targetShape is ListShape) { // This is a constrained list shape and we must therefore be generating a server SDK. @@ -449,7 +460,9 @@ class HttpBindingGenerator( */ // Rename here technically not required, operations and members cannot be renamed. private fun fnName(operationShape: OperationShape, binding: HttpBindingDescriptor) = - "${operationShape.id.getName(service).toSnakeCase()}_${binding.member.container.name.toSnakeCase()}_${binding.memberName.toSnakeCase()}" + "${ + operationShape.id.getName(service).toSnakeCase() + }_${binding.member.container.name.toSnakeCase()}_${binding.memberName.toSnakeCase()}" /** * Returns a function to set headers on an HTTP message for the given [shape]. @@ -467,6 +480,7 @@ class HttpBindingGenerator( // Only a single structure member can be bound by `httpPrefixHeaders`, hence the `getOrNull(0)`. HttpMessageType.REQUEST -> index.getRequestBindings(shape, HttpLocation.HEADER) to index.getRequestBindings(shape, HttpLocation.PREFIX_HEADERS).getOrNull(0) + HttpMessageType.RESPONSE -> index.getResponseBindings(shape, HttpLocation.HEADER) to index.getResponseBindings(shape, HttpLocation.PREFIX_HEADERS).getOrNull(0) } @@ -517,50 +531,135 @@ class HttpBindingGenerator( check(httpBinding.location == HttpLocation.HEADER) val memberShape = httpBinding.member val targetShape = model.expectShape(memberShape.target) - val memberSymbol = symbolProvider.toSymbol(memberShape) val memberName = symbolProvider.toMemberName(memberShape) - ifSet(targetShape, memberSymbol, "&input.$memberName") { field -> - listForEach(targetShape, field) { innerField, targetId -> - val innerMemberType = model.expectShape(targetId) - if (innerMemberType.isPrimitive()) { - val encoder = CargoDependency.smithyTypes(runtimeConfig).toType().member("primitive::Encoder") - rust("let mut encoder = #T::from(${autoDeref(innerField)});", encoder) - } - val formatted = headerFmtFun( - this, - innerMemberType, - memberShape, - innerField, - isListHeader = targetShape is CollectionShape, - ) - val safeName = safeName("formatted") - write("let $safeName = $formatted;") - rustBlock("if !$safeName.is_empty()") { - rustTemplate( - """ - let header_value = $safeName; - let header_value = http::header::HeaderValue::try_from(&*header_value).map_err(|err| { - #{invalid_field_error:W} - })?; - builder = builder.header("${httpBinding.locationName}", header_value); - """, - "invalid_field_error" to OperationBuildError(runtimeConfig).invalidField(memberName) { - rust( - """ - format!( - "`{}` cannot be used as a header value: {}", - &${memberShape.redactIfNecessary(model, "header_value")}, - err - ) - """, - ) - }, + val headerName = httpBinding.locationName + val timestampFormat = + index.determineTimestampFormat(memberShape, HttpBinding.Location.HEADER, defaultTimestampFormat) + val renderErrorMessage = { headerValueVariableName: String -> + OperationBuildError(runtimeConfig).invalidField(memberName) { + rust( + """ + format!( + "`{}` cannot be used as a header value: {}", + &${memberShape.redactIfNecessary(model, headerValueVariableName)}, + err ) - } + """, + ) + } + } + + val memberSymbol = symbolProvider.toSymbol(memberShape) + // If a header is of a primitive type and required (e.g. `bool`), we do not serialize it on the + // wire if it's set to the default value for that primitive type (e.g. `false` for `bool`). + // If the header is optional, instead, we want to serialize it if it has been set by the user to the + // default value for that primitive type (e.g. `Some(false)` for an `Option` header). + // If a header is multivalued, we always want to serialize its primitive members, regardless of their + // values. + val serializePrimitiveValuesIfDefault = memberSymbol.isOptional() || (targetShape is CollectionShape) + ifSome(memberSymbol, ValueExpression.Reference("&input.$memberName")) { variableName -> + if (targetShape is CollectionShape) { + renderMultiValuedHeader( + model, + headerName, + variableName, + targetShape, + timestampFormat, + renderErrorMessage, + ) + } else { + renderHeaderValue( + headerName, + variableName, + targetShape, + false, + timestampFormat, + renderErrorMessage, + serializePrimitiveValuesIfDefault, + ) } } } + private fun RustWriter.renderMultiValuedHeader( + model: Model, + headerName: String, + value: ValueExpression, + shape: CollectionShape, + timestampFormat: TimestampFormatTrait.Format, + renderErrorMessage: (String) -> Writable, + ) { + val loopVariable = ValueExpression.Reference(safeName("inner")) + rustBlock("for ${loopVariable.name} in ${value.asRef()}") { + this.renderHeaderValue( + headerName, + loopVariable, + model.expectShape(shape.member.target), + isMultiValuedHeader = true, + timestampFormat, + renderErrorMessage, + serializeIfDefault = true, + ) + } + } + + data class HeaderValueSerializationContext( + /** Expression representing the value to write to the JsonValueWriter */ + var valueExpression: ValueExpression, + /** Path in the JSON to get here, used for errors */ + val shape: Shape, + ) + + private fun RustWriter.renderHeaderValue( + headerName: String, + value: ValueExpression, + shape: Shape, + isMultiValuedHeader: Boolean, + timestampFormat: TimestampFormatTrait.Format, + renderErrorMessage: (String) -> Writable, + serializeIfDefault: Boolean, + ) { + val context = HeaderValueSerializationContext(value, shape) + for (customization in customizations) { + customization.section( + HttpBindingSection.BeforeRenderingHeaderValue(context), + )(this) + } + + val block: RustWriter.(value: ValueExpression) -> Unit = { variableName -> + if (shape.isPrimitive()) { + val encoder = CargoDependency.smithyTypes(runtimeConfig).toType().member("primitive::Encoder") + rust("let mut encoder = #T::from(${variableName.asValue()});", encoder) + } + val formatted = headerFmtFun( + this, + shape, + timestampFormat, + context.valueExpression.name, + isMultiValuedHeader = isMultiValuedHeader, + ) + val safeName = safeName("formatted") + rustTemplate( + """ + let $safeName = $formatted; + if !$safeName.is_empty() { + let header_value = $safeName; + let header_value = http::header::HeaderValue::try_from(&*header_value).map_err(|err| { + #{invalid_field_error:W} + })?; + builder = builder.header("$headerName", header_value); + } + """, + "invalid_field_error" to renderErrorMessage("header_value"), + ) + } + if (serializeIfDefault) { + block(context.valueExpression) + } else { + ifNotDefault(context.shape, context.valueExpression, block) + } + } + private fun RustWriter.renderPrefixHeader(httpBinding: HttpBinding) { check(httpBinding.location == HttpLocation.PREFIX_HEADERS) val memberShape = httpBinding.member @@ -568,21 +667,31 @@ class HttpBindingGenerator( val memberSymbol = symbolProvider.toSymbol(memberShape) val memberName = symbolProvider.toMemberName(memberShape) val valueTargetShape = model.expectShape(targetShape.value.target) + val timestampFormat = + index.determineTimestampFormat(memberShape, HttpBinding.Location.HEADER, defaultTimestampFormat) - ifSet(targetShape, memberSymbol, "&input.$memberName") { field -> + ifSet(targetShape, memberSymbol, ValueExpression.Reference("&input.$memberName")) { local -> for (customization in customizations) { customization.section( - HttpBindingSection.BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders(field, targetShape), + HttpBindingSection.BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders(local.name, targetShape), )(this) } rustTemplate( """ - for (k, v) in $field { + for (k, v) in ${local.asRef()} { use std::str::FromStr; let header_name = http::header::HeaderName::from_str(&format!("{}{}", "${httpBinding.locationName}", &k)).map_err(|err| { #{invalid_header_name:W} })?; - let header_value = ${headerFmtFun(this, valueTargetShape, memberShape, "v", isListHeader = false)}; + let header_value = ${ + headerFmtFun( + this, + valueTargetShape, + timestampFormat, + "v", + isMultiValuedHeader = false, + ) + }; let header_value = http::header::HeaderValue::try_from(&*header_value).map_err(|err| { #{invalid_header_value:W} })?; @@ -611,10 +720,16 @@ class HttpBindingGenerator( /** * Format [member] when used as an HTTP header. */ - private fun headerFmtFun(writer: RustWriter, target: Shape, member: MemberShape, targetName: String, isListHeader: Boolean): String { + private fun headerFmtFun( + writer: RustWriter, + target: Shape, + timestampFormat: TimestampFormatTrait.Format, + targetName: String, + isMultiValuedHeader: Boolean, + ): String { fun quoteValue(value: String): String { // Timestamp shapes are not quoted in header lists - return if (isListHeader && !target.isTimestampShape) { + return if (isMultiValuedHeader && !target.isTimestampShape) { val quoteFn = writer.format(headerUtil.member("quote_header_value")) "$quoteFn($value)" } else { @@ -630,18 +745,20 @@ class HttpBindingGenerator( quoteValue("$targetName.as_str()") } } + target.isTimestampShape -> { - val timestampFormat = - index.determineTimestampFormat(member, HttpBinding.Location.HEADER, defaultTimestampFormat) val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat) quoteValue("$targetName.fmt(${writer.format(timestampFormatType)})?") } + target.isListShape || target.isMemberShape -> { throw IllegalArgumentException("lists should be handled at a higher level") } + target.isPrimitive() -> { "encoder.encode()" } + else -> throw CodegenException("unexpected shape: $target") } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/RequestBindingGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/RequestBindingGenerator.kt index 0a1421c938..e5d0a74c8e 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/RequestBindingGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/RequestBindingGenerator.kt @@ -31,6 +31,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.operationBuildError import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.ValueExpression import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.expectMember import software.amazon.smithy.rust.codegen.core.util.hasTrait @@ -193,8 +194,12 @@ class RequestBindingGenerator( val memberName = symbolProvider.toMemberName(memberShape) val targetShape = model.expectShape(memberShape.target, MapShape::class.java) val stringFormatter = RuntimeType.QueryFormat(runtimeConfig, "fmt_string") - ifSet(model.expectShape(param.member.target), memberSymbol, "&_input.$memberName") { field -> - rustBlock("for (k, v) in $field") { + ifSet( + model.expectShape(param.member.target), + memberSymbol, + ValueExpression.Reference("&_input.$memberName"), + ) { value -> + rustBlock("for (k, v) in ${value.asRef()}") { // if v is a list, generate another level of iteration listForEach(model.expectShape(targetShape.value.target), "v") { innerField, _ -> rustBlock("if !protected_params.contains(&k.as_str())") { @@ -236,9 +241,9 @@ class RequestBindingGenerator( paramList(target, derefName, param, writer, memberShape) } else { - ifSet(target, memberSymbol, "&_input.$memberName") { field -> + ifSet(target, memberSymbol, ValueExpression.Reference("&_input.$memberName")) { field -> // if `param` is a list, generate another level of iteration - paramList(target, field, param, writer, memberShape) + paramList(target, field.name, param, writer, memberShape) } } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt index 1d7a099272..4693cc31f7 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt @@ -334,7 +334,7 @@ class JsonParserGenerator( .map(#{NumberType}::try_from) .transpose()? """, - "NumberType" to symbolProvider.toSymbol(target), + "NumberType" to returnSymbolToParse(target).symbol, *codegenScope, ) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt index 82be9e7818..c7805fc468 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt @@ -7,14 +7,20 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.BooleanShape +import software.amazon.smithy.model.shapes.ByteShape import software.amazon.smithy.model.shapes.CollectionShape import software.amazon.smithy.model.shapes.DocumentShape +import software.amazon.smithy.model.shapes.DoubleShape +import software.amazon.smithy.model.shapes.FloatShape +import software.amazon.smithy.model.shapes.IntegerShape +import software.amazon.smithy.model.shapes.LongShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.NumberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.ShortShape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.TimestampShape @@ -22,7 +28,6 @@ import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.TimestampFormatTrait.Format.EPOCH_SECONDS import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustModule -import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock @@ -42,7 +47,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.serializeFunctionName -import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.expectTrait @@ -54,16 +58,24 @@ import software.amazon.smithy.rust.codegen.core.util.outputShape */ sealed class JsonSerializerSection(name: String) : Section(name) { /** Mutate the server error object prior to finalization. Eg: this can be used to inject `__type` to record the error type. */ - data class ServerError(val structureShape: StructureShape, val jsonObject: String) : JsonSerializerSection("ServerError") + data class ServerError(val structureShape: StructureShape, val jsonObject: String) : + JsonSerializerSection("ServerError") - /** Mutate a map prior to it being serialized. **/ - data class BeforeIteratingOverMap(val shape: MapShape, val valueExpression: ValueExpression) : JsonSerializerSection("BeforeIteratingOverMap") + /** Manipulate the serializer context for a map prior to it being serialized. **/ + data class BeforeIteratingOverMap(val shape: MapShape, val context: JsonSerializerGenerator.Context) : + JsonSerializerSection("BeforeIteratingOverMap") + + /** Manipulate the serializer context for a non-null member prior to it being serialized. **/ + data class BeforeSerializingNonNullMember(val shape: Shape, val context: JsonSerializerGenerator.MemberContext) : + JsonSerializerSection("BeforeSerializingNonNullMember") /** Mutate the input object prior to finalization. */ - data class InputStruct(val structureShape: StructureShape, val jsonObject: String) : JsonSerializerSection("InputStruct") + data class InputStruct(val structureShape: StructureShape, val jsonObject: String) : + JsonSerializerSection("InputStruct") /** Mutate the output object prior to finalization. */ - data class OutputStruct(val structureShape: StructureShape, val jsonObject: String) : JsonSerializerSection("OutputStruct") + data class OutputStruct(val structureShape: StructureShape, val jsonObject: String) : + JsonSerializerSection("OutputStruct") } /** @@ -78,20 +90,20 @@ class JsonSerializerGenerator( private val jsonName: (MemberShape) -> String, private val customizations: List = listOf(), ) : StructuredDataSerializerGenerator { - private data class Context( + data class Context( /** Expression that retrieves a JsonValueWriter from either a JsonObjectWriter or JsonArrayWriter */ val writerExpression: String, /** Expression representing the value to write to the JsonValueWriter */ - val valueExpression: ValueExpression, + var valueExpression: ValueExpression, /** Path in the JSON to get here, used for errors */ val shape: T, ) - private data class MemberContext( + data class MemberContext( /** Expression that retrieves a JsonValueWriter from either a JsonObjectWriter or JsonArrayWriter */ val writerExpression: String, /** Expression representing the value to write to the JsonValueWriter */ - val valueExpression: ValueExpression, + var valueExpression: ValueExpression, val shape: MemberShape, /** Whether to serialize null values if the type is optional */ val writeNulls: Boolean = false, @@ -144,7 +156,7 @@ class JsonSerializerGenerator( } // Specialized since it holds a JsonObjectWriter expression rather than a JsonValueWriter - private data class StructContext( + data class StructContext( /** Name of the JsonObjectWriter */ val objectName: String, /** Name of the variable that holds the struct */ @@ -337,8 +349,16 @@ class JsonSerializerGenerator( if (symbolProvider.toSymbol(context.shape).isOptional()) { safeName().also { local -> rustBlock("if let Some($local) = ${context.valueExpression.asRef()}") { - val innerContext = context.copy(valueExpression = ValueExpression.Reference(local)) - serializeMemberValue(innerContext, targetShape) + context.valueExpression = ValueExpression.Reference(local) + for (customization in customizations) { + customization.section( + JsonSerializerSection.BeforeSerializingNonNullMember( + targetShape, + context, + ), + )(this) + } + serializeMemberValue(context, targetShape) } if (context.writeNulls) { rustBlock("else") { @@ -347,6 +367,12 @@ class JsonSerializerGenerator( } } } else { + for (customization in customizations) { + customization.section(JsonSerializerSection.BeforeSerializingNonNullMember(targetShape, context))( + this, + ) + } + with(serializerUtil) { ignoreZeroValues(context.shape, context.valueExpression) { serializeMemberValue(context, targetShape) @@ -363,10 +389,9 @@ class JsonSerializerGenerator( is StringShape -> rust("$writer.string(${value.name}.as_str());") is BooleanShape -> rust("$writer.boolean(${value.asValue()});") is NumberShape -> { - val numberType = when (symbolProvider.toSymbol(target).rustType()) { - is RustType.Float -> "Float" - // NegInt takes an i64 while PosInt takes u64. We need this to be signed here - is RustType.Integer -> "NegInt" + val numberType = when (target) { + is IntegerShape, is ByteShape, is LongShape, is ShortShape -> "NegInt" + is DoubleShape, is FloatShape -> "Float" else -> throw IllegalStateException("unreachable") } rust( @@ -374,10 +399,12 @@ class JsonSerializerGenerator( smithyTypes.member("Number"), ) } + is BlobShape -> rust( "$writer.string_unchecked(&#T(${value.asRef()}));", RuntimeType.Base64Encode(runtimeConfig), ) + is TimestampShape -> { val timestampFormat = httpBindingResolver.timestampFormat(context.shape, HttpLocation.DOCUMENT, EPOCH_SECONDS) @@ -388,18 +415,23 @@ class JsonSerializerGenerator( "ConvertInto" to typeConversionGenerator.convertViaInto(target), ) } + is CollectionShape -> jsonArrayWriter(context) { arrayName -> serializeCollection(Context(arrayName, value, target)) } + is MapShape -> jsonObjectWriter(context) { objectName -> serializeMap(Context(objectName, value, target)) } + is StructureShape -> jsonObjectWriter(context) { objectName -> serializeStructure(StructContext(objectName, value.asRef(), target)) } + is UnionShape -> jsonObjectWriter(context) { objectName -> serializeUnion(Context(objectName, value, target)) } + is DocumentShape -> rust("$writer.document(${value.asRef()});") else -> TODO(target.toString()) } @@ -432,7 +464,9 @@ class JsonSerializerGenerator( val keyName = safeName("key") val valueName = safeName("value") for (customization in customizations) { - customization.section(JsonSerializerSection.BeforeIteratingOverMap(context.shape, context.valueExpression))(this) + customization.section(JsonSerializerSection.BeforeIteratingOverMap(context.shape, context))( + this, + ) } rustBlock("for ($keyName, $valueName) in ${context.valueExpression.asRef()}") { val keyExpression = "$keyName.as_str()" diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/SerializerUtil.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/SerializerUtil.kt index 98bdc80aa0..78012b1d6d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/SerializerUtil.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/SerializerUtil.kt @@ -6,11 +6,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize import software.amazon.smithy.model.Model -import software.amazon.smithy.model.shapes.BooleanShape -import software.amazon.smithy.model.shapes.DoubleShape -import software.amazon.smithy.model.shapes.FloatShape import software.amazon.smithy.model.shapes.MemberShape -import software.amazon.smithy.model.shapes.NumberShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable @@ -18,16 +14,9 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock class SerializerUtil(private val model: Model) { fun RustWriter.ignoreZeroValues(shape: MemberShape, value: ValueExpression, inner: Writable) { - val expr = when (model.expectShape(shape.target)) { - is FloatShape, is DoubleShape -> "${value.asValue()} != 0.0" - is NumberShape -> "${value.asValue()} != 0" - is BooleanShape -> value.asValue() - else -> null - } - - if (expr == null || - // Required shapes should always be serialized - // See https://github.com/awslabs/smithy-rs/issues/230 and https://github.com/aws/aws-sdk-go-v2/pull/1129 + // Required shapes should always be serialized + // See https://github.com/awslabs/smithy-rs/issues/230 and https://github.com/aws/aws-sdk-go-v2/pull/1129 + if ( shape.isRequired || // Zero values are always serialized in lists and collections, this only applies to structures model.expectShape(shape.container) !is StructureShape @@ -36,9 +25,7 @@ class SerializerUtil(private val model: Model) { inner(this) } } else { - rustBlock("if $expr") { - inner(this) - } + this.ifNotDefault(model.expectShape(shape.target), value) { inner(this) } } } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/ValueExpression.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/ValueExpression.kt index 6ab3aea1e3..00bc8ba74c 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/ValueExpression.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/ValueExpression.kt @@ -22,4 +22,6 @@ sealed class ValueExpression { is Reference -> name is Value -> "&$name" } + + override fun toString(): String = this.name } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt index 5bc0744ba1..3fff4b7804 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt @@ -67,7 +67,7 @@ fun testRustSettings( private const val SmithyVersion = "1.0" fun String.asSmithyModel(sourceLocation: String? = null, smithyVersion: String = SmithyVersion): Model { - val processed = letIf(!this.startsWith("\$version")) { "\$version: ${smithyVersion.dq()}\n$it" } + val processed = letIf(!this.trimStart().startsWith("\$version")) { "\$version: ${smithyVersion.dq()}\n$it" } return Model.assembler().discoverModels().addUnparsedModel(sourceLocation ?: "test.smithy", processed).assemble() .unwrap() } diff --git a/codegen-server-test/build.gradle.kts b/codegen-server-test/build.gradle.kts index a49bbcce50..82187c2db6 100644 --- a/codegen-server-test/build.gradle.kts +++ b/codegen-server-test/build.gradle.kts @@ -69,6 +69,11 @@ val allCodegenTests = "../codegen-core/common-test-models".let { commonModels -> "aws.protocoltests.restjson.validation#RestJsonValidation", "rest_json_validation", extraConfig = """, "codegen": { "ignoreUnsupportedConstraints": true } """, ), + CodegenTest( + "aws.protocoltests.extras.restjson.validation#MalformedRangeValidation", "malformed_range_extras", + extraConfig = """, "codegen": { "ignoreUnsupportedConstraints": true } """, + imports = listOf("$commonModels/malformed-range-extras.smithy"), + ), CodegenTest("aws.protocoltests.json10#JsonRpc10", "json_rpc10"), CodegenTest( "aws.protocoltests.json#JsonProtocol", diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProvider.kt index 625f3b5fa6..240d583065 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProvider.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProvider.kt @@ -9,6 +9,7 @@ import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.NullableIndex import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.IntegerShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.ServiceShape @@ -95,7 +96,7 @@ class ConstrainedShapeSymbolProvider( symbolBuilder(shape, RustType.Vec(inner.rustType())).addReference(inner).build() } } - is StringShape -> { + is StringShape, is IntegerShape -> { if (shape.isDirectlyConstrained(base)) { val rustType = RustType.Opaque(shape.contextName(serviceShape).toPascalCase()) symbolBuilder(shape, rustType).locatedIn(ModelsModule).build() diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintViolationSymbolProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintViolationSymbolProvider.kt index e2b6e23f01..264edf545b 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintViolationSymbolProvider.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintViolationSymbolProvider.kt @@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.server.smithy import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.IntegerShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.Shape @@ -112,8 +113,7 @@ class ConstraintViolationSymbolProvider( .locatedIn(module) .build() } - - is StringShape -> { + is StringShape, is IntegerShape -> { val module = shape.shapeModule() val rustType = RustType.Opaque(constraintViolationName, module.fullyQualifiedPath()) Symbol.builder() diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/Constraints.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/Constraints.kt index fe5b66ad82..592fe74282 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/Constraints.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/Constraints.kt @@ -9,6 +9,7 @@ import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.model.Model import software.amazon.smithy.model.neighbor.Walker import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.IntegerShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.Shape @@ -71,6 +72,7 @@ fun Shape.isDirectlyConstrained(symbolProvider: SymbolProvider): Boolean = when is MapShape -> this.hasTrait() is StringShape -> this.hasTrait() || supportedStringConstraintTraits.any { this.hasTrait(it) } + is IntegerShape -> this.hasTrait() else -> false } @@ -96,6 +98,7 @@ fun MemberShape.targetCanReachConstrainedShape(model: Model, symbolProvider: Sym fun Shape.hasPublicConstrainedWrapperTupleType(model: Model, publicConstrainedTypes: Boolean): Boolean = when (this) { is MapShape -> publicConstrainedTypes && this.hasTrait() is StringShape -> !this.hasTrait() && (publicConstrainedTypes && supportedStringConstraintTraits.any(this::hasTrait)) + is IntegerShape -> publicConstrainedTypes && this.hasTrait() is MemberShape -> model.expectShape(this.target).hasPublicConstrainedWrapperTupleType(model, publicConstrainedTypes) else -> false } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RangeTraitValidationErrorMessage.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RangeTraitValidationErrorMessage.kt new file mode 100644 index 0000000000..5512da8470 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RangeTraitValidationErrorMessage.kt @@ -0,0 +1,21 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import software.amazon.smithy.model.traits.RangeTrait + +fun RangeTrait.validationErrorMessage(): String { + val beginning = "Value {} at '{}' failed to satisfy constraint: Member must be " + val ending = if (this.min.isPresent && this.max.isPresent) { + "between ${this.min.get()} and ${this.max.get()}, inclusive" + } else if (this.min.isPresent) ( + "greater than or equal to ${this.min.get()}" + ) else { + check(this.max.isPresent) + "less than or equal to ${this.max.get()}" + } + return "$beginning$ending" +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt index 92f9bafa3b..d4aa6eccd1 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt @@ -11,6 +11,7 @@ import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.NullableIndex import software.amazon.smithy.model.neighbor.Walker import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.IntegerShape import software.amazon.smithy.model.shapes.ListShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.ServiceShape @@ -42,6 +43,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveSha import software.amazon.smithy.rust.codegen.core.util.CommandFailed import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.runCommand +import software.amazon.smithy.rust.codegen.server.smithy.generators.ConstrainedIntegerGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.ConstrainedMapGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.ConstrainedStringGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.ConstrainedTraitForEnumGenerator @@ -350,6 +352,15 @@ open class ServerCodegenVisitor( stringShape(shape, ::serverEnumGeneratorFactory) } + override fun integerShape(shape: IntegerShape) { + if (shape.isDirectlyConstrained(codegenContext.symbolProvider)) { + logger.info("[rust-server-codegen] Generating a constrained integer $shape") + rustCrate.withModule(ModelsModule) { + ConstrainedIntegerGenerator(codegenContext, this, shape).render() + } + } + } + protected fun stringShape( shape: StringShape, enumShapeGeneratorFactory: (codegenContext: ServerCodegenContext, writer: RustWriter, shape: StringShape) -> ServerEnumGenerator, diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraints.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraints.kt index 4535616bff..76f5fd1d25 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraints.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraints.kt @@ -10,6 +10,7 @@ import software.amazon.smithy.model.neighbor.Walker import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.CollectionShape import software.amazon.smithy.model.shapes.EnumShape +import software.amazon.smithy.model.shapes.IntegerShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape @@ -39,13 +40,17 @@ private sealed class UnsupportedConstraintMessageKind { """ $intro This is not supported in the smithy-rs server SDK. - ${ if (willSupport) "It will be supported in the future." else "" } + ${if (willSupport) "It will be supported in the future." else ""} See the tracking issue ($trackingIssue). If you want to go ahead and generate the server SDK ignoring unsupported constraint traits, set the key `ignoreUnsupportedConstraints` inside the `runtimeConfig.codegenConfig` JSON object in your `smithy-build.json` to `true`. """.trimIndent().replace("\n", " ") - fun buildMessageShapeHasUnsupportedConstraintTrait(shape: Shape, constraintTrait: Trait, trackingIssue: String) = + fun buildMessageShapeHasUnsupportedConstraintTrait( + shape: Shape, + constraintTrait: Trait, + trackingIssue: String, + ) = buildMessage( "The ${shape.type} shape `${shape.id}` has the constraint trait `${constraintTrait.toShapeId()}` attached.", willSupport = true, @@ -59,6 +64,7 @@ private sealed class UnsupportedConstraintMessageKind { level, buildMessageShapeHasUnsupportedConstraintTrait(shape, constraintTrait, constraintTraitsUberIssue), ) + is UnsupportedConstraintOnShapeReachableViaAnEventStream -> LogMessage( level, buildMessage( @@ -70,6 +76,7 @@ private sealed class UnsupportedConstraintMessageKind { "https://github.com/awslabs/smithy/issues/1388", ), ) + is UnsupportedLengthTraitOnStreamingBlobShape -> LogMessage( level, buildMessage( @@ -81,14 +88,17 @@ private sealed class UnsupportedConstraintMessageKind { "https://github.com/awslabs/smithy/issues/1389", ), ) + is UnsupportedLengthTraitOnCollectionOrOnBlobShape -> LogMessage( level, buildMessageShapeHasUnsupportedConstraintTrait(shape, lengthTrait, constraintTraitsUberIssue), ) + is UnsupportedRangeTraitOnShape -> LogMessage( level, buildMessageShapeHasUnsupportedConstraintTrait(shape, rangeTrait, constraintTraitsUberIssue), ) + is UnsupportedUniqueItemsTraitOnShape -> LogMessage( level, buildMessageShapeHasUnsupportedConstraintTrait(shape, uniqueItemsTrait, constraintTraitsUberIssue), @@ -96,13 +106,28 @@ private sealed class UnsupportedConstraintMessageKind { } } } + private data class OperationWithConstrainedInputWithoutValidationException(val shape: OperationShape) -private data class UnsupportedConstraintOnMemberShape(val shape: MemberShape, val constraintTrait: Trait) : UnsupportedConstraintMessageKind() -private data class UnsupportedConstraintOnShapeReachableViaAnEventStream(val shape: Shape, val constraintTrait: Trait) : UnsupportedConstraintMessageKind() -private data class UnsupportedLengthTraitOnStreamingBlobShape(val shape: BlobShape, val lengthTrait: LengthTrait, val streamingTrait: StreamingTrait) : UnsupportedConstraintMessageKind() -private data class UnsupportedLengthTraitOnCollectionOrOnBlobShape(val shape: Shape, val lengthTrait: LengthTrait) : UnsupportedConstraintMessageKind() -private data class UnsupportedRangeTraitOnShape(val shape: Shape, val rangeTrait: RangeTrait) : UnsupportedConstraintMessageKind() -private data class UnsupportedUniqueItemsTraitOnShape(val shape: Shape, val uniqueItemsTrait: UniqueItemsTrait) : UnsupportedConstraintMessageKind() +private data class UnsupportedConstraintOnMemberShape(val shape: MemberShape, val constraintTrait: Trait) : + UnsupportedConstraintMessageKind() + +private data class UnsupportedConstraintOnShapeReachableViaAnEventStream(val shape: Shape, val constraintTrait: Trait) : + UnsupportedConstraintMessageKind() + +private data class UnsupportedLengthTraitOnStreamingBlobShape( + val shape: BlobShape, + val lengthTrait: LengthTrait, + val streamingTrait: StreamingTrait, +) : UnsupportedConstraintMessageKind() + +private data class UnsupportedLengthTraitOnCollectionOrOnBlobShape(val shape: Shape, val lengthTrait: LengthTrait) : + UnsupportedConstraintMessageKind() + +private data class UnsupportedRangeTraitOnShape(val shape: Shape, val rangeTrait: RangeTrait) : + UnsupportedConstraintMessageKind() + +private data class UnsupportedUniqueItemsTraitOnShape(val shape: Shape, val uniqueItemsTrait: UniqueItemsTrait) : + UnsupportedConstraintMessageKind() data class LogMessage(val level: Level, val message: String) data class ValidationResult(val shouldAbort: Boolean, val messages: List) @@ -117,7 +142,10 @@ private val allConstraintTraits = setOf( ) private val unsupportedConstraintsOnMemberShapes = allConstraintTraits - RequiredTrait::class.java -fun validateOperationsWithConstrainedInputHaveValidationExceptionAttached(model: Model, service: ServiceShape): ValidationResult { +fun validateOperationsWithConstrainedInputHaveValidationExceptionAttached( + model: Model, + service: ServiceShape, +): ValidationResult { // Traverse the model and error out if an operation uses constrained input, but it does not have // `ValidationException` attached in `errors`. https://github.com/awslabs/smithy-rs/pull/1199#discussion_r809424783 // TODO(https://github.com/awslabs/smithy-rs/issues/1401): This check will go away once we add support for @@ -161,7 +189,11 @@ fun validateOperationsWithConstrainedInputHaveValidationExceptionAttached(model: return ValidationResult(shouldAbort = messages.any { it.level == Level.SEVERE }, messages) } -fun validateUnsupportedConstraints(model: Model, service: ServiceShape, codegenConfig: ServerCodegenConfig): ValidationResult { +fun validateUnsupportedConstraints( + model: Model, + service: ServiceShape, + codegenConfig: ServerCodegenConfig, +): ValidationResult { // Traverse the model and error out if: val walker = Walker(model) @@ -208,22 +240,28 @@ fun validateUnsupportedConstraints(model: Model, service: ServiceShape, codegenC .map { UnsupportedLengthTraitOnCollectionOrOnBlobShape(it, it.expectTrait()) } .toSet() - // 5. Range trait on any shape is used. It has not been implemented yet. + // 5. Range trait used on a non-integer shape. It has not been implemented yet. // TODO(https://github.com/awslabs/smithy-rs/issues/1401) val unsupportedRangeTraitOnShapeSet = walker .walkShapes(service) .asSequence() + .filterNot { it is IntegerShape } .filterMapShapesToTraits(setOf(RangeTrait::class.java)) .map { (shape, rangeTrait) -> UnsupportedRangeTraitOnShape(shape, rangeTrait as RangeTrait) } .toSet() - // 7. UniqueItems trait on any shape is used. It has not been implemented yet. + // 6. UniqueItems trait on any shape is used. It has not been implemented yet. // TODO(https://github.com/awslabs/smithy-rs/issues/1401) val unsupportedUniqueItemsTraitOnShapeSet = walker .walkShapes(service) .asSequence() .filterMapShapesToTraits(setOf(UniqueItemsTrait::class.java)) - .map { (shape, uniqueItemsTrait) -> UnsupportedUniqueItemsTraitOnShape(shape, uniqueItemsTrait as UniqueItemsTrait) } + .map { (shape, uniqueItemsTrait) -> + UnsupportedUniqueItemsTraitOnShape( + shape, + uniqueItemsTrait as UniqueItemsTrait, + ) + } .toSet() val messages = diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeIteratingOverMapJsonCustomization.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeIteratingOverMapJsonCustomization.kt index 820f5bc8b7..9fad74c044 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeIteratingOverMapJsonCustomization.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeIteratingOverMapJsonCustomization.kt @@ -10,6 +10,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerCustomization import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerSection +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.ValueExpression import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.workingWithPublicConstrainedWrapperTupleType @@ -26,11 +27,8 @@ class BeforeIteratingOverMapJsonCustomization(private val codegenContext: Server codegenContext.settings.codegenConfig.publicConstrainedTypes, ) ) { - // Note that this particular implementation just so happens to work because when the customization - // is invoked in the JSON serializer, the value expression is guaranteed to be a variable binding name. - // If the expression in the future were to be more complex, we wouldn't be able to write the left-hand - // side of this assignment. - rust("""let ${section.valueExpression.name} = &${section.valueExpression.name}.0;""") + section.context.valueExpression = + ValueExpression.Reference("&${section.context.valueExpression.name}.0") } } else -> emptySection diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberJsonCustomization.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberJsonCustomization.kt new file mode 100644 index 0000000000..0308a9dd5f --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberJsonCustomization.kt @@ -0,0 +1,38 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.customizations + +import software.amazon.smithy.model.shapes.IntegerShape +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerCustomization +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerSection +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.ValueExpression +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.workingWithPublicConstrainedWrapperTupleType + +/** + * A customization to, just before we serialize a _constrained_ shape in a JSON serializer, unwrap the wrapper + * newtype and take a shared reference to the actual unconstrained value within it. + */ +class BeforeSerializingMemberJsonCustomization(private val codegenContext: ServerCodegenContext) : JsonSerializerCustomization() { + override fun section(section: JsonSerializerSection): Writable = when (section) { + is JsonSerializerSection.BeforeSerializingNonNullMember -> writable { + if (workingWithPublicConstrainedWrapperTupleType( + section.shape, + codegenContext.model, + codegenContext.settings.codegenConfig.publicConstrainedTypes, + ) + ) { + if (section.shape is IntegerShape) { + section.context.valueExpression = + ValueExpression.Reference("&${section.context.valueExpression.name}.0") + } + } + } + else -> emptySection + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedIntegerGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedIntegerGenerator.kt new file mode 100644 index 0000000000..10c5e25396 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedIntegerGenerator.kt @@ -0,0 +1,216 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.IntegerShape +import software.amazon.smithy.model.traits.RangeTrait +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.docs +import software.amazon.smithy.rust.codegen.core.rustlang.documentShape +import software.amazon.smithy.rust.codegen.core.rustlang.render +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained +import software.amazon.smithy.rust.codegen.core.smithy.module +import software.amazon.smithy.rust.codegen.core.util.expectTrait +import software.amazon.smithy.rust.codegen.core.util.redactIfNecessary +import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput +import software.amazon.smithy.rust.codegen.server.smithy.validationErrorMessage + +/** + * [ConstrainedIntegerGenerator] generates a wrapper newtype holding a constrained `i32`. + * This type can be built from unconstrained values, yielding a `ConstraintViolation` when the input does not satisfy + * the constraints. + */ +class ConstrainedIntegerGenerator( + val codegenContext: ServerCodegenContext, + val writer: RustWriter, + val shape: IntegerShape, +) { + val model = codegenContext.model + val constrainedShapeSymbolProvider = codegenContext.constrainedShapeSymbolProvider + val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes + private val constraintViolationSymbolProvider = + with(codegenContext.constraintViolationSymbolProvider) { + if (publicConstrainedTypes) { + this + } else { + PubCrateConstraintViolationSymbolProvider(this) + } + } + + fun render() { + val rangeTrait = shape.expectTrait() + + val symbol = constrainedShapeSymbolProvider.toSymbol(shape) + val constrainedTypeName = symbol.name + val unconstrainedTypeName = RustType.Integer(32).render() + val constraintViolation = constraintViolationSymbolProvider.toSymbol(shape) + val constraintsInfo = listOf(Range(rangeTrait).toTraitInfo(unconstrainedTypeName)) + + val constrainedTypeVisibility = if (publicConstrainedTypes) { + Visibility.PUBLIC + } else { + Visibility.PUBCRATE + } + val constrainedTypeMetadata = RustMetadata( + Attribute.Derives( + setOf( + RuntimeType.Debug, + RuntimeType.Clone, + RuntimeType.PartialEq, + RuntimeType.Eq, + RuntimeType.Hash, + ), + ), + visibility = constrainedTypeVisibility, + ) + + writer.documentShape(shape, model, note = rustDocsNote(constrainedTypeName)) + constrainedTypeMetadata.render(writer) + writer.rust("struct $constrainedTypeName(pub(crate) $unconstrainedTypeName);") + + if (constrainedTypeVisibility == Visibility.PUBCRATE) { + Attribute.AllowUnused.render(writer) + } + writer.rustTemplate( + """ + impl $constrainedTypeName { + /// ${rustDocsInnerMethod(unconstrainedTypeName)} + pub fn inner(&self) -> &$unconstrainedTypeName { + &self.0 + } + + /// ${rustDocsIntoInnerMethod(unconstrainedTypeName)} + pub fn into_inner(self) -> $unconstrainedTypeName { + self.0 + } + } + + impl #{ConstrainedTrait} for $constrainedTypeName { + type Unconstrained = $unconstrainedTypeName; + } + + impl #{From}<$unconstrainedTypeName> for #{MaybeConstrained} { + fn from(value: $unconstrainedTypeName) -> Self { + Self::Unconstrained(value) + } + } + + impl #{Display} for $constrainedTypeName { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + ${shape.redactIfNecessary(model, "self.0")}.fmt(f) + } + } + + impl #{From}<$constrainedTypeName> for $unconstrainedTypeName { + fn from(value: $constrainedTypeName) -> Self { + value.into_inner() + } + } + """, + "ConstrainedTrait" to RuntimeType.ConstrainedTrait(), + "ConstraintViolation" to constraintViolation, + "MaybeConstrained" to symbol.makeMaybeConstrained(), + "Display" to RuntimeType.Display, + "From" to RuntimeType.From, + "TryFrom" to RuntimeType.TryFrom, + "AsRef" to RuntimeType.AsRef, + ) + + writer.renderTryFrom(unconstrainedTypeName, constrainedTypeName, constraintViolation, constraintsInfo) + + writer.withInlineModule(constraintViolation.module()) { + rust( + """ + ##[derive(Debug, PartialEq)] + pub enum ${constraintViolation.name} { + Range($unconstrainedTypeName), + } + """, + ) + + if (shape.isReachableFromOperationInput()) { + rustBlock("impl ${constraintViolation.name}") { + rustBlockTemplate( + "pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField", + "String" to RuntimeType.String, + ) { + rustBlock("match self") { + rust( + """ + Self::Range(value) => crate::model::ValidationExceptionField { + message: format!("${rangeTrait.validationErrorMessage()}", value, &path), + path, + }, + """, + ) + } + } + } + } + } + } +} + +private data class Range(val rangeTrait: RangeTrait) { + fun toTraitInfo(unconstrainedTypeName: String): TraitInfo = TraitInfo( + { rust("Self::check_range(value)?;") }, + { + docs("Error when an integer doesn't satisfy its `@range` requirements.") + rust("Range($unconstrainedTypeName)") + }, + { + rust( + """ + Self::Range(value) => crate::model::ValidationExceptionField { + message: format!("${rangeTrait.validationErrorMessage()}", value, &path), + path, + }, + """, + ) + }, + this::renderValidationFunction, + ) + + /** + * Renders a `check_range` function to validate the integer matches the + * required range indicated by the `@range` trait. + */ + private fun renderValidationFunction(constraintViolation: Symbol, unconstrainedTypeName: String): Writable = { + val valueVariableName = "value" + val condition = if (rangeTrait.min.isPresent && rangeTrait.max.isPresent) { + "(${rangeTrait.min.get()}..=${rangeTrait.max.get()}).contains(&$valueVariableName)" + } else if (rangeTrait.min.isPresent) { + "${rangeTrait.min.get()} <= $valueVariableName" + } else { + "$valueVariableName <= ${rangeTrait.max.get()}" + } + + rust( + """ + fn check_range($valueVariableName: $unconstrainedTypeName) -> Result<(), $constraintViolation> { + if $condition { + Ok(()) + } else { + Err($constraintViolation::Range($valueVariableName)) + } + } + """, + ) + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt index 898f98f17d..38d1a2acc7 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt @@ -62,35 +62,6 @@ class ConstrainedStringGenerator( .map(StringTraitInfo::fromTrait) .map(StringTraitInfo::toTraitInfo) - private fun renderTryFrom(inner: String, name: String, constraintViolation: Symbol) { - writer.rustTemplate( - """ - impl $name { - #{ValidationFunctions:W} - } - """, - "ValidationFunctions" to constraintsInfo.map { it.validationFunctionDefinition(constraintViolation) }.join("\n"), - ) - - writer.rustTemplate( - """ - impl #{TryFrom}<$inner> for $name { - type Error = #{ConstraintViolation}; - - /// ${rustDocsTryFromMethod(name, inner)} - fn try_from(value: $inner) -> Result { - #{TryFromChecks:W} - - Ok(Self(value)) - } - } - """, - "TryFrom" to RuntimeType.TryFrom, - "ConstraintViolation" to constraintViolation, - "TryFromChecks" to constraintsInfo.map { it.tryFromCheck }.join("\n"), - ) - } - fun render() { val symbol = constrainedShapeSymbolProvider.toSymbol(shape) val name = symbol.name @@ -136,7 +107,7 @@ class ConstrainedStringGenerator( }""", ) - renderTryFrom(inner, name, constraintViolation) + writer.renderTryFrom(inner, name, constraintViolation, constraintsInfo) writer.rustTemplate( """ @@ -204,55 +175,51 @@ class ConstrainedStringGenerator( } } private data class Length(val lengthTrait: LengthTrait) : StringTraitInfo() { - override fun toTraitInfo(): TraitInfo { - return TraitInfo( - { rust("Self::check_length(&value)?;") }, - { - docs("Error when a string doesn't satisfy its `@length` requirements.") - rust("Length(usize)") - }, - { - rust( - """ - Self::Length(length) => crate::model::ValidationExceptionField { - message: format!("${lengthTrait.validationErrorMessage()}", length, &path), - path, - }, - """, - ) - }, - this::renderValidationFunction, - ) - } + override fun toTraitInfo(): TraitInfo = TraitInfo( + { rust("Self::check_length(&value)?;") }, + { + docs("Error when a string doesn't satisfy its `@length` requirements.") + rust("Length(usize)") + }, + { + rust( + """ + Self::Length(length) => crate::model::ValidationExceptionField { + message: format!("${lengthTrait.validationErrorMessage()}", length, &path), + path, + }, + """, + ) + }, + this::renderValidationFunction, + ) /** * Renders a `check_length` function to validate the string matches the * required length indicated by the `@length` trait. */ - private fun renderValidationFunction(constraintViolation: Symbol): Writable { - return { - val condition = if (lengthTrait.min.isPresent && lengthTrait.max.isPresent) { - "(${lengthTrait.min.get()}..=${lengthTrait.max.get()}).contains(&length)" - } else if (lengthTrait.min.isPresent) { - "${lengthTrait.min.get()} <= length" - } else { - "length <= ${lengthTrait.max.get()}" - } + private fun renderValidationFunction(constraintViolation: Symbol, unconstrainedTypeName: String): Writable = { + val condition = if (lengthTrait.min.isPresent && lengthTrait.max.isPresent) { + "(${lengthTrait.min.get()}..=${lengthTrait.max.get()}).contains(&length)" + } else if (lengthTrait.min.isPresent) { + "${lengthTrait.min.get()} <= length" + } else { + "length <= ${lengthTrait.max.get()}" + } - rust( - """ - fn check_length(string: &str) -> Result<(), $constraintViolation> { - let length = string.chars().count(); + rust( + """ + fn check_length(string: &str) -> Result<(), $constraintViolation> { + let length = string.chars().count(); - if $condition { - Ok(()) - } else { - Err($constraintViolation::Length(length)) - } + if $condition { + Ok(()) + } else { + Err($constraintViolation::Length(length)) } - """, - ) - } + } + """, + ) } } @@ -285,14 +252,15 @@ private data class Pattern(val patternTrait: PatternTrait) : StringTraitInfo() { * Renders a `check_pattern` function to validate the string matches the * supplied regex in the `@pattern` trait. */ - private fun renderValidationFunction(constraintViolation: Symbol): Writable { + private fun renderValidationFunction(constraintViolation: Symbol, unconstrainedTypeName: String): Writable { val pattern = patternTrait.pattern - val errorMessageForUnsupportedRegex = """The regular expression $pattern is not supported by the `regex` crate; feel free to file an issue under https://github.com/awslabs/smithy-rs/issues for support""" + val errorMessageForUnsupportedRegex = + """The regular expression $pattern is not supported by the `regex` crate; feel free to file an issue under https://github.com/awslabs/smithy-rs/issues for support""" return { rustTemplate( """ - fn check_pattern(string: String) -> Result { + fn check_pattern(string: $unconstrainedTypeName) -> Result<$unconstrainedTypeName, $constraintViolation> { let regex = Self::compile_regex(); if regex.is_match(&string) { diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/TraitInfo.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/TraitInfo.kt index 3390382ba1..afd8b55aae 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/TraitInfo.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/TraitInfo.kt @@ -6,8 +6,11 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable -import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.join +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType /** * Information needed to render a constraint trait as Rust code. @@ -16,5 +19,48 @@ data class TraitInfo( val tryFromCheck: Writable, val constraintViolationVariant: Writable, val asValidationExceptionField: Writable, - val validationFunctionDefinition: (constraintViolation: Symbol) -> Writable, + val validationFunctionDefinition: (constraintViolation: Symbol, unconstrainedTypeName: String) -> Writable, ) + +/** + * Render the implementation of `TryFrom` for a constrained type. + */ +fun RustWriter.renderTryFrom( + unconstrainedTypeName: String, + constrainedTypeName: String, + constraintViolationError: Symbol, + constraintsInfo: List, +) { + this.rustTemplate( + """ + impl $constrainedTypeName { + #{ValidationFunctions:W} + } + """, + "ValidationFunctions" to constraintsInfo.map { + it.validationFunctionDefinition( + constraintViolationError, + unconstrainedTypeName, + ) + } + .join("\n"), + ) + + this.rustTemplate( + """ + impl #{TryFrom}<$unconstrainedTypeName> for $constrainedTypeName { + type Error = #{ConstraintViolation}; + + /// ${rustDocsTryFromMethod(constrainedTypeName, unconstrainedTypeName)} + fn try_from(value: $unconstrainedTypeName) -> Result { + #{TryFromChecks:W} + + Ok(Self(value)) + } + } + """, + "TryFrom" to RuntimeType.TryFrom, + "ConstraintViolation" to constraintViolationError, + "TryFromChecks" to constraintsInfo.map { it.tryFromCheck }.join("\n"), + ) +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt index f866d83e3a..b01c2f633e 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt @@ -75,7 +75,9 @@ class ServerRequestBindingGenerator( class ServerRequestAfterDeserializingIntoAHashMapOfHttpPrefixHeadersWrapInUnconstrainedMapHttpBindingCustomization(val codegenContext: ServerCodegenContext) : HttpBindingCustomization() { override fun section(section: HttpBindingSection): Writable = when (section) { - is HttpBindingSection.BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders -> emptySection + is HttpBindingSection.BeforeRenderingHeaderValue, + is HttpBindingSection.BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders, + -> emptySection is HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders -> writable { if (section.memberShape.targetCanReachConstrainedShape(codegenContext.model, codegenContext.unconstrainedShapeSymbolProvider)) { rust( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt index 2a030d1070..20e9f97363 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt @@ -18,6 +18,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindi import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingSection import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.ValueExpression import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderSymbol import software.amazon.smithy.rust.codegen.server.smithy.workingWithPublicConstrainedWrapperTupleType @@ -40,6 +41,9 @@ class ServerResponseBindingGenerator( ServerResponseBeforeIteratingOverMapBoundWithHttpPrefixHeadersUnwrapConstrainedMapHttpBindingCustomization( codegenContext, ), + ServerResponseBeforeRenderingHeadersHttpBindingCustomization( + codegenContext, + ), ), ) @@ -65,6 +69,36 @@ class ServerResponseBeforeIteratingOverMapBoundWithHttpPrefixHeadersUnwrapConstr rust("let ${section.variableName} = &${section.variableName}.0;") } } - is HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders -> emptySection + + is HttpBindingSection.BeforeRenderingHeaderValue, + is HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders, + -> emptySection + } +} + +/** + * A customization to, just before we render a _constrained_ member shape to an HTTP response header, + * unwrap the wrapper newtype and take a shared reference to the actual inner type within it. + */ +class ServerResponseBeforeRenderingHeadersHttpBindingCustomization(val codegenContext: ServerCodegenContext) : + HttpBindingCustomization() { + override fun section(section: HttpBindingSection): Writable = when (section) { + is HttpBindingSection.BeforeRenderingHeaderValue -> writable { + if (workingWithPublicConstrainedWrapperTupleType( + section.context.shape, + codegenContext.model, + codegenContext.settings.codegenConfig.publicConstrainedTypes, + ) + ) { + if (section.context.shape.isIntegerShape) { + section.context.valueExpression = + ValueExpression.Reference("&${section.context.valueExpression.name.removePrefix("&")}.0") + } + } + } + + is HttpBindingSection.BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders, + is HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders, + -> emptySection } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index fdc1078a4f..60a26e54c2 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -909,6 +909,7 @@ class ServerProtocolTestGenerator( private const val AwsJson11 = "aws.protocoltests.json#JsonProtocol" private const val RestJson = "aws.protocoltests.restjson#RestJson" private const val RestJsonValidation = "aws.protocoltests.restjson.validation#RestJsonValidation" + private const val MalformedRangeValidation = "aws.protocoltests.extras.restjson.validation#MalformedRangeValidation" private val ExpectFail: Set = setOf( // Pending merge from the Smithy team: see https://github.com/awslabs/smithy/pull/1477. FailingTest(RestJson, "RestJsonWithPayloadExpectsImpliedContentType", TestType.MalformedRequest), @@ -997,6 +998,31 @@ class ServerProtocolTestGenerator( FailingTest(RestJsonValidation, "RestJsonMalformedRangeMinFloat", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedPatternSensitiveString", TestType.MalformedRequest), + // Tests involving using @range on bytes, shorts and longs. + // See https://github.com/awslabs/smithy-rs/issues/1968 + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeShort_case0", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeShort_case1", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeLong_case0", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeLong_case1", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeMaxShort", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeMaxLong", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeMinShort", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeMinLong", TestType.MalformedRequest), + + // See https://github.com/awslabs/smithy-rs/issues/1969 + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeShortOverride_case0", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeShortOverride_case1", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeIntegerOverride_case0", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeIntegerOverride_case1", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeLongOverride_case0", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeLongOverride_case1", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeMaxShortOverride", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeMaxIntegerOverride", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeMaxLongOverride", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeMinShortOverride", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeMinIntegerOverride", TestType.MalformedRequest), + FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeMinLongOverride", TestType.MalformedRequest), + // Some tests for the S3 service (restXml). FailingTest("com.amazonaws.s3#AmazonS3", "GetBucketLocationUnwrappedOutput", TestType.Response), FailingTest("com.amazonaws.s3#AmazonS3", "S3DefaultAddressing", TestType.Request), diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt index 815608fb24..ad50011ede 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt @@ -22,6 +22,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.Struc import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeIteratingOverMapJsonCustomization +import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeSerializingMemberJsonCustomization import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerAwsJsonProtocol import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol @@ -31,7 +32,8 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.Ser */ class ServerAwsJsonFactory(private val version: AwsJsonVersion) : ProtocolGeneratorFactory { - override fun protocol(codegenContext: ServerCodegenContext): ServerProtocol = ServerAwsJsonProtocol(codegenContext, version) + override fun protocol(codegenContext: ServerCodegenContext): ServerProtocol = + ServerAwsJsonProtocol(codegenContext, version) override fun buildProtocolGenerator(codegenContext: ServerCodegenContext): ServerHttpBoundProtocolGenerator = ServerHttpBoundProtocolGenerator(codegenContext, protocol(codegenContext)) @@ -71,6 +73,7 @@ class ServerAwsJsonError(private val awsJsonVersion: AwsJsonVersion) : JsonSeria rust("""${section.jsonObject}.key("__type").string("${escape(typeId)}");""") } } + else -> emptySection } } @@ -90,6 +93,10 @@ class ServerAwsJsonSerializerGenerator( codegenContext, httpBindingResolver, ::awsJsonFieldName, - customizations = listOf(ServerAwsJsonError(awsJsonVersion), BeforeIteratingOverMapJsonCustomization(codegenContext)), + customizations = listOf( + ServerAwsJsonError(awsJsonVersion), + BeforeIteratingOverMapJsonCustomization(codegenContext), + BeforeSerializingMemberJsonCustomization(codegenContext), + ), ), ) : StructuredDataSerializerGenerator by jsonSerializerGenerator diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJson.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJson.kt index a913b806d2..a53bb363da 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJson.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJson.kt @@ -14,6 +14,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonS import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeIteratingOverMapJsonCustomization +import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeSerializingMemberJsonCustomization import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerRestJsonProtocol /** @@ -50,6 +51,9 @@ class ServerRestJsonSerializerGenerator( codegenContext, httpBindingResolver, ::restJsonFieldName, - customizations = listOf(BeforeIteratingOverMapJsonCustomization(codegenContext)), + customizations = listOf( + BeforeIteratingOverMapJsonCustomization(codegenContext), + BeforeSerializingMemberJsonCustomization(codegenContext), + ), ), ) : StructuredDataSerializerGenerator by jsonSerializerGenerator diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/traits/ShapeReachableFromOperationInputTagTrait.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/traits/ShapeReachableFromOperationInputTagTrait.kt index 14a2f72092..b4bd9a255b 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/traits/ShapeReachableFromOperationInputTagTrait.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/traits/ShapeReachableFromOperationInputTagTrait.kt @@ -7,6 +7,7 @@ package software.amazon.smithy.rust.codegen.server.smithy.traits import software.amazon.smithy.model.node.Node import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.IntegerShape import software.amazon.smithy.model.shapes.ListShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.Shape @@ -30,7 +31,7 @@ class ShapeReachableFromOperationInputTagTrait : AnnotationTrait(ID, Node.object } private fun isShapeReachableFromOperationInput(shape: Shape) = when (shape) { - is StructureShape, is UnionShape, is MapShape, is ListShape, is StringShape -> { + is StructureShape, is UnionShape, is MapShape, is ListShape, is StringShape, is IntegerShape -> { shape.hasTrait() } else -> PANIC("this method does not support shape type ${shape.type}") } @@ -40,3 +41,4 @@ fun StructureShape.isReachableFromOperationInput() = isShapeReachableFromOperati fun CollectionShape.isReachableFromOperationInput() = isShapeReachableFromOperationInput(this) fun UnionShape.isReachableFromOperationInput() = isShapeReachableFromOperationInput(this) fun MapShape.isReachableFromOperationInput() = isShapeReachableFromOperationInput(this) +fun IntegerShape.isReachableFromOperationInput() = isShapeReachableFromOperationInput(this) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ShapesReachableFromOperationInputTagger.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ShapesReachableFromOperationInputTagger.kt index cf58f3f9d9..0036238e6b 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ShapesReachableFromOperationInputTagger.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ShapesReachableFromOperationInputTagger.kt @@ -7,6 +7,7 @@ package software.amazon.smithy.rust.codegen.server.smithy.transformers import software.amazon.smithy.model.Model import software.amazon.smithy.model.neighbor.Walker +import software.amazon.smithy.model.shapes.IntegerShape import software.amazon.smithy.model.shapes.ListShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.StringShape @@ -50,7 +51,7 @@ object ShapesReachableFromOperationInputTagger { return ModelTransformer.create().mapShapes(model) { shape -> when (shape) { - is StructureShape, is UnionShape, is ListShape, is MapShape, is StringShape -> { + is StructureShape, is UnionShape, is ListShape, is MapShape, is StringShape, is IntegerShape -> { if (shapesReachableFromOperationInputs.contains(shape)) { val builder = when (shape) { is StructureShape -> shape.toBuilder() @@ -58,6 +59,7 @@ object ShapesReachableFromOperationInputTagger { is ListShape -> shape.toBuilder() is MapShape -> shape.toBuilder() is StringShape -> shape.toBuilder() + is IntegerShape -> shape.toBuilder() else -> UNREACHABLE("the `when` is exhaustive") } builder.addTrait(ShapeReachableFromOperationInputTagTrait()).build() diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProviderTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProviderTest.kt index bcf7fe34ce..fb6c2ec1c7 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProviderTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProviderTest.kt @@ -7,14 +7,20 @@ package software.amazon.smithy.rust.codegen.server.smithy import io.kotest.matchers.shouldBe import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.MethodSource +import software.amazon.smithy.model.shapes.IntegerShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.util.lookup import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider +import java.util.stream.Stream const val baseModelString = """ @@ -32,13 +38,17 @@ const val baseModelString = structure TestInputOutput { constrainedString: ConstrainedString, + constrainedInteger: ConstrainedInteger, constrainedMap: ConstrainedMap, unconstrainedMap: TransitivelyConstrainedMap } @length(min: 1, max: 69) string ConstrainedString - + + @range(min: 10, max: 29) + integer ConstrainedInteger + string UnconstrainedString @length(min: 1, max: 69) @@ -64,24 +74,33 @@ class ConstrainedShapeSymbolProviderTest { private val symbolProvider = serverTestSymbolProvider(model, serviceShape) private val constrainedShapeSymbolProvider = ConstrainedShapeSymbolProvider(symbolProvider, model, serviceShape) - private val constrainedMapShape = model.lookup("test#ConstrainedMap") - private val constrainedMapType = constrainedShapeSymbolProvider.toSymbol(constrainedMapShape).rustType() - - @Test - fun `it should return a constrained string type for a constrained string shape`() { - val constrainedStringShape = model.lookup("test#ConstrainedString") - val constrainedStringType = constrainedShapeSymbolProvider.toSymbol(constrainedStringShape).rustType() - - constrainedStringType shouldBe RustType.Opaque("ConstrainedString", "crate::model") + companion object { + @JvmStatic + fun getConstrainedShapes(): Stream = + Stream.of( + Arguments.of("ConstrainedInteger", { s: Shape -> s is IntegerShape }), + Arguments.of("ConstrainedString", { s: Shape -> s is StringShape }), + Arguments.of("ConstrainedMap", { s: Shape -> s is MapShape }), + ) } - @Test - fun `it should return a constrained map type for a constrained map shape`() { - constrainedMapType shouldBe RustType.Opaque("ConstrainedMap", "crate::model") + @ParameterizedTest + @MethodSource("getConstrainedShapes") + fun `it should return a constrained type for a constrained shape`( + shapeName: String, + shapeCheck: (Shape) -> Boolean, + ) { + val constrainedShape = model.lookup("test#$shapeName") + assert(shapeCheck(constrainedShape)) + val constrainedType = constrainedShapeSymbolProvider.toSymbol(constrainedShape).rustType() + + constrainedType shouldBe RustType.Opaque(shapeName, "crate::model") } @Test fun `it should not blindly delegate to the base symbol provider when the shape is an aggregate shape and is not directly constrained`() { + val constrainedMapShape = model.lookup("test#ConstrainedMap") + val constrainedMapType = constrainedShapeSymbolProvider.toSymbol(constrainedMapShape).rustType() val unconstrainedMapShape = model.lookup("test#TransitivelyConstrainedMap") val unconstrainedMapType = constrainedShapeSymbolProvider.toSymbol(unconstrainedMapShape).rustType() diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProviderTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProviderTest.kt index 21baefe747..f0b339a485 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProviderTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProviderTest.kt @@ -23,24 +23,30 @@ import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymb class PubCrateConstrainedShapeSymbolProviderTest { private val model = """ $baseModelString - + + structure NonTransitivelyConstrainedStructureShape { + constrainedString: ConstrainedString, + constrainedMap: ConstrainedMap, + unconstrainedMap: TransitivelyConstrainedMap + } + list TransitivelyConstrainedCollection { member: Structure } - + structure Structure { @required requiredMember: String } - + structure StructureWithMemberTargetingAggregateShape { member: TransitivelyConstrainedCollection } - + union Union { structure: Structure } - """.asSmithyModel() + """.asSmithyModel() private val serverTestSymbolProviders = serverTestSymbolProviders(model) private val symbolProvider = serverTestSymbolProviders.symbolProvider @@ -97,7 +103,7 @@ class PubCrateConstrainedShapeSymbolProviderTest { @Test fun `it should delegate to the base symbol provider when provided with a structure shape`() { - val structureShape = model.lookup("test#TestInputOutput") + val structureShape = model.lookup("test#NonTransitivelyConstrainedStructureShape") val structureSymbol = pubCrateConstrainedShapeSymbolProvider.toSymbol(structureShape) structureSymbol shouldBe symbolProvider.toSymbol(structureShape) diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraintsAreNotUsedTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraintsAreNotUsedTest.kt index 4c3fa0efe1..25461f1875 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraintsAreNotUsedTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraintsAreNotUsedTest.kt @@ -177,22 +177,32 @@ internal class ValidateUnsupportedConstraintsAreNotUsedTest { } @Test - fun `it should detect when the range trait is used`() { + fun `it should detect when the range trait is used on a shape we do not support`() { val model = """ $baseModel structure TestInputOutput { - rangeInteger: RangeInteger + rangeByte: RangeByte + rangeShort: RangeShort + rangeLong: RangeLong } @range(min: 1) - integer RangeInteger + byte RangeByte + + @range(min: 1) + long RangeLong + + @range(min: 1) + short RangeShort """.asSmithyModel() val validationResult = validateModel(model) - validationResult.messages shouldHaveSize 1 - validationResult.messages[0].message shouldContain "The integer shape `test#RangeInteger` has the constraint trait `smithy.api#range` attached" + validationResult.messages shouldHaveSize 3 + validationResult.messages[0].message shouldContain "The long shape `test#RangeLong` has the constraint trait `smithy.api#range` attached" + validationResult.messages[1].message shouldContain "The short shape `test#RangeShort` has the constraint trait `smithy.api#range` attached" + validationResult.messages[2].message shouldContain "The byte shape `test#RangeByte` has the constraint trait `smithy.api#range` attached" } @Test @@ -200,11 +210,11 @@ internal class ValidateUnsupportedConstraintsAreNotUsedTest { val model = """ $baseModel - + structure TestInputOutput { uniqueItemsList: UniqueItemsList } - + @uniqueItems list UniqueItemsList { member: String diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedIntegerGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedIntegerGeneratorTest.kt new file mode 100644 index 0000000000..a889281327 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedIntegerGeneratorTest.kt @@ -0,0 +1,124 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import io.kotest.matchers.string.shouldContain +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.ExtensionContext +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.ArgumentsProvider +import org.junit.jupiter.params.provider.ArgumentsSource +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.IntegerShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext +import java.util.stream.Stream + +class ConstrainedIntegerGeneratorTest { + + data class TestCase(val model: Model, val validInteger: Int, val invalidInteger: Int) + + class ConstrainedIntGeneratorTestProvider : ArgumentsProvider { + private val testCases = listOf( + // Min and max. + Triple("@range(min: 10, max: 12)", 11, 13), + // Min equal to max. + Triple("@range(min: 11, max: 11)", 11, 12), + // Only min. + Triple("@range(min: 11)", 12, 2), + // Only max. + Triple("@range(max: 11)", 0, 12), + ).map { + TestCase( + """ + namespace test + + ${it.first} + integer ConstrainedInteger + """.asSmithyModel(), + it.second, + it.third, + ) + } + + override fun provideArguments(context: ExtensionContext?): Stream = + testCases.map { Arguments.of(it) }.stream() + } + + @ParameterizedTest + @ArgumentsSource(ConstrainedIntGeneratorTestProvider::class) + fun `it should generate constrained integer types`(testCase: TestCase) { + val constrainedIntegerShape = testCase.model.lookup("test#ConstrainedInteger") + + val codegenContext = serverTestCodegenContext(testCase.model) + val symbolProvider = codegenContext.symbolProvider + + val project = TestWorkspace.testProject(symbolProvider) + + project.withModule(ModelsModule) { + ConstrainedIntegerGenerator(codegenContext, this, constrainedIntegerShape).render() + + unitTest( + name = "try_from_success", + test = """ + let _constrained: ConstrainedInteger = ${testCase.validInteger}.try_into().unwrap(); + """, + ) + unitTest( + name = "try_from_fail", + test = """ + let constrained_res: Result = ${testCase.invalidInteger}.try_into(); + constrained_res.unwrap_err(); + """, + ) + unitTest( + name = "inner", + test = """ + let constrained = ConstrainedInteger::try_from(${testCase.validInteger}).unwrap(); + assert_eq!(constrained.inner(), &${testCase.validInteger}); + """, + ) + unitTest( + name = "into_inner", + test = """ + let int = ${testCase.validInteger}; + let constrained = ConstrainedInteger::try_from(int).unwrap(); + + assert_eq!(constrained.into_inner(), int); + """, + ) + } + + project.compileAndTest() + } + + @Test + fun `type should not be constructible without using a constructor`() { + val model = """ + namespace test + + @range(min: -1, max: 69) + integer ConstrainedInteger + """.asSmithyModel() + val constrainedIntegerShape = model.lookup("test#ConstrainedInteger") + + val codegenContext = serverTestCodegenContext(model) + + val writer = RustWriter.forModule(ModelsModule.name) + + ConstrainedIntegerGenerator(codegenContext, writer, constrainedIntegerShape).render() + + // Check that the wrapped type is `pub(crate)`. + writer.toString() shouldContain "pub struct ConstrainedInteger(pub(crate) i32);" + } +}