diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index 7a81809aec..1ed5cdfb59 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -370,3 +370,15 @@ set_credentials_cache( references = ["smithy-rs#2122"] meta = { "breaking" = true, "tada" = false, "bug" = false } author = "ysaito1001" + +[[smithy-rs]] +message = "`aws_smithy_types::date_time::DateTime`, `aws_smithy_types::Blob` now implement the `Eq` and `Hash` traits" +references = ["smithy-rs#2223"] +meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "all"} +author = "david-perez" + +[[smithy-rs]] +message = "Code-generated types for server SDKs now implement the `Eq` and `Hash` traits when possible" +references = ["smithy-rs#2223"] +meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "server"} +author = "david-perez" diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt index 8750fa0c96..0189a4cb65 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt @@ -80,7 +80,7 @@ class RustClientCodegenPlugin : DecoratableBuildPlugin() { .let { StreamingShapeSymbolProvider(it, model) } // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes .let { BaseSymbolMetadataProvider(it, model, additionalAttributes = listOf(NonExhaustive)) } - // Streaming shapes need different derives (e.g. they cannot derive Eq) + // Streaming shapes need different derives (e.g. they cannot derive `PartialEq`) .let { StreamingShapeMetadataProvider(it, model) } // Rename shapes that clash with Rust reserved words & and other SDK specific features e.g. `send()` cannot // be the name of an operation input diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/StreamingTraitSymbolProvider.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/StreamingTraitSymbolProvider.kt index 2e91ee5fba..3e1d082627 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/StreamingTraitSymbolProvider.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/StreamingTraitSymbolProvider.kt @@ -8,7 +8,10 @@ package software.amazon.smithy.rust.codegen.core.smithy import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.NumberShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape @@ -27,7 +30,7 @@ class StreamingShapeSymbolProvider(private val base: RustSymbolProvider, private WrappingSymbolProvider(base) { override fun toSymbol(shape: Shape): Symbol { val initial = base.toSymbol(shape) - // We are only targetting member shapes + // We are only targeting member shapes if (shape !is MemberShape) { return initial } @@ -49,7 +52,7 @@ class StreamingShapeSymbolProvider(private val base: RustSymbolProvider, private } /** - * SymbolProvider to drop the clone and PartialEq bounds in streaming shapes + * SymbolProvider to drop the `Clone` and `PartialEq` bounds in streaming shapes. * * Streaming shapes cannot be cloned and equality cannot be checked without reading the body. Because of this, these shapes * do not implement `Clone` or `PartialEq`. @@ -60,10 +63,6 @@ class StreamingShapeMetadataProvider( private val base: RustSymbolProvider, private val model: Model, ) : SymbolMetadataProvider(base) { - override fun memberMeta(memberShape: MemberShape): RustMetadata { - return base.toSymbol(memberShape).expectRustMetadata() - } - override fun structureMeta(structureShape: StructureShape): RustMetadata { val baseMetadata = base.toSymbol(structureShape).expectRustMetadata() return if (structureShape.hasStreamingMember(model)) { @@ -78,7 +77,12 @@ class StreamingShapeMetadataProvider( } else baseMetadata } - override fun enumMeta(stringShape: StringShape): RustMetadata { - return base.toSymbol(stringShape).expectRustMetadata() - } + override fun memberMeta(memberShape: MemberShape) = base.toSymbol(memberShape).expectRustMetadata() + override fun enumMeta(stringShape: StringShape) = base.toSymbol(stringShape).expectRustMetadata() + + override fun listMeta(listShape: ListShape) = base.toSymbol(listShape).expectRustMetadata() + override fun mapMeta(mapShape: MapShape) = base.toSymbol(mapShape).expectRustMetadata() + override fun stringMeta(stringShape: StringShape) = base.toSymbol(stringShape).expectRustMetadata() + override fun numberMeta(numberShape: NumberShape) = base.toSymbol(numberShape).expectRustMetadata() + override fun blobMeta(blobShape: BlobShape) = base.toSymbol(blobShape).expectRustMetadata() } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolMetadataProvider.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolMetadataProvider.kt index 4f6f095e94..a7017b504e 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolMetadataProvider.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolMetadataProvider.kt @@ -8,7 +8,12 @@ package software.amazon.smithy.rust.codegen.core.smithy import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.NumberShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape @@ -55,9 +60,15 @@ abstract class SymbolMetadataProvider(private val base: RustSymbolProvider) : Wr is MemberShape -> memberMeta(shape) is StructureShape -> structureMeta(shape) is UnionShape -> unionMeta(shape) + is ListShape -> listMeta(shape) + is MapShape -> mapMeta(shape) + is NumberShape -> numberMeta(shape) + is BlobShape -> blobMeta(shape) is StringShape -> if (shape.hasTrait()) { enumMeta(shape) - } else null + } else { + stringMeta(shape) + } else -> null } @@ -68,98 +79,101 @@ abstract class SymbolMetadataProvider(private val base: RustSymbolProvider) : Wr abstract fun structureMeta(structureShape: StructureShape): RustMetadata abstract fun unionMeta(unionShape: UnionShape): RustMetadata abstract fun enumMeta(stringShape: StringShape): RustMetadata + + abstract fun listMeta(listShape: ListShape): RustMetadata + abstract fun mapMeta(mapShape: MapShape): RustMetadata + abstract fun stringMeta(stringShape: StringShape): RustMetadata + abstract fun numberMeta(numberShape: NumberShape): RustMetadata + abstract fun blobMeta(blobShape: BlobShape): RustMetadata +} + +fun containerDefaultMetadata( + shape: Shape, + model: Model, + additionalAttributes: List = emptyList(), +): RustMetadata { + val defaultDerives = setOf(RuntimeType.Debug, RuntimeType.PartialEq, RuntimeType.Clone) + + val isSensitive = shape.hasTrait() || + // Checking the shape's direct members for the sensitive trait should suffice. + // Whether their descendants, i.e. a member's member, is sensitive does not + // affect the inclusion/exclusion of the derived `Debug` trait of _this_ container + // shape; any sensitive descendant should still be printed as redacted. + shape.members().any { it.getMemberTrait(model, SensitiveTrait::class.java).isPresent } + + val setOfDerives = if (isSensitive) { + defaultDerives - RuntimeType.Debug + } else { + defaultDerives + } + return RustMetadata( + setOfDerives, + additionalAttributes, + Visibility.PUBLIC, + ) } /** - * The base metadata supports a list of attributes that are used by generators to decorate code. - * By default, we apply ```#[non_exhaustive]``` only to client structures since model changes should - * be considered as breaking only when generating server code. + * The base metadata supports a set of attributes that are used by generators to decorate code. + * + * By default we apply `#[non_exhaustive]` in [additionalAttributes] only to client structures since breaking model + * changes are fine when generating server code. */ class BaseSymbolMetadataProvider( base: RustSymbolProvider, private val model: Model, private val additionalAttributes: List, ) : SymbolMetadataProvider(base) { - private fun containerDefault(shape: Shape): RustMetadata { - val isSensitive = shape.hasTrait() || - // Checking the shape's direct members for the sensitive trait should suffice. - // Whether their descendants, i.e. a member's member, is sensitive does not - // affect the inclusion/exclusion of the derived Debug trait of _this_ container - // shape; any sensitive descendant should still be printed as redacted. - shape.members().any { it.getMemberTrait(model, SensitiveTrait::class.java).isPresent } - - val derives = if (isSensitive) { - defaultDerives - RuntimeType.Debug - } else { - defaultDerives - } - return RustMetadata( - derives, - additionalAttributes, - Visibility.PUBLIC, - ) - } - override fun memberMeta(memberShape: MemberShape): RustMetadata { - val container = model.expectShape(memberShape.container) - return when { - container.isStructureShape -> { + override fun memberMeta(memberShape: MemberShape): RustMetadata = + when (val container = model.expectShape(memberShape.container)) { + is StructureShape -> { // TODO(https://github.com/awslabs/smithy-rs/issues/943): Once streaming accessors are usable, // then also make streaming members `#[doc(hidden)]` if (memberShape.getMemberTrait(model, StreamingTrait::class.java).isPresent) { RustMetadata(visibility = Visibility.PUBLIC) } else { RustMetadata( - // At some point, visibility will be made PRIVATE, so make these `#[doc(hidden)]` for now + // At some point, visibility _may_ be made `PRIVATE`, so make these `#[doc(hidden)]` for now. visibility = Visibility.PUBLIC, additionalAttributes = listOf(Attribute.DocHidden), ) } } - container.isUnionShape || - container.isListShape || - container.isSetShape || - container.isMapShape - -> RustMetadata(visibility = Visibility.PUBLIC) - + is UnionShape, is CollectionShape, is MapShape -> RustMetadata(visibility = Visibility.PUBLIC) else -> TODO("Unrecognized container type: $container") } - } - - override fun structureMeta(structureShape: StructureShape): RustMetadata { - return containerDefault(structureShape) - } - override fun unionMeta(unionShape: UnionShape): RustMetadata { - return containerDefault(unionShape) - } + override fun structureMeta(structureShape: StructureShape) = containerDefaultMetadata(structureShape, model, additionalAttributes) + override fun unionMeta(unionShape: UnionShape) = containerDefaultMetadata(unionShape, model, additionalAttributes) - override fun enumMeta(stringShape: StringShape): RustMetadata { - return containerDefault(stringShape).withDerives( - RuntimeType.Hash, - // enums can be Eq because they can only contain ints and strings + override fun enumMeta(stringShape: StringShape): RustMetadata = + containerDefaultMetadata(stringShape, model, additionalAttributes).withDerives( + // Smithy's `enum` shapes can additionally be `Eq`, `PartialOrd`, `Ord`, and `Hash` because they can + // only contain strings. RuntimeType.Eq, - // enums can be PartialOrd/Ord because they can only contain ints and strings RuntimeType.PartialOrd, RuntimeType.Ord, + RuntimeType.Hash, ) - } - companion object { - private val defaultDerives by lazy { - setOf(RuntimeType.Debug, RuntimeType.PartialEq, RuntimeType.Clone) - } - } + // Only the server subproject uses these, so we provide a sane and conservative default implementation here so that + // the rest of symbol metadata providers can just delegate to it. + private val defaultRustMetadata = RustMetadata(visibility = Visibility.PRIVATE) + + override fun listMeta(listShape: ListShape) = defaultRustMetadata + override fun mapMeta(mapShape: MapShape) = defaultRustMetadata + override fun stringMeta(stringShape: StringShape) = defaultRustMetadata + override fun numberMeta(numberShape: NumberShape) = defaultRustMetadata + override fun blobMeta(blobShape: BlobShape) = defaultRustMetadata } private const val META_KEY = "meta" -fun Symbol.Builder.meta(rustMetadata: RustMetadata?): Symbol.Builder { - return this.putProperty(META_KEY, rustMetadata) -} +fun Symbol.Builder.meta(rustMetadata: RustMetadata?): Symbol.Builder = this.putProperty(META_KEY, rustMetadata) fun Symbol.expectRustMetadata(): RustMetadata = this.getProperty(META_KEY, RustMetadata::class.java).orElseThrow { CodegenException( - "Expected $this to have metadata attached but it did not. ", + "Expected `$this` to have metadata attached but it did not.", ) } diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonCodegenServerPlugin.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonCodegenServerPlugin.kt index 3d8a57ef6b..7e277185f7 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonCodegenServerPlugin.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonCodegenServerPlugin.kt @@ -17,7 +17,9 @@ import software.amazon.smithy.rust.codegen.core.smithy.EventStreamSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitor import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig import software.amazon.smithy.rust.codegen.server.python.smithy.customizations.DECORATORS +import software.amazon.smithy.rust.codegen.server.smithy.ConstrainedShapeSymbolMetadataProvider import software.amazon.smithy.rust.codegen.server.smithy.ConstrainedShapeSymbolProvider +import software.amazon.smithy.rust.codegen.server.smithy.DeriveEqAndHashSymbolMetadataProvider import software.amazon.smithy.rust.codegen.server.smithy.customizations.ServerRequiredCustomizations import software.amazon.smithy.rust.codegen.server.smithy.customize.CombinedServerCodegenDecorator import java.util.logging.Level @@ -54,7 +56,7 @@ class PythonCodegenServerPlugin : SmithyBuildPlugin { } companion object { - /** SymbolProvider + /** * When generating code, smithy types need to be converted into Rust types—that is the core role of the symbol provider * * The Symbol provider is composed of a base [SymbolVisitor] which handles the core functionality, then is layered @@ -77,8 +79,12 @@ class PythonCodegenServerPlugin : SmithyBuildPlugin { .let { EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model, CodegenTarget.SERVER) } // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes .let { BaseSymbolMetadataProvider(it, model, additionalAttributes = listOf()) } + // Constrained shapes generate newtypes that need the same derives we place on types generated from aggregate shapes. + .let { ConstrainedShapeSymbolMetadataProvider(it, model, constrainedTypes) } // Streaming shapes need different derives (e.g. they cannot derive Eq) .let { PythonStreamingShapeMetadataProvider(it, model) } + // Derive `Eq` and `Hash` if possible. + .let { DeriveEqAndHashSymbolMetadataProvider(it, model) } // Rename shapes that clash with Rust reserved words & and other SDK specific features e.g. `send()` cannot // be the name of an operation input .let { RustReservedWordSymbolProvider(it, model) } diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerSymbolProvider.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerSymbolProvider.kt index 1247d1e064..57e77ffe03 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerSymbolProvider.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerSymbolProvider.kt @@ -8,7 +8,10 @@ package software.amazon.smithy.rust.codegen.server.python.smithy import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.NumberShape import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StringShape @@ -55,7 +58,7 @@ class PythonServerSymbolVisitor( val target = model.expectShape(shape.target) val container = model.expectShape(shape.container) - // We are only targetting non syntetic inputs and outputs. + // We are only targeting non-synthetic inputs and outputs. if (!container.hasTrait() && !container.hasTrait()) { return initial } @@ -88,10 +91,6 @@ class PythonServerSymbolVisitor( * Note that since streaming members can only be used on the root shape, this can only impact input and output shapes. */ class PythonStreamingShapeMetadataProvider(private val base: RustSymbolProvider, private val model: Model) : SymbolMetadataProvider(base) { - override fun memberMeta(memberShape: MemberShape): RustMetadata { - return base.toSymbol(memberShape).expectRustMetadata() - } - override fun structureMeta(structureShape: StructureShape): RustMetadata { val baseMetadata = base.toSymbol(structureShape).expectRustMetadata() return if (structureShape.hasStreamingMember(model)) { @@ -106,7 +105,12 @@ class PythonStreamingShapeMetadataProvider(private val base: RustSymbolProvider, } else baseMetadata } - override fun enumMeta(stringShape: StringShape): RustMetadata { - return base.toSymbol(stringShape).expectRustMetadata() - } + override fun memberMeta(memberShape: MemberShape) = base.toSymbol(memberShape).expectRustMetadata() + override fun enumMeta(stringShape: StringShape) = base.toSymbol(stringShape).expectRustMetadata() + + override fun listMeta(listShape: ListShape) = base.toSymbol(listShape).expectRustMetadata() + override fun mapMeta(mapShape: MapShape) = base.toSymbol(mapShape).expectRustMetadata() + override fun stringMeta(stringShape: StringShape) = base.toSymbol(stringShape).expectRustMetadata() + override fun numberMeta(numberShape: NumberShape) = base.toSymbol(numberShape).expectRustMetadata() + override fun blobMeta(blobShape: BlobShape) = base.toSymbol(blobShape).expectRustMetadata() } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolMetadataProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolMetadataProvider.kt new file mode 100644 index 0000000000..01e8255ccc --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolMetadataProvider.kt @@ -0,0 +1,61 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.NumberShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.SymbolMetadataProvider +import software.amazon.smithy.rust.codegen.core.smithy.containerDefaultMetadata +import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata + +/** + * This symbol metadata provider adds the usual derives on shapes that are constrained and hence generate newtypes. + * + * It also makes the newtypes `pub(crate)` when `publicConstrainedTypes` is disabled. + */ +class ConstrainedShapeSymbolMetadataProvider( + private val base: RustSymbolProvider, + private val model: Model, + private val constrainedTypes: Boolean, +) : SymbolMetadataProvider(base) { + + override fun memberMeta(memberShape: MemberShape) = base.toSymbol(memberShape).expectRustMetadata() + override fun structureMeta(structureShape: StructureShape) = base.toSymbol(structureShape).expectRustMetadata() + override fun unionMeta(unionShape: UnionShape) = base.toSymbol(unionShape).expectRustMetadata() + override fun enumMeta(stringShape: StringShape) = base.toSymbol(stringShape).expectRustMetadata() + + private fun addDerivesAndAdjustVisibilityIfConstrained(shape: Shape): RustMetadata { + check(shape is ListShape || shape is MapShape || shape is StringShape || shape is NumberShape || shape is BlobShape) + + val baseMetadata = base.toSymbol(shape).expectRustMetadata() + val derives = baseMetadata.derives.toMutableSet() + val additionalAttributes = baseMetadata.additionalAttributes.toMutableList() + + if (shape.canReachConstrainedShape(model, base)) { + derives += containerDefaultMetadata(shape, model).derives + } + + val visibility = Visibility.publicIf(constrainedTypes, Visibility.PUBCRATE) + return RustMetadata(derives, additionalAttributes, visibility) + } + + override fun listMeta(listShape: ListShape): RustMetadata = addDerivesAndAdjustVisibilityIfConstrained(listShape) + override fun mapMeta(mapShape: MapShape): RustMetadata = addDerivesAndAdjustVisibilityIfConstrained(mapShape) + override fun stringMeta(stringShape: StringShape): RustMetadata = addDerivesAndAdjustVisibilityIfConstrained(stringShape) + override fun numberMeta(numberShape: NumberShape): RustMetadata = addDerivesAndAdjustVisibilityIfConstrained(numberShape) + override fun blobMeta(blobShape: BlobShape) = addDerivesAndAdjustVisibilityIfConstrained(blobShape) +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/DeriveEqAndHashSymbolMetadataProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/DeriveEqAndHashSymbolMetadataProvider.kt new file mode 100644 index 0000000000..5438447ed5 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/DeriveEqAndHashSymbolMetadataProvider.kt @@ -0,0 +1,96 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.DocumentShape +import software.amazon.smithy.model.shapes.DoubleShape +import software.amazon.smithy.model.shapes.FloatShape +import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.NumberShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.StreamingTrait +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.smithy.DirectedWalker +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.SymbolMetadataProvider +import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata +import software.amazon.smithy.rust.codegen.core.util.hasTrait + +/** + * This symbol metadata provider adds derives to implement the [`Eq`] and [`Hash`] traits for shapes, whenever + * possible. + * + * These traits can be implemented by any shape _except_ if the shape's closure contains: + * + * 1. A `float`, `double`, or `document` shape: floating point types in Rust do not implement `Eq`. Similarly, + * [`document` shapes] may contain arbitrary JSON-like data containing floating point values. + * 2. A [@streaming] shape: all the streaming data would need to be buffered first to compare it. + * + * Additionally, the `Hash` trait cannot be implemented by shapes whose closure contains: + * + * 1. A `map` shape: we render `map` shapes as `std::collections::HashMap`, which _do not_ implement `Hash`. + * See https://github.com/awslabs/smithy/issues/1567. + * + * [`Eq`]: https://doc.rust-lang.org/std/cmp/trait.Eq.html + * [`Hash`]: https://doc.rust-lang.org/std/hash/trait.Hash.html + * [`document` shapes]: https://smithy.io/2.0/spec/simple-types.html#document + * [@streaming]: https://smithy.io/2.0/spec/streaming.html + */ +class DeriveEqAndHashSymbolMetadataProvider( + private val base: RustSymbolProvider, + val model: Model, +) : SymbolMetadataProvider(base) { + private val walker = DirectedWalker(model) + + private fun addDeriveEqAndHashIfPossible(shape: Shape): RustMetadata { + check(shape !is MemberShape) + val baseMetadata = base.toSymbol(shape).expectRustMetadata() + // See class-level documentation for why we filter these out. + return if (walker.walkShapes(shape) + .any { it is FloatShape || it is DoubleShape || it is DocumentShape || it.hasTrait() } + ) { + baseMetadata + } else { + var ret = baseMetadata + if (ret.derives.contains(RuntimeType.PartialEq)) { + // We can only derive `Eq` if the type implements `PartialEq`. Not every shape that does not reach a + // floating point or a document shape does; for example, streaming shapes cannot be `PartialEq`, see + // [StreamingShapeMetadataProvider]. This is just a defensive check in case other symbol providers + // want to remove the `PartialEq` trait, since we've also just checked that we do not reach a streaming + // shape. + ret = ret.withDerives(RuntimeType.Eq) + } + + // `std::collections::HashMap` does not implement `std::hash::Hash`: + // https://github.com/awslabs/smithy/issues/1567 + if (walker.walkShapes(shape).none { it is MapShape }) { + ret = ret.withDerives(RuntimeType.Hash) + } + + return ret + } + } + + override fun memberMeta(memberShape: MemberShape) = base.toSymbol(memberShape).expectRustMetadata() + + override fun structureMeta(structureShape: StructureShape) = addDeriveEqAndHashIfPossible(structureShape) + override fun unionMeta(unionShape: UnionShape) = addDeriveEqAndHashIfPossible(unionShape) + override fun enumMeta(stringShape: StringShape) = addDeriveEqAndHashIfPossible(stringShape) + + override fun listMeta(listShape: ListShape): RustMetadata = addDeriveEqAndHashIfPossible(listShape) + override fun mapMeta(mapShape: MapShape): RustMetadata = addDeriveEqAndHashIfPossible(mapShape) + override fun stringMeta(stringShape: StringShape): RustMetadata = addDeriveEqAndHashIfPossible(stringShape) + override fun numberMeta(numberShape: NumberShape): RustMetadata = addDeriveEqAndHashIfPossible(numberShape) + override fun blobMeta(blobShape: BlobShape): RustMetadata = addDeriveEqAndHashIfPossible(blobShape) +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenServerPlugin.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenServerPlugin.kt index 98e875f3cc..8a1dc17e54 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenServerPlugin.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenServerPlugin.kt @@ -73,8 +73,12 @@ class RustCodegenServerPlugin : SmithyBuildPlugin { .let { StreamingShapeSymbolProvider(it, model) } // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes .let { BaseSymbolMetadataProvider(it, model, additionalAttributes = listOf()) } - // Streaming shapes need different derives (e.g. they cannot derive Eq) + // Constrained shapes generate newtypes that need the same derives we place on types generated from aggregate shapes. + .let { ConstrainedShapeSymbolMetadataProvider(it, model, constrainedTypes) } + // Streaming shapes need different derives (e.g. they cannot derive `PartialEq`) .let { StreamingShapeMetadataProvider(it, model) } + // Derive `Eq` and `Hash` if possible. + .let { DeriveEqAndHashSymbolMetadataProvider(it, model) } // Rename shapes that clash with Rust reserved words & and other SDK specific features e.g. `send()` cannot // be the name of an operation input .let { RustReservedWordSymbolProvider(it, model) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedBlobGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedBlobGenerator.kt index 0afeac0690..41fec1cec7 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedBlobGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedBlobGenerator.kt @@ -8,8 +8,6 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.traits.LengthTrait -import software.amazon.smithy.rust.codegen.core.rustlang.Attribute -import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.rustlang.Writable @@ -18,8 +16,10 @@ import software.amazon.smithy.rust.codegen.core.rustlang.documentShape import software.amazon.smithy.rust.codegen.core.rustlang.join import software.amazon.smithy.rust.codegen.core.rustlang.render import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.core.smithy.rustType @@ -55,37 +55,31 @@ class ConstrainedBlobGenerator( val inner = RuntimeType.blob(codegenContext.runtimeConfig).toSymbol().rustType().render() val constraintViolation = constraintViolationSymbolProvider.toSymbol(shape) - val constrainedTypeVisibility = if (publicConstrainedTypes) { - Visibility.PUBLIC - } else { - Visibility.PUBCRATE - } - val constrainedTypeMetadata = RustMetadata( - setOf(RuntimeType.Debug, RuntimeType.Clone, RuntimeType.PartialEq), - visibility = constrainedTypeVisibility, - ) - writer.documentShape(shape, model) writer.docs(rustDocsConstrainedTypeEpilogue(name)) - constrainedTypeMetadata.render(writer) + val metadata = symbol.expectRustMetadata() + metadata.render(writer) writer.rust("struct $name(pub(crate) $inner);") - if (constrainedTypeVisibility == Visibility.PUBCRATE) { - Attribute.AllowDeadCode.render(writer) - } - writer.rust( - """ - impl $name { - /// ${rustDocsInnerMethod(inner)} - pub fn inner(&self) -> &$inner { - &self.0 - } - + writer.rustBlock("impl $name") { + if (metadata.visibility == Visibility.PUBLIC) { + writer.rust( + """ + /// ${rustDocsInnerMethod(inner)} + pub fn inner(&self) -> &$inner { + &self.0 + } + """, + ) + } + writer.rust( + """ /// ${rustDocsIntoInnerMethod(inner)} pub fn into_inner(self) -> $inner { self.0 } - }""", - ) + """, + ) + } writer.renderTryFrom(inner, name, constraintViolation, constraintsInfo) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedCollectionGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedCollectionGenerator.kt index 8dc227e273..9f74839dd3 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedCollectionGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedCollectionGenerator.kt @@ -11,16 +11,16 @@ import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.LengthTrait import software.amazon.smithy.model.traits.Trait -import software.amazon.smithy.rust.codegen.core.rustlang.Attribute -import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.rustlang.docs import software.amazon.smithy.rust.codegen.core.rustlang.documentShape import software.amazon.smithy.rust.codegen.core.rustlang.join import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.util.PANIC import software.amazon.smithy.rust.codegen.core.util.orNull import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider @@ -72,11 +72,7 @@ class ConstrainedCollectionGenerator( val name = constrainedShapeSymbolProvider.toSymbol(shape).name val inner = "std::vec::Vec<#{ValueMemberSymbol}>" val constraintViolation = constraintViolationSymbolProvider.toSymbol(shape) - val constrainedTypeVisibility = Visibility.publicIf(publicConstrainedTypes, Visibility.PUBCRATE) - val constrainedTypeMetadata = RustMetadata( - setOf(RuntimeType.Debug, RuntimeType.Clone, RuntimeType.PartialEq), - visibility = constrainedTypeVisibility, - ) + val constrainedSymbol = symbolProvider.toSymbol(shape) val codegenScope = arrayOf( "ValueMemberSymbol" to constrainedShapeSymbolProvider.toSymbol(shape.member), @@ -87,33 +83,43 @@ class ConstrainedCollectionGenerator( writer.documentShape(shape, model) writer.docs(rustDocsConstrainedTypeEpilogue(name)) - constrainedTypeMetadata.render(writer) + val metadata = constrainedSymbol.expectRustMetadata() + metadata.render(writer) writer.rustTemplate( """ struct $name(pub(crate) $inner); """, *codegenScope, ) - if (constrainedTypeVisibility == Visibility.PUBCRATE) { - Attribute.AllowDeadCode.render(writer) - } - - writer.rustTemplate( - """ - impl $name { - /// ${rustDocsInnerMethod(inner)} - pub fn inner(&self) -> &$inner { - &self.0 - } + writer.rustBlock("impl $name") { + if (metadata.visibility == Visibility.PUBLIC) { + writer.rustTemplate( + """ + /// ${rustDocsInnerMethod(inner)} + pub fn inner(&self) -> &$inner { + &self.0 + } + """, + *codegenScope, + ) + } + writer.rustTemplate( + """ /// ${rustDocsIntoInnerMethod(inner)} pub fn into_inner(self) -> $inner { self.0 } #{ValidationFunctions:W} - } + """, + *codegenScope, + "ValidationFunctions" to constraintsInfo.map { it.validationFunctionDefinition(constraintViolation, inner) }.join("\n"), + ) + } + writer.rustTemplate( + """ impl #{TryFrom}<$inner> for $name { type Error = #{ConstraintViolation}; diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGenerator.kt index a8a891dc3a..e5721e7741 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGenerator.kt @@ -10,14 +10,14 @@ import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.LengthTrait -import software.amazon.smithy.rust.codegen.core.rustlang.Attribute -import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.rustlang.docs import software.amazon.smithy.rust.codegen.core.rustlang.documentShape +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext @@ -58,16 +58,12 @@ class ConstrainedMapGenerator( val lengthTrait = shape.expectTrait() val name = constrainedShapeSymbolProvider.toSymbol(shape).name - val inner = "std::collections::HashMap<#{KeySymbol}, #{ValueMemberSymbol}>" + val inner = "#{HashMap}<#{KeySymbol}, #{ValueMemberSymbol}>" val constraintViolation = constraintViolationSymbolProvider.toSymbol(shape) - - val constrainedTypeVisibility = Visibility.publicIf(publicConstrainedTypes, Visibility.PUBCRATE) - val constrainedTypeMetadata = RustMetadata( - setOf(RuntimeType.Debug, RuntimeType.Clone, RuntimeType.PartialEq), - visibility = constrainedTypeVisibility, - ) + val constrainedSymbol = symbolProvider.toSymbol(shape) val codegenScope = arrayOf( + "HashMap" to RuntimeType.HashMap, "KeySymbol" to constrainedShapeSymbolProvider.toSymbol(model.expectShape(shape.key.target)), "ValueMemberSymbol" to constrainedShapeSymbolProvider.toSymbol(shape.value), "From" to RuntimeType.From, @@ -77,25 +73,34 @@ class ConstrainedMapGenerator( writer.documentShape(shape, model) writer.docs(rustDocsConstrainedTypeEpilogue(name)) - constrainedTypeMetadata.render(writer) + val metadata = constrainedSymbol.expectRustMetadata() + metadata.render(writer) writer.rustTemplate("struct $name(pub(crate) $inner);", *codegenScope) - if (constrainedTypeVisibility == Visibility.PUBCRATE) { - Attribute.AllowDeadCode.render(writer) - } - writer.rustTemplate( - """ - impl $name { - /// ${rustDocsInnerMethod(inner)} - pub fn inner(&self) -> &$inner { - &self.0 - } - + writer.rustBlockTemplate("impl $name", *codegenScope) { + if (metadata.visibility == Visibility.PUBLIC) { + writer.rustTemplate( + """ + /// ${rustDocsInnerMethod(inner)} + pub fn inner(&self) -> &$inner { + &self.0 + } + """, + *codegenScope, + ) + } + writer.rustTemplate( + """ /// ${rustDocsIntoInnerMethod(inner)} pub fn into_inner(self) -> $inner { self.0 } - } + """, + *codegenScope, + ) + } + writer.rustTemplate( + """ impl #{TryFrom}<$inner> for $name { type Error = #{ConstraintViolation}; @@ -149,6 +154,7 @@ class ConstrainedMapGenerator( type Unconstrained = #{UnconstrainedSymbol}; } """, + *codegenScope, "ConstrainedTrait" to RuntimeType.ConstrainedTrait, "UnconstrainedSymbol" to unconstrainedSymbol, ) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedNumberGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedNumberGenerator.kt index 863f20af9b..281f0005c1 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedNumberGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedNumberGenerator.kt @@ -13,7 +13,6 @@ import software.amazon.smithy.model.shapes.NumberShape import software.amazon.smithy.model.shapes.ShortShape import software.amazon.smithy.model.traits.RangeTrait import software.amazon.smithy.rust.codegen.core.rustlang.Attribute -import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Visibility @@ -26,6 +25,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE @@ -76,30 +76,13 @@ class ConstrainedNumberGenerator( val constraintViolation = constraintViolationSymbolProvider.toSymbol(shape) val constraintsInfo = listOf(Range(rangeTrait).toTraitInfo(unconstrainedTypeName)) - val constrainedTypeVisibility = if (publicConstrainedTypes) { - Visibility.PUBLIC - } else { - Visibility.PUBCRATE - } - val constrainedTypeMetadata = RustMetadata( - - setOf( - RuntimeType.Debug, - RuntimeType.Clone, - RuntimeType.PartialEq, - RuntimeType.Eq, - RuntimeType.Hash, - ), - - visibility = constrainedTypeVisibility, - ) - writer.documentShape(shape, model) writer.docs(rustDocsConstrainedTypeEpilogue(name)) - constrainedTypeMetadata.render(writer) + val metadata = symbol.expectRustMetadata() + metadata.render(writer) writer.rust("struct $name(pub(crate) $unconstrainedTypeName);") - if (constrainedTypeVisibility == Visibility.PUBCRATE) { + if (metadata.visibility == Visibility.PUBCRATE) { Attribute.AllowDeadCode.render(writer) } writer.rustTemplate( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt index 8c2876f53f..20a6746aa8 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt @@ -12,7 +12,6 @@ import software.amazon.smithy.model.traits.LengthTrait import software.amazon.smithy.model.traits.PatternTrait import software.amazon.smithy.model.traits.Trait import software.amazon.smithy.rust.codegen.core.rustlang.Attribute -import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Visibility @@ -24,6 +23,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.render import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.core.smithy.testModuleForShape @@ -71,24 +71,12 @@ class ConstrainedStringGenerator( val inner = RustType.String.render() val constraintViolation = constraintViolationSymbolProvider.toSymbol(shape) - val constrainedTypeVisibility = if (publicConstrainedTypes) { - Visibility.PUBLIC - } else { - Visibility.PUBCRATE - } - val constrainedTypeMetadata = RustMetadata( - setOf(RuntimeType.Debug, RuntimeType.Clone, RuntimeType.PartialEq, RuntimeType.Eq, RuntimeType.Hash), - visibility = constrainedTypeVisibility, - ) - - // Note that we're using the linear time check `chars().count()` instead of `len()` on the input value, since the - // Smithy specification says the `length` trait counts the number of Unicode code points when applied to string shapes. - // https://awslabs.github.io/smithy/1.0/spec/core/constraint-traits.html#length-trait writer.documentShape(shape, model) writer.docs(rustDocsConstrainedTypeEpilogue(name)) - constrainedTypeMetadata.render(writer) + val metadata = symbol.expectRustMetadata() + metadata.render(writer) writer.rust("struct $name(pub(crate) $inner);") - if (constrainedTypeVisibility == Visibility.PUBCRATE) { + if (metadata.visibility == Visibility.PUBCRATE) { Attribute.AllowDeadCode.render(writer) } writer.rust( @@ -157,7 +145,7 @@ class ConstrainedStringGenerator( """ ##[derive(Debug, PartialEq)] pub enum ${constraintViolation.name} { - #{Variants:W} + #{Variants:W} } """, "Variants" to constraintsInfo.map { it.constraintViolationVariant }.join(",\n"), @@ -198,12 +186,12 @@ class ConstrainedStringGenerator( } private data class Length(val lengthTrait: LengthTrait) : StringTraitInfo() { override fun toTraitInfo(): TraitInfo = TraitInfo( - { rust("Self::check_length(&value)?;") }, - { + tryFromCheck = { rust("Self::check_length(&value)?;") }, + constraintViolationVariant = { docs("Error when a string doesn't satisfy its `@length` requirements.") rust("Length(usize)") }, - { + asValidationExceptionField = { rust( """ Self::Length(length) => crate::model::ValidationExceptionField { @@ -213,7 +201,7 @@ private data class Length(val lengthTrait: LengthTrait) : StringTraitInfo() { """, ) }, - this::renderValidationFunction, + validationFunctionDefinition = this::renderValidationFunction, ) /** @@ -222,6 +210,9 @@ private data class Length(val lengthTrait: LengthTrait) : StringTraitInfo() { */ @Suppress("UNUSED_PARAMETER") private fun renderValidationFunction(constraintViolation: Symbol, unconstrainedTypeName: String): Writable = { + // Note that we're using the linear time check `chars().count()` instead of `len()` on the input value, since the + // Smithy specification says the `length` trait counts the number of Unicode code points when applied to string shapes. + // https://awslabs.github.io/smithy/1.0/spec/core/constraint-traits.html#length-trait rust( """ fn check_length(string: &str) -> Result<(), $constraintViolation> { @@ -243,13 +234,13 @@ private data class Pattern(val symbol: Symbol, val patternTrait: PatternTrait) : val pattern = patternTrait.pattern return TraitInfo( - { rust("let value = Self::check_pattern(value)?;") }, - { + tryFromCheck = { rust("let value = Self::check_pattern(value)?;") }, + constraintViolationVariant = { docs("Error when a string doesn't satisfy its `@pattern`.") docs("Contains the String that failed the pattern.") rust("Pattern(String)") }, - { + asValidationExceptionField = { rust( """ Self::Pattern(string) => crate::model::ValidationExceptionField { diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/DeriveEqAndHashSymbolMetadataProviderTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/DeriveEqAndHashSymbolMetadataProviderTest.kt new file mode 100644 index 0000000000..e347197276 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/DeriveEqAndHashSymbolMetadataProviderTest.kt @@ -0,0 +1,236 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import io.kotest.matchers.collections.shouldContainAll +import io.kotest.matchers.collections.shouldNotContain +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.MethodSource +import org.junit.jupiter.params.provider.ValueSource +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.rust.codegen.core.smithy.BaseSymbolMetadataProvider +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider +import java.util.stream.Stream + +internal class DeriveEqAndHashSymbolMetadataProviderTest { + private val model = + """ + namespace test + + service TestService { + version: "123" + operations: [TestOperation, StreamingOperation, EventStreamOperation] + } + + operation TestOperation { + input: TestInputOutput + output: TestInputOutput + } + + operation StreamingOperation { + input: StreamingOperationInputOutput + output: StreamingOperationInputOutput + } + + operation EventStreamOperation { + input: EventStreamOperationInputOutput + output: EventStreamOperationInputOutput + } + + structure EventStreamOperationInputOutput { + @httpPayload + @required + union: StreamingUnion + } + + structure StreamingOperationInputOutput { + @httpPayload + @required + blobStream: BlobStream + } + + @streaming + blob BlobStream + + structure TestInputOutput { + hasFloat: HasFloat + hasDouble: HasDouble + hasDocument: HasDocument + containsFloat: ContainsFloat + containsDouble: ContainsDouble + containsDocument: ContainsDocument + + hasList: HasList + hasListWithMap: HasListWithMap + hasMap: HasMap + + eqAndHashStruct: EqAndHashStruct + } + + structure EqAndHashStruct { + blob: Blob + boolean: Boolean + string: String + byte: Byte + short: Short + integer: Integer + long: Long + enum: Enum + timestamp: Timestamp + + list: List + union: EqAndHashUnion + + // bigInteger: BigInteger + // bigDecimal: BigDecimal + } + + list List { + member: String + } + + list ListWithMap { + member: Map + } + + map Map { + key: String + value: String + } + + union EqAndHashUnion { + blob: Blob + boolean: Boolean + string: String + byte: Byte + short: Short + integer: Integer + long: Long + enum: Enum + timestamp: Timestamp + + list: List + } + + @streaming + union StreamingUnion { + eqAndHashStruct: EqAndHashStruct + } + + structure HasFloat { + float: Float + } + + structure HasDouble { + double: Double + } + + structure HasDocument { + document: Document + } + + structure HasList { + list: List + } + + structure HasListWithMap { + list: ListWithMap + } + + structure HasMap { + map: Map + } + + structure ContainsFloat { + hasFloat: HasFloat + } + + structure ContainsDouble { + hasDouble: HasDouble + } + + structure ContainsDocument { + containsDocument: HasDocument + } + + enum Enum { + DIAMOND + CLUB + HEART + SPADE + } + """.asSmithyModel(smithyVersion = "2.0") + private val serviceShape = model.lookup("test#TestService") + private val deriveEqAndHashSymbolMetadataProvider = serverTestSymbolProvider(model, serviceShape) + .let { BaseSymbolMetadataProvider(it, model, additionalAttributes = listOf()) } + .let { DeriveEqAndHashSymbolMetadataProvider(it, model) } + + companion object { + @JvmStatic + fun getShapes(): Stream { + val shapesWithNeitherEqNorHash = listOf( + "test#StreamingOperationInputOutput", + "test#EventStreamOperationInputOutput", + "test#StreamingUnion", + "test#BlobStream", + "test#TestInputOutput", + "test#HasFloat", + "test#HasDouble", + "test#HasDocument", + "test#ContainsFloat", + "test#ContainsDouble", + "test#ContainsDocument", + ) + + val shapesWithEqAndHash = listOf( + "test#EqAndHashStruct", + "test#EqAndHashUnion", + "test#Enum", + "test#HasList", + ) + + val shapesWithOnlyEq = listOf( + "test#HasListWithMap", + "test#HasMap", + ) + + return ( + shapesWithNeitherEqNorHash.map { Arguments.of(it, emptyList()) } + + shapesWithEqAndHash.map { Arguments.of(it, listOf(RuntimeType.Eq, RuntimeType.Hash)) } + + shapesWithOnlyEq.map { Arguments.of(it, listOf(RuntimeType.Eq)) } + ).stream() + } + } + + @ParameterizedTest(name = "(#{index}) Derive `Eq` and `Hash` when possible. Params = shape: {0}, expectedTraits: {1}") + @MethodSource("getShapes") + fun `it should derive Eq and Hash when possible`( + shapeId: String, + expectedTraits: Collection, + ) { + val shape = model.lookup(shapeId) + val derives = deriveEqAndHashSymbolMetadataProvider.toSymbol(shape).expectRustMetadata().derives + derives shouldContainAll expectedTraits + } + + @ParameterizedTest + // These don't implement `PartialEq` because they are not constrained, so they don't generate newtypes. If the + // symbol provider wrapped `ConstrainedShapeSymbolProvider` and they were constrained, they would generate + // newtypes, and they would hence implement `PartialEq`. + @ValueSource(strings = ["test#List", "test#Map", "test#ListWithMap", "smithy.api#Blob"]) + fun `it should not derive Eq if shape does not implement PartialEq`(shapeId: String) { + val shape = model.lookup(shapeId) + val derives = deriveEqAndHashSymbolMetadataProvider.toSymbol(shape).expectRustMetadata().derives + derives shouldNotContain RuntimeType.PartialEq + derives shouldNotContain RuntimeType.Eq + } +} diff --git a/rust-runtime/aws-smithy-http-server-python/src/types.rs b/rust-runtime/aws-smithy-http-server-python/src/types.rs index 1ce7779482..2fb127fb70 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/types.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/types.rs @@ -26,7 +26,7 @@ use crate::PyError; /// Python Wrapper for [aws_smithy_types::Blob]. #[pyclass] -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Blob(aws_smithy_types::Blob); impl Blob { @@ -88,7 +88,7 @@ impl<'blob> From<&'blob Blob> for &'blob aws_smithy_types::Blob { /// Python Wrapper for [aws_smithy_types::date_time::DateTime]. #[pyclass] -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct DateTime(aws_smithy_types::date_time::DateTime); #[pyclass] diff --git a/rust-runtime/aws-smithy-types/src/date_time/mod.rs b/rust-runtime/aws-smithy-types/src/date_time/mod.rs index 6c2600dc9f..accb788046 100644 --- a/rust-runtime/aws-smithy-types/src/date_time/mod.rs +++ b/rust-runtime/aws-smithy-types/src/date_time/mod.rs @@ -47,7 +47,7 @@ const NANOS_PER_SECOND_U32: u32 = 1_000_000_000; /// The [`aws-smithy-types-convert`](https://crates.io/crates/aws-smithy-types-convert) crate /// can be used for conversions to/from other libraries, such as /// [`time`](https://crates.io/crates/time) or [`chrono`](https://crates.io/crates/chrono). -#[derive(Debug, PartialEq, Clone, Copy)] +#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] pub struct DateTime { seconds: i64, subsecond_nanos: u32, diff --git a/rust-runtime/aws-smithy-types/src/lib.rs b/rust-runtime/aws-smithy-types/src/lib.rs index f2729cefa1..1e45b94183 100644 --- a/rust-runtime/aws-smithy-types/src/lib.rs +++ b/rust-runtime/aws-smithy-types/src/lib.rs @@ -30,7 +30,7 @@ pub use error::Error; /// Binary Blob Type /// /// Blobs represent protocol-agnostic binary content. -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Eq, Hash, Clone)] pub struct Blob { inner: Vec, }