Skip to content

Commit

Permalink
Support @sparse constrained map shapes and list shapes
Browse files Browse the repository at this point in the history
Turns out we've never supported them, neither directly constrained nor
with constrained members, because of a lack of tests. Yet another data
point to prioritize working on code-generating `constraints.smithy` (see
#2101).

The implementation is simple: we just need to call the symbol provider
on the member symbols instead of on the target symbols so we get
`Option<T>` list members / map values if applicable, and handle the
wrapper when converting between unconstrained and constrained types with
help from `match` and `Option<T>::map`.
  • Loading branch information
david-perez committed Jan 13, 2023
1 parent f5c56b6 commit 6c835a6
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 70 deletions.
28 changes: 28 additions & 0 deletions codegen-core/common-test-models/constraints.smithy
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,10 @@ structure ConA {
lengthMap: LengthMap,

mapOfMapOfListOfListOfConB: MapOfMapOfListOfListOfConB,
sparseMap: SparseMap,
sparseList: SparseList,
sparseLengthMap: SparseLengthMap,
sparseLengthList: SparseLengthList,

constrainedUnion: ConstrainedUnion,
enumString: EnumString,
Expand Down Expand Up @@ -543,6 +547,30 @@ structure ConA {
// lengthSetOfPatternString: LengthSetOfPatternString,
}

@sparse
map SparseMap {
key: String,
value: LengthString
}

@sparse
list SparseList {
member: LengthString
}

@sparse
@length(min: 69)
map SparseLengthMap {
key: String,
value: String
}

@sparse
@length(min: 69)
list SparseLengthList {
member: String
}

map MapOfLengthBlob {
key: String,
value: LengthBlob,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,6 @@ class PubCrateConstrainedShapeSymbolProvider(
}

is MemberShape -> {
require(model.expectShape(shape.container).isStructureShape) {
"This arm is only exercised by `ServerBuilderGenerator`"
}
require(!shape.hasConstraintTraitOrTargetHasConstraintTrait(model, base)) { errorMessage(shape) }

val targetShape = model.expectShape(shape.target)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class ConstrainedCollectionGenerator(
}

val name = constrainedShapeSymbolProvider.toSymbol(shape).name
val inner = "std::vec::Vec<#{ValueSymbol}>"
val inner = "std::vec::Vec<#{ValueMemberSymbol}>"
val constraintViolation = constraintViolationSymbolProvider.toSymbol(shape)
val constrainedTypeVisibility = Visibility.publicIf(publicConstrainedTypes, Visibility.PUBCRATE)
val constrainedTypeMetadata = RustMetadata(
Expand All @@ -79,7 +79,7 @@ class ConstrainedCollectionGenerator(
)

val codegenScope = arrayOf(
"ValueSymbol" to constrainedShapeSymbolProvider.toSymbol(model.expectShape(shape.member.target)),
"ValueMemberSymbol" to constrainedShapeSymbolProvider.toSymbol(shape.member),
"From" to RuntimeType.From,
"TryFrom" to RuntimeType.TryFrom,
"ConstraintViolation" to constraintViolation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class ConstrainedMapGenerator(
val lengthTrait = shape.expectTrait<LengthTrait>()

val name = constrainedShapeSymbolProvider.toSymbol(shape).name
val inner = "std::collections::HashMap<#{KeySymbol}, #{ValueSymbol}>"
val inner = "std::collections::HashMap<#{KeySymbol}, #{ValueMemberSymbol}>"
val constraintViolation = constraintViolationSymbolProvider.toSymbol(shape)

val constrainedTypeVisibility = Visibility.publicIf(publicConstrainedTypes, Visibility.PUBCRATE)
Expand All @@ -69,7 +69,7 @@ class ConstrainedMapGenerator(

val codegenScope = arrayOf(
"KeySymbol" to constrainedShapeSymbolProvider.toSymbol(model.expectShape(shape.key.target)),
"ValueSymbol" to constrainedShapeSymbolProvider.toSymbol(model.expectShape(shape.value.target)),
"ValueMemberSymbol" to constrainedShapeSymbolProvider.toSymbol(shape.value),
"From" to RuntimeType.From,
"TryFrom" to RuntimeType.TryFrom,
"ConstraintViolation" to constraintViolation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators
import software.amazon.smithy.model.shapes.CollectionShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.module
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape
Expand Down Expand Up @@ -54,14 +60,14 @@ class PubCrateConstrainedCollectionGenerator(
val unconstrainedSymbol = unconstrainedShapeSymbolProvider.toSymbol(shape)
val name = constrainedSymbol.name
val innerShape = model.expectShape(shape.member.target)
val innerConstrainedSymbol = if (innerShape.isTransitivelyButNotDirectlyConstrained(model, symbolProvider)) {
pubCrateConstrainedShapeSymbolProvider.toSymbol(innerShape)
val innerMemberSymbol = if (innerShape.isTransitivelyButNotDirectlyConstrained(model, symbolProvider)) {
pubCrateConstrainedShapeSymbolProvider.toSymbol(shape.member)
} else {
constrainedShapeSymbolProvider.toSymbol(innerShape)
constrainedShapeSymbolProvider.toSymbol(shape.member)
}

val codegenScope = arrayOf(
"InnerConstrainedSymbol" to innerConstrainedSymbol,
"InnerMemberSymbol" to innerMemberSymbol,
"ConstrainedTrait" to RuntimeType.ConstrainedTrait,
"UnconstrainedSymbol" to unconstrainedSymbol,
"Symbol" to symbol,
Expand All @@ -72,7 +78,7 @@ class PubCrateConstrainedCollectionGenerator(
rustTemplate(
"""
##[derive(Debug, Clone)]
pub(crate) struct $name(pub(crate) std::vec::Vec<#{InnerConstrainedSymbol}>);
pub(crate) struct $name(pub(crate) std::vec::Vec<#{InnerMemberSymbol}>);
impl #{ConstrainedTrait} for $name {
type Unconstrained = #{UnconstrainedSymbol};
Expand Down Expand Up @@ -130,22 +136,19 @@ class PubCrateConstrainedCollectionGenerator(
val innerNeedsConversion =
innerShape.typeNameContainsNonPublicType(model, symbolProvider, publicConstrainedTypes)

rustTemplate(
"""
impl #{From}<$name> for #{Symbol} {
fn from(v: $name) -> Self {
${
if (innerNeedsConversion) {
"v.0.into_iter().map(|item| item.into()).collect()"
} else {
"v.0"
}
}
rustBlockTemplate("impl #{From}<$name> for #{Symbol}", *codegenScope) {
rustBlock("fn from(v: $name) -> Self") {
if (innerNeedsConversion) {
withBlock("v.0.into_iter().map(|item| ", ").collect()") {
conditionalBlock("item.map(|item| ", ")", innerMemberSymbol.isOptional()) {
rust("item.into()")
}
}
} else {
rust("v.0")
}
}
""",
*codegenScope,
)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,14 @@ import software.amazon.smithy.model.shapes.CollectionShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.module
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape
Expand Down Expand Up @@ -54,15 +60,15 @@ class PubCrateConstrainedMapGenerator(
val keyShape = model.expectShape(shape.key.target, StringShape::class.java)
val valueShape = model.expectShape(shape.value.target)
val keySymbol = constrainedShapeSymbolProvider.toSymbol(keyShape)
val valueSymbol = if (valueShape.isTransitivelyButNotDirectlyConstrained(model, symbolProvider)) {
pubCrateConstrainedShapeSymbolProvider.toSymbol(valueShape)
val valueMemberSymbol = if (valueShape.isTransitivelyButNotDirectlyConstrained(model, symbolProvider)) {
pubCrateConstrainedShapeSymbolProvider.toSymbol(shape.value)
} else {
constrainedShapeSymbolProvider.toSymbol(valueShape)
constrainedShapeSymbolProvider.toSymbol(shape.value)
}

val codegenScope = arrayOf(
"KeySymbol" to keySymbol,
"ValueSymbol" to valueSymbol,
"ValueMemberSymbol" to valueMemberSymbol,
"ConstrainedTrait" to RuntimeType.ConstrainedTrait,
"UnconstrainedSymbol" to unconstrainedSymbol,
"Symbol" to symbol,
Expand All @@ -73,7 +79,7 @@ class PubCrateConstrainedMapGenerator(
rustTemplate(
"""
##[derive(Debug, Clone)]
pub(crate) struct $name(pub(crate) std::collections::HashMap<#{KeySymbol}, #{ValueSymbol}>);
pub(crate) struct $name(pub(crate) std::collections::HashMap<#{KeySymbol}, #{ValueMemberSymbol}>);
impl #{ConstrainedTrait} for $name {
type Unconstrained = #{UnconstrainedSymbol};
Expand Down Expand Up @@ -117,22 +123,27 @@ class PubCrateConstrainedMapGenerator(
val keyNeedsConversion = keyShape.typeNameContainsNonPublicType(model, symbolProvider, publicConstrainedTypes)
val valueNeedsConversion = valueShape.typeNameContainsNonPublicType(model, symbolProvider, publicConstrainedTypes)

rustTemplate(
"""
impl #{From}<$name> for #{Symbol} {
fn from(v: $name) -> Self {
${ if (keyNeedsConversion || valueNeedsConversion) {
val keyConversion = if (keyNeedsConversion) { ".into()" } else { "" }
val valueConversion = if (valueNeedsConversion) { ".into()" } else { "" }
"v.0.into_iter().map(|(k, v)| (k$keyConversion, v$valueConversion)).collect()"
} else {
"v.0"
} }
rustBlockTemplate("impl #{From}<$name> for #{Symbol}", *codegenScope) {
rustBlock("fn from(v: $name) -> Self") {
if (keyNeedsConversion || valueNeedsConversion) {
withBlock("v.0.into_iter().map(|(k, v)| {", "}).collect()") {
if (keyNeedsConversion) {
rust("let k = k.into();")
}
if (valueNeedsConversion) {
withBlock("let v = {", "};") {
conditionalBlock("v.map(|v| ", ")", valueMemberSymbol.isOptional()) {
rust("v.into()")
}
}
}
rust("(k, v)")
}
} else {
rust("v.0")
}
}
""",
*codegenScope,
)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@ import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained
import software.amazon.smithy.rust.codegen.core.smithy.module
import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider
Expand Down Expand Up @@ -65,22 +68,21 @@ class UnconstrainedCollectionGenerator(
fun render() {
check(shape.canReachConstrainedShape(model, symbolProvider))

val innerShape = model.expectShape(shape.member.target)
val innerUnconstrainedSymbol = unconstrainedShapeSymbolProvider.toSymbol(innerShape)
val innerMemberSymbol = unconstrainedShapeSymbolProvider.toSymbol(shape.member)

unconstrainedModuleWriter.withInlineModule(symbol.module()) {
rustTemplate(
"""
##[derive(Debug, Clone)]
pub(crate) struct $name(pub(crate) std::vec::Vec<#{InnerUnconstrainedSymbol}>);
pub(crate) struct $name(pub(crate) std::vec::Vec<#{InnerMemberSymbol}>);
impl From<$name> for #{MaybeConstrained} {
fn from(value: $name) -> Self {
Self::Unconstrained(value)
}
}
""",
"InnerUnconstrainedSymbol" to innerUnconstrainedSymbol,
"InnerMemberSymbol" to innerMemberSymbol,
"MaybeConstrained" to constrainedSymbol.makeMaybeConstrained(),
)

Expand All @@ -99,26 +101,35 @@ class UnconstrainedCollectionGenerator(
!innerShape.isDirectlyConstrained(symbolProvider) &&
innerShape !is StructureShape &&
innerShape !is UnionShape
val innerConstrainedSymbol = if (resolvesToNonPublicConstrainedValueType) {
pubCrateConstrainedShapeSymbolProvider.toSymbol(innerShape)
val constrainedMemberSymbol = if (resolvesToNonPublicConstrainedValueType) {
pubCrateConstrainedShapeSymbolProvider.toSymbol(shape.member)
} else {
constrainedShapeSymbolProvider.toSymbol(innerShape)
constrainedShapeSymbolProvider.toSymbol(shape.member)
}
val innerConstraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(innerShape)

val constrainValueWritable = writable {
conditionalBlock("inner.map(|inner| ", ").transpose()", constrainedMemberSymbol.isOptional()) {
rust("inner.try_into().map_err(|inner_violation| (idx, inner_violation))")
}
}

rustTemplate(
"""
let res: Result<std::vec::Vec<#{InnerConstrainedSymbol}>, (usize, #{InnerConstraintViolationSymbol})> = value
let res: Result<#{Vec}<#{ConstrainedMemberSymbol}>, (usize, #{InnerConstraintViolationSymbol}) > = value
.0
.into_iter()
.enumerate()
.map(|(idx, inner)| inner.try_into().map_err(|inner_violation| (idx, inner_violation)))
.map(|(idx, inner)| {
#{ConstrainValueWritable:W}
})
.collect();
let inner = res.map_err(|(idx, inner_violation)| Self::Error::Member(idx, inner_violation))?;
""",
"InnerConstrainedSymbol" to innerConstrainedSymbol,
"Vec" to RuntimeType.Vec,
"ConstrainedMemberSymbol" to constrainedMemberSymbol,
"InnerConstraintViolationSymbol" to innerConstraintViolationSymbol,
"TryFrom" to RuntimeType.TryFrom,
"ConstrainValueWritable" to constrainValueWritable,
)
} else {
rust("let inner = value.0;")
Expand Down

0 comments on commit 6c835a6

Please sign in to comment.