diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt index 8fd2a07f583..00ea40dc7aa 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt @@ -10,19 +10,16 @@ import software.amazon.smithy.build.PluginContext import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.NullableIndex -import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.EnumTrait -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerEnumGenerator -import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerOperationHandlerGenerator import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerServiceGenerator import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerStructureGenerator import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext @@ -167,11 +164,4 @@ class PythonServerCodegenVisitor( ) .render() } - - override fun operationShape(shape: OperationShape) { - super.operationShape(shape) - rustCrate.withModule(RustModule.public("python_operation_adaptor")) { - PythonServerOperationHandlerGenerator(codegenContext, shape).render(this) - } - } } diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt index 793118b4410..ef0591cf8ed 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt @@ -199,7 +199,7 @@ class PythonApplicationGenerator( let ${name}_locals = #{pyo3_asyncio}::TaskLocals::new(event_loop); let handler = self.handlers.get("$name").expect("Python handler for operation `$name` not found").clone(); let builder = builder.$name(move |input, state| { - #{pyo3_asyncio}::tokio::scope(${name}_locals.clone(), crate::python_operation_adaptor::$name(input, state, handler.clone())) + #{pyo3_asyncio}::tokio::scope(${name}_locals.clone(), crate::operation_handler::$name(input, state, handler.clone())) }); """, *codegenScope, diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt index f107806098d..60784f36b54 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt @@ -13,6 +13,8 @@ import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCargoDependency +import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerOperationHandlerGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol /** * The Rust code responsible to run the Python business logic on the Python interpreter @@ -30,8 +32,9 @@ import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCarg */ class PythonServerOperationHandlerGenerator( codegenContext: CodegenContext, - private val operation: OperationShape, -) { + protocol: ServerProtocol, + private val operations: List, +) : ServerOperationHandlerGenerator(codegenContext, protocol, operations) { private val symbolProvider = codegenContext.symbolProvider private val runtimeConfig = codegenContext.runtimeConfig private val codegenScope = @@ -44,39 +47,42 @@ class PythonServerOperationHandlerGenerator( "tracing" to PythonServerCargoDependency.Tracing.toType(), ) - fun render(writer: RustWriter) { + override fun render(writer: RustWriter) { + super.render(writer) renderPythonOperationHandlerImpl(writer) } private fun renderPythonOperationHandlerImpl(writer: RustWriter) { - val operationName = symbolProvider.toSymbol(operation).name - val input = "crate::input::${operationName}Input" - val output = "crate::output::${operationName}Output" - val error = "crate::error::${operationName}Error" - val fnName = operationName.toSnakeCase() + for (operation in operations) { + val operationName = symbolProvider.toSymbol(operation).name + val input = "crate::input::${operationName}Input" + val output = "crate::output::${operationName}Output" + val error = "crate::error::${operationName}Error" + val fnName = operationName.toSnakeCase() - writer.rustTemplate( - """ - /// Python handler for operation `$operationName`. - pub(crate) async fn $fnName( - input: $input, - state: #{SmithyServer}::Extension<#{SmithyPython}::context::PyContext>, - handler: #{SmithyPython}::PyHandler, - ) -> std::result::Result<$output, $error> { - // Async block used to run the handler and catch any Python error. - let result = if handler.is_coroutine { - #{PyCoroutine:W} - } else { - #{PyFunction:W} - }; - #{PyError:W} - } - """, - *codegenScope, - "PyCoroutine" to renderPyCoroutine(fnName, output), - "PyFunction" to renderPyFunction(fnName, output), - "PyError" to renderPyError(), - ) + writer.rustTemplate( + """ + /// Python handler for operation `$operationName`. + pub(crate) async fn $fnName( + input: $input, + state: #{SmithyServer}::Extension<#{SmithyPython}::context::PyContext>, + handler: #{SmithyPython}::PyHandler, + ) -> std::result::Result<$output, $error> { + // Async block used to run the handler and catch any Python error. + let result = if handler.is_coroutine { + #{PyCoroutine:W} + } else { + #{PyFunction:W} + }; + #{PyError:W} + } + """, + *codegenScope, + "PyCoroutine" to renderPyCoroutine(fnName, output), + "PyFunction" to renderPyFunction(fnName, output), + "PyError" to renderPyError(), + ) + } } private fun renderPyFunction(name: String, output: String): Writable = diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerServiceGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerServiceGenerator.kt index 58e87bc3888..544eaae66a4 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerServiceGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerServiceGenerator.kt @@ -33,6 +33,10 @@ class PythonServerServiceGenerator( PythonServerOperationErrorGenerator(context.model, context.symbolProvider, operation).render(writer) } + override fun renderOperationHandler(writer: RustWriter, operations: List) { + PythonServerOperationHandlerGenerator(context, protocol, operations).render(writer) + } + override fun renderExtras(operations: List) { rustCrate.withModule(RustModule.public("python_server_application", "Python server and application implementation.")) { PythonApplicationGenerator(context, protocol, operations) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt index cb4e62fe9a8..6b4df3204e6 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt @@ -7,6 +7,8 @@ package software.amazon.smithy.rust.codegen.server.smithy import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.CratesIo import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope +import software.amazon.smithy.rust.codegen.core.rustlang.InlineDependency +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig /** @@ -25,8 +27,32 @@ object ServerCargoDependency { val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4")) val TokioDev: CargoDependency = CargoDependency("tokio", CratesIo("1.8.4"), scope = DependencyScope.Dev) val Regex: CargoDependency = CargoDependency("regex", CratesIo("1.5.5")) - val HyperDev: CargoDependency = CargoDependency("hyper", CratesIo("0.14.12"), DependencyScope.Dev) fun smithyHttpServer(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http-server") fun smithyTypes(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-types") } + +/** + * A dependency on a snippet of code + * + * ServerInlineDependency should not be instantiated directly, rather, it should be constructed with + * [software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.forInlineFun] + * + * ServerInlineDependencies are created as private modules within the main crate. This is useful for any code that + * doesn't need to exist in a shared crate, but must still be generated exactly once during codegen. + * + * CodegenVisitor de-duplicates inline dependencies by (module, name) during code generation. + */ +object ServerInlineDependency { + fun serverOperationHandler(runtimeConfig: RuntimeConfig): InlineDependency = + InlineDependency.forRustFile( + RustModule.private("server_operation_handler_trait"), + "/inlineable/src/server_operation_handler_trait.rs", + ServerCargoDependency.smithyHttpServer(runtimeConfig), + CargoDependency.Http, + ServerCargoDependency.PinProjectLite, + ServerCargoDependency.Tower, + ServerCargoDependency.FuturesUtil, + ServerCargoDependency.AsyncTrait, + ) +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt index e31275fa333..a0a1baa23d9 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt @@ -19,6 +19,9 @@ object ServerRuntimeType { fun router(runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("routing::Router") + fun operationHandler(runtimeConfig: RuntimeConfig) = + forInlineDependency(ServerInlineDependency.serverOperationHandler(runtimeConfig)) + fun runtimeError(runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("runtime_error::RuntimeError") fun requestRejection(runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("rejection::RequestRejection") diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt new file mode 100644 index 00000000000..308f2f04708 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt @@ -0,0 +1,154 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.error.errorSymbol +import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErrors +import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember +import software.amazon.smithy.rust.codegen.core.util.inputShape +import software.amazon.smithy.rust.codegen.core.util.outputShape +import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency +import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType +import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol +import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpBoundProtocolGenerator + +/** + * ServerOperationHandlerGenerator + */ +open class ServerOperationHandlerGenerator( + codegenContext: CodegenContext, + val protocol: ServerProtocol, + private val operations: List, +) { + private val serverCrate = "aws_smithy_http_server" + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + private val runtimeConfig = codegenContext.runtimeConfig + private val codegenScope = arrayOf( + "AsyncTrait" to ServerCargoDependency.AsyncTrait.toType(), + "Tower" to ServerCargoDependency.Tower.toType(), + "FuturesUtil" to ServerCargoDependency.FuturesUtil.toType(), + "SmithyHttp" to RuntimeType.smithyHttp(runtimeConfig), + "SmithyHttpServer" to ServerCargoDependency.smithyHttpServer(runtimeConfig).toType(), + "Phantom" to RuntimeType.Phantom, + "ServerOperationHandler" to ServerRuntimeType.operationHandler(runtimeConfig), + "http" to RuntimeType.Http, + ) + + open fun render(writer: RustWriter) { + renderHandlerImplementations(writer, false) + renderHandlerImplementations(writer, true) + } + + /** + * Renders the implementation of the `Handler` trait for all operations. + * Handlers are implemented for `FnOnce` function types whose signatures take in state or not. + */ + private fun renderHandlerImplementations(writer: RustWriter, state: Boolean) { + operations.map { operation -> + val operationName = symbolProvider.toSymbol(operation).name + val inputName = symbolProvider.toSymbol(operation.inputShape(model)).fullName + val inputWrapperName = "crate::operation::$operationName${ServerHttpBoundProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}" + val outputWrapperName = "crate::operation::$operationName${ServerHttpBoundProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}" + val fnSignature = if (state) { + "impl #{ServerOperationHandler}::Handler, $inputName> for Fun" + } else { + "impl #{ServerOperationHandler}::Handler for Fun" + } + writer.rustBlockTemplate( + """ + ##[#{AsyncTrait}::async_trait] + $fnSignature + where + ${operationTraitBounds(operation, inputName, state)} + """.trimIndent(), + *codegenScope, + ) { + val callImpl = if (state) { + """ + let state = match $serverCrate::extension::extract_extension(&mut req).await { + Ok(v) => v, + Err(extension_not_found_rejection) => { + let extension = $serverCrate::extension::RuntimeErrorExtension::new(extension_not_found_rejection.to_string()); + let runtime_error = $serverCrate::runtime_error::RuntimeError::from(extension_not_found_rejection); + let mut response = #{SmithyHttpServer}::response::IntoResponse::<#{Protocol}>::into_response(runtime_error); + response.extensions_mut().insert(extension); + return response.map($serverCrate::body::boxed); + } + }; + let input_inner = input_wrapper.into(); + let output_inner = self(input_inner, state).await; + """.trimIndent() + } else { + """ + let input_inner = input_wrapper.into(); + let output_inner = self(input_inner).await; + """.trimIndent() + } + rustTemplate( + """ + type Sealed = #{ServerOperationHandler}::sealed::Hidden; + async fn call(self, req: #{http}::Request) -> #{http}::Response<#{SmithyHttpServer}::body::BoxBody> { + let mut req = #{SmithyHttpServer}::request::RequestParts::new(req); + let input_wrapper = match $inputWrapperName::from_request(&mut req).await { + Ok(v) => v, + Err(runtime_error) => { + let response = #{SmithyHttpServer}::response::IntoResponse::<#{Protocol}>::into_response(runtime_error); + return response.map($serverCrate::body::boxed); + } + }; + $callImpl + let output_wrapper: $outputWrapperName = output_inner.into(); + let mut response = output_wrapper.into_response(); + let operation_ext = #{SmithyHttpServer}::extension::OperationExtension::new("${operation.id.namespace}.$operationName").expect("malformed absolute shape ID"); + response.extensions_mut().insert(operation_ext); + response.map(#{SmithyHttpServer}::body::boxed) + } + """, + "Protocol" to protocol.markerStruct(), + *codegenScope, + ) + } + } + } + + /** + * Generates the trait bounds of the `Handler` trait implementation, depending on: + * - the presence of state; and + * - whether the operation is fallible or not. + */ + private fun operationTraitBounds(operation: OperationShape, inputName: String, state: Boolean): String { + val inputFn = if (state) { + """S: Send + Clone + Sync + 'static, + Fun: FnOnce($inputName, $serverCrate::Extension) -> Fut + Clone + Send + 'static,""" + } else { + "Fun: FnOnce($inputName) -> Fut + Clone + Send + 'static," + } + val outputType = if (operation.operationErrors(model).isNotEmpty()) { + "Result<${symbolProvider.toSymbol(operation.outputShape(model)).fullName}, ${operation.errorSymbol(symbolProvider).fullyQualifiedName()}>" + } else { + symbolProvider.toSymbol(operation.outputShape(model)).fullName + } + val streamingBodyTraitBounds = if (operation.inputShape(model).hasStreamingMember(model)) { + "\n B: Into<#{SmithyHttp}::byte_stream::ByteStream>," + } else { + "" + } + return """ + $inputFn + Fut: std::future::Future + Send, + B: $serverCrate::body::HttpBody + Send + 'static, $streamingBodyTraitBounds + B::Data: Send, + $serverCrate::rejection::RequestRejection: From<::Error> + """.trimIndent() + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationRegistryGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationRegistryGenerator.kt new file mode 100644 index 00000000000..10b11978d0a --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationRegistryGenerator.kt @@ -0,0 +1,407 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.traits.DocumentationTrait +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.derive +import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope +import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.withBlock +import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.ErrorsModule +import software.amazon.smithy.rust.codegen.core.smithy.InputsModule +import software.amazon.smithy.rust.codegen.core.smithy.OutputsModule +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.error.errorSymbol +import software.amazon.smithy.rust.codegen.core.util.getTrait +import software.amazon.smithy.rust.codegen.core.util.inputShape +import software.amazon.smithy.rust.codegen.core.util.outputShape +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase +import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency +import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType +import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol + +/** + * [ServerOperationRegistryGenerator] renders the `OperationRegistry` struct, a place where users can register their + * service's operation implementations. + * + * Users can construct the operation registry using a builder. They can subsequently convert the operation registry into + * the [`aws_smithy_http_server::Router`], a [`tower::Service`] that will route incoming requests to their operation + * handlers, invoking them and returning the response. + * + * [`aws_smithy_http_server::Router`]: https://docs.rs/aws-smithy-http-server/latest/aws_smithy_http_server/struct.Router.html + * [`tower::Service`]: https://docs.rs/tower/latest/tower/trait.Service.html + */ +class ServerOperationRegistryGenerator( + private val codegenContext: CodegenContext, + private val protocol: ServerProtocol, + private val operations: List, +) { + private val crateName = codegenContext.settings.moduleName + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + private val serviceName = codegenContext.serviceShape.toShapeId().name + private val operationNames = operations.map { RustReservedWords.escapeIfNeeded(symbolProvider.toSymbol(it).name.toSnakeCase()) } + private val runtimeConfig = codegenContext.runtimeConfig + private val codegenScope = arrayOf( + "Router" to ServerRuntimeType.router(runtimeConfig), + "SmithyHttpServer" to ServerCargoDependency.smithyHttpServer(runtimeConfig).toType(), + "ServerOperationHandler" to ServerRuntimeType.operationHandler(runtimeConfig), + "Tower" to ServerCargoDependency.Tower.toType(), + "Phantom" to RuntimeType.Phantom, + "StdError" to RuntimeType.StdError, + "Display" to RuntimeType.Display, + "From" to RuntimeType.From, + ) + private val operationRegistryName = "OperationRegistry" + private val operationRegistryBuilderName = "${operationRegistryName}Builder" + private val operationRegistryErrorName = "${operationRegistryBuilderName}Error" + private val genericArguments = "B, " + operations.mapIndexed { i, _ -> "Op$i, In$i" }.joinToString() + private val operationRegistryNameWithArguments = "$operationRegistryName<$genericArguments>" + private val operationRegistryBuilderNameWithArguments = "$operationRegistryBuilderName<$genericArguments>" + + fun render(writer: RustWriter) { + renderOperationRegistryRustDocs(writer) + renderOperationRegistryStruct(writer) + renderOperationRegistryBuilderStruct(writer) + renderOperationRegistryBuilderError(writer) + renderOperationRegistryBuilderDefault(writer) + renderOperationRegistryBuilderImplementation(writer) + renderRouterImplementationFromOperationRegistryBuilder(writer) + } + + private fun renderOperationRegistryRustDocs(writer: RustWriter) { + val inputOutputErrorsImport = if (operations.any { it.errors.isNotEmpty() }) { + "/// use ${crateName.toSnakeCase()}::{${InputsModule.name}, ${OutputsModule.name}, ${ErrorsModule.name}};" + } else { + "/// use ${crateName.toSnakeCase()}::{${InputsModule.name}, ${OutputsModule.name}};" + } + + writer.rustTemplate( +""" +##[allow(clippy::tabs_in_doc_comments)] +/// The `$operationRegistryName` is the place where you can register +/// your service's operation implementations. +/// +/// Use [`$operationRegistryBuilderName`] to construct the +/// `$operationRegistryName`. For each of the [operations] modeled in +/// your Smithy service, you need to provide an implementation in the +/// form of a Rust async function or closure that takes in the +/// operation's input as their first parameter, and returns the +/// operation's output. If your operation is fallible (i.e. it +/// contains the `errors` member in your Smithy model), the function +/// implementing the operation has to be fallible (i.e. return a +/// [`Result`]). **You must register an implementation for all +/// operations with the correct signature**, or your application +/// will fail to compile. +/// +/// The operation registry can be converted into an [`#{Router}`] for +/// your service. This router will take care of routing HTTP +/// requests to the matching operation implementation, adhering to +/// your service's protocol and the [HTTP binding traits] that you +/// used in your Smithy model. This router can be converted into a +/// type implementing [`tower::make::MakeService`], a _service +/// factory_. You can feed this value to a [Hyper server], and the +/// server will instantiate and [`serve`] your service. +/// +/// Here's a full example to get you started: +/// +/// ```rust +/// use std::net::SocketAddr; +$inputOutputErrorsImport +/// use ${crateName.toSnakeCase()}::operation_registry::$operationRegistryBuilderName; +/// use #{Router}; +/// +/// ##[#{Tokio}::main] +/// pub async fn main() { +/// let app: Router = $operationRegistryBuilderName::default() +${operationNames.map { ".$it($it)" }.joinToString("\n") { it.prependIndent("/// ") }} +/// .build() +/// .expect("unable to build operation registry") +/// .into(); +/// +/// let bind: SocketAddr = format!("{}:{}", "127.0.0.1", "6969") +/// .parse() +/// .expect("unable to parse the server bind address and port"); +/// +/// let server = #{Hyper}::Server::bind(&bind).serve(app.into_make_service()); +/// +/// // Run your service! +/// // if let Err(err) = server.await { +/// // eprintln!("server error: {}", err); +/// // } +/// } +/// +${operationImplementationStubs(operations)} +/// ``` +/// +/// [`serve`]: https://docs.rs/hyper/0.14.16/hyper/server/struct.Builder.html##method.serve +/// [`tower::make::MakeService`]: https://docs.rs/tower/latest/tower/make/trait.MakeService.html +/// [HTTP binding traits]: https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html +/// [operations]: https://awslabs.github.io/smithy/1.0/spec/core/model.html##operation +/// [Hyper server]: https://docs.rs/hyper/latest/hyper/server/index.html +""", + "Router" to ServerRuntimeType.router(runtimeConfig), + // These should be dev-dependencies. Not all sSDKs depend on `Hyper` (only those that convert the body + // `to_bytes`), and none depend on `tokio`. + "Tokio" to ServerCargoDependency.TokioDev.toType(), + "Hyper" to CargoDependency.Hyper.copy(scope = DependencyScope.Dev).toType(), + ) + } + + private fun renderOperationRegistryStruct(writer: RustWriter) { + writer.rust("""##[deprecated(since = "0.52.0", note = "`OperationRegistry` is part of the deprecated service builder API. Use `$serviceName::builder` instead.")]""") + writer.rustBlock("pub struct $operationRegistryNameWithArguments") { + val members = operationNames + .mapIndexed { i, operationName -> "$operationName: Op$i" } + .joinToString(separator = ",\n") + rustTemplate( + """ + $members, + _phantom: #{Phantom}<(B, ${phantomMembers()})>, + """, + *codegenScope, + ) + } + } + + /** + * Renders the `OperationRegistryBuilder` structure, used to build the `OperationRegistry`. + */ + private fun renderOperationRegistryBuilderStruct(writer: RustWriter) { + writer.rust("""##[deprecated(since = "0.52.0", note = "`OperationRegistryBuilder` is part of the deprecated service builder API. Use `$serviceName::builder` instead.")]""") + writer.rustBlock("pub struct $operationRegistryBuilderNameWithArguments") { + val members = operationNames + .mapIndexed { i, operationName -> "$operationName: Option" } + .joinToString(separator = ",\n") + rustTemplate( + """ + $members, + _phantom: #{Phantom}<(B, ${phantomMembers()})>, + """, + *codegenScope, + ) + } + } + + /** + * Renders the `OperationRegistryBuilderError` type, used to error out in case there are uninitialized fields. + * This is an enum deriving `Debug` and implementing `Display` and `std::error::Error`. + */ + private fun renderOperationRegistryBuilderError(writer: RustWriter) { + Attribute(derive(RuntimeType.Debug)).render(writer) + writer.rustTemplate( + """ + pub enum $operationRegistryErrorName { + UninitializedField(&'static str) + } + impl #{Display} for $operationRegistryErrorName { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::UninitializedField(v) => write!(f, "{}", v), + } + } + } + impl #{StdError} for $operationRegistryErrorName {} + """, + *codegenScope, + ) + } + + /** + * Renders the `OperationRegistryBuilder` `Default` implementation, used to create a new builder that can be + * populated with the service's operation implementations. + */ + private fun renderOperationRegistryBuilderDefault(writer: RustWriter) { + writer.rustBlockTemplate("impl<$genericArguments> std::default::Default for $operationRegistryBuilderNameWithArguments") { + val defaultOperations = operationNames.joinToString(separator = "\n,") { operationName -> + "$operationName: Default::default()" + } + rustTemplate( + """ + fn default() -> Self { + Self { + $defaultOperations, + _phantom: #{Phantom} + } + } + """, + *codegenScope, + ) + } + } + + /** + * Renders the `OperationRegistryBuilder`'s impl block, where operations are stored. + * The `build()` method converts the builder into an `OperationRegistry` instance. + */ + private fun renderOperationRegistryBuilderImplementation(writer: RustWriter) { + writer.rustBlock("impl<$genericArguments> $operationRegistryBuilderNameWithArguments") { + operationNames.forEachIndexed { i, operationName -> + rust( + """ + pub fn $operationName(self, value: Op$i) -> Self { + let mut new = self; + new.$operationName = Some(value); + new + } + """, + ) + } + + rustBlock("pub fn build(self) -> Result<$operationRegistryNameWithArguments, $operationRegistryErrorName>") { + withBlock("Ok( $operationRegistryName {", "})") { + for (operationName in operationNames) { + rust( + """ + $operationName: match self.$operationName { + Some(v) => v, + None => return Err($operationRegistryErrorName::UninitializedField("$operationName")), + }, + """, + ) + } + rustTemplate("_phantom: #{Phantom}", *codegenScope) + } + } + } + } + + /** + * Renders the converter between the `OperationRegistry` and the `Router` via the `std::convert::From` trait. + */ + private fun renderRouterImplementationFromOperationRegistryBuilder(writer: RustWriter) { + val operationTraitBounds = writable { + operations.forEachIndexed { i, operation -> + rustTemplate( + """ + Op$i: #{ServerOperationHandler}::Handler, + In$i: 'static + Send, + """, + *codegenScope, + "OperationInput" to symbolProvider.toSymbol(operation.inputShape(model)), + ) + } + } + + writer.rustBlockTemplate( + // The bound `B: Send` is required because of [`tower::util::BoxCloneService`]. + // [`tower::util::BoxCloneService`]: https://docs.rs/tower/latest/tower/util/struct.BoxCloneService.html#method.new + """ + impl<$genericArguments> #{From}<$operationRegistryNameWithArguments> for #{Router} + where + B: Send + 'static, + #{operationTraitBounds:W} + """, + *codegenScope, + "operationTraitBounds" to operationTraitBounds, + ) { + rustBlock("fn from(registry: $operationRegistryNameWithArguments) -> Self") { + val requestSpecsVarNames = operationNames.map { "${it}_request_spec" } + + requestSpecsVarNames.zip(operations).forEach { (requestSpecVarName, operation) -> + rustTemplate( + "let $requestSpecVarName = #{RequestSpec:W};", + "RequestSpec" to operation.requestSpec(), + ) + } + + val sensitivityGens = operations.map { + ServerHttpSensitivityGenerator(model, it, codegenContext.runtimeConfig) + } + + withBlockTemplate( + "#{Router}::${protocol.serverRouterRuntimeConstructor()}(vec![", + "])", + *codegenScope, + ) { + requestSpecsVarNames.zip(operationNames).zip(sensitivityGens).forEach { + val (inner, sensitivityGen) = it + val (requestSpecVarName, operationName) = inner + + rustBlock("") { + rustTemplate( + """ + let svc = #{ServerOperationHandler}::operation(registry.$operationName); + let request_fmt = #{RequestFmt:W}; + let response_fmt = #{ResponseFmt:W}; + let svc = #{SmithyHttpServer}::instrumentation::InstrumentOperation::new(svc, "$operationName").request_fmt(request_fmt).response_fmt(response_fmt); + (#{Tower}::util::BoxCloneService::new(svc), $requestSpecVarName) + """, + "RequestFmt" to sensitivityGen.requestFmt().value, + "ResponseFmt" to sensitivityGen.responseFmt().value, + *codegenScope, + ) + } + rust(",") + } + } + } + } + } + + /** + * Returns the `PhantomData` generic members in a comma-separated list. + */ + private fun phantomMembers() = operationNames.mapIndexed { i, _ -> "In$i" }.joinToString(separator = ",\n") + + private fun operationImplementationStubs(operations: List): String = + operations.joinToString("\n///\n") { + val operationDocumentation = it.getTrait()?.value + val ret = if (!operationDocumentation.isNullOrBlank()) { + operationDocumentation.replace("#", "##").prependIndent("/// /// ") + "\n" + } else "" + ret + + """ + /// ${it.signature()} { + /// todo!() + /// } + """.trimIndent() + } + + /** + * Returns the function signature for an operation handler implementation. Used in the documentation. + */ + private fun OperationShape.signature(): String { + val inputSymbol = symbolProvider.toSymbol(inputShape(model)) + val outputSymbol = symbolProvider.toSymbol(outputShape(model)) + val errorSymbol = errorSymbol(symbolProvider) + + // using module names here to avoid generating `crate::...` since we've already added the import + val inputT = "${InputsModule.name}::${inputSymbol.name}" + val t = "${OutputsModule.name}::${outputSymbol.name}" + val outputT = if (errors.isEmpty()) { + t + } else { + val e = "${ErrorsModule.name}::${errorSymbol.name}" + "Result<$t, $e>" + } + + val operationName = RustReservedWords.escapeIfNeeded(symbolProvider.toSymbol(this).name.toSnakeCase()) + return "async fn $operationName(input: $inputT) -> $outputT" + } + + /** + * Returns a writable for the `RequestSpec` for an operation based on the service's protocol. + */ + private fun OperationShape.requestSpec(): Writable = protocol.serverRouterRequestSpec( + this, + symbolProvider.toSymbol(this).name, + serviceName, + ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("routing::request_spec"), + ) +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt index cf405dc5e99..a8d605a1a49 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt @@ -7,6 +7,9 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.model.knowledge.TopDownIndex import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.deprecated +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter @@ -20,6 +23,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.InputsModule import software.amazon.smithy.rust.codegen.core.smithy.OutputsModule import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport +import software.amazon.smithy.rust.codegen.core.util.toPascalCase import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext @@ -82,7 +86,7 @@ open class ServerServiceGenerator( //! let server = app.into_make_service(); //! let bind: SocketAddr = "127.0.0.1:6969".parse() //! .expect("unable to parse the server bind address and port"); - //! #{Hyper}::Server::bind(&bind).serve(server).await.unwrap(); + //! hyper::Server::bind(&bind).serve(server).await.unwrap(); //! ## } //! ``` //! @@ -115,7 +119,7 @@ open class ServerServiceGenerator( //! ```rust //! ## use #{SmithyHttpServer}::plugin::IdentityPlugin as LoggingPlugin; //! ## use #{SmithyHttpServer}::plugin::IdentityPlugin as MetricsPlugin; - //! ## use #{Hyper}::Body; + //! ## use hyper::Body; //! use #{SmithyHttpServer}::plugin::PluginPipeline; //! use $crateName::{$serviceName, $builderName}; //! @@ -190,7 +194,7 @@ open class ServerServiceGenerator( //! ## use std::net::SocketAddr; //! use $crateName::$serviceName; //! - //! ##[#{Tokio}::main] + //! ##[tokio::main] //! pub async fn main() { //! let app = $serviceName::builder_without_plugins() ${builderFieldNames.values.joinToString("\n") { "//! .$it($it)" }} @@ -199,7 +203,7 @@ open class ServerServiceGenerator( //! //! let bind: SocketAddr = "127.0.0.1:6969".parse() //! .expect("unable to parse the server bind address and port"); - //! let server = #{Hyper}::Server::bind(&bind).serve(app.into_make_service()); + //! let server = hyper::Server::bind(&bind).serve(app.into_make_service()); //! ## let server = async { Ok::<_, ()>(()) }; //! //! // Run your service! @@ -225,8 +229,6 @@ open class ServerServiceGenerator( "Handlers" to handlers, "ExampleHandler" to operations.take(1).map { operation -> DocHandlerGenerator(codegenContext, operation, builderFieldNames[operation]!!, "//!").docSignature() }, "SmithyHttpServer" to ServerCargoDependency.smithyHttpServer(codegenContext.runtimeConfig).toType(), - "Hyper" to ServerCargoDependency.HyperDev.toType(), - "Tokio" to ServerCargoDependency.TokioDev.toType(), "Tower" to ServerCargoDependency.Tower.toType(), ) } @@ -253,6 +255,30 @@ open class ServerServiceGenerator( } } } + rustCrate.withModule(RustModule.private("operation_handler", "Operation handlers definition and implementation.")) { + renderOperationHandler(this, operations) + } + rustCrate.withModule( + RustModule.LeafModule( + "operation_registry", + RustMetadata( + visibility = Visibility.PUBLIC, + additionalAttributes = listOf( + Attribute(deprecated("0.52.0", "This module exports the deprecated `OperationRegistry`. Use the service builder exported from your root crate.")), + ), + ), + """ + Contains the [`operation_registry::OperationRegistry`], a place where + you can register your service's operation implementations. + + ## Deprecation + + This service builder is deprecated - use [`${codegenContext.serviceShape.id.name.toPascalCase()}::builder_with_plugins`] or [`${codegenContext.serviceShape.id.name.toPascalCase()}::builder_without_plugins`] instead. + """, + ), + ) { + renderOperationRegistry(this, operations) + } rustCrate.withModule( RustModule.public("operation_shape"), @@ -291,6 +317,16 @@ open class ServerServiceGenerator( /* Subclasses can override */ } + // Render operations handler. + open fun renderOperationHandler(writer: RustWriter, operations: List) { + ServerOperationHandlerGenerator(codegenContext, protocol, operations).render(writer) + } + + // Render operations registry. + private fun renderOperationRegistry(writer: RustWriter, operations: List) { + ServerOperationRegistryGenerator(codegenContext, protocol, operations).render(writer) + } + // Render `server` crate, re-exporting types. private fun renderServerReExports(writer: RustWriter) { ServerRuntimeTypesReExportsGenerator(codegenContext).render(writer) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGeneratorV2.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGeneratorV2.kt index e8483f4a206..80b1d9b6319 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGeneratorV2.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGeneratorV2.kt @@ -231,7 +231,7 @@ class ServerServiceGeneratorV2( #{Router}::from_iter([#{RoutesArrayElements:W}]) }; Ok($serviceName { - router: #{SmithyHttpServer}::routing::RoutingService::new(router), + router: #{SmithyHttpServer}::routers::RoutingService::new(router), }) } """, @@ -306,7 +306,7 @@ class ServerServiceGeneratorV2( { let router = #{Router}::from_iter([#{Pairs:W}]); $serviceName { - router: #{SmithyHttpServer}::routing::RoutingService::new(router), + router: #{SmithyHttpServer}::routers::RoutingService::new(router), } } """, @@ -387,7 +387,7 @@ class ServerServiceGeneratorV2( /// See the [root](crate) documentation for more information. ##[derive(Clone)] pub struct $serviceName { - router: #{SmithyHttpServer}::routing::RoutingService<#{Router}, #{Protocol}>, + router: #{SmithyHttpServer}::routers::RoutingService<#{Router}, #{Protocol}>, } impl $serviceName<()> { @@ -459,7 +459,7 @@ class ServerServiceGeneratorV2( { type Response = #{Http}::Response<#{SmithyHttpServer}::body::BoxBody>; type Error = S::Error; - type Future = #{SmithyHttpServer}::routing::RoutingFuture; + type Future = #{SmithyHttpServer}::routers::RoutingFuture; fn poll_ready(&mut self, cx: &mut std::task::Context) -> std::task::Poll> { self.router.poll_ready(cx) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 129199115a9..431eb99637d 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -37,6 +37,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.withBlock +import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType @@ -55,9 +56,12 @@ import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType import software.amazon.smithy.rust.codegen.server.smithy.generators.serverInstantiator +import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpBoundProtocolGenerator import java.util.logging.Logger import kotlin.reflect.KFunction1 +private const val PROTOCOL_TEST_HELPER_MODULE_NAME = "protocol_test_helper" + /** * Generate protocol tests for an operation */ @@ -137,12 +141,97 @@ class ServerProtocolTestGenerator( } fun render(writer: RustWriter) { + renderTestHelper(writer) + for (operation in operations) { protocolGenerator.renderOperation(writer, operation) renderOperationTestCases(operation, writer) } } + /** + * Render a test helper module to: + * + * - generate a dynamic builder for each handler, and + * - construct a Tower service to exercise each test case. + */ + private fun renderTestHelper(writer: RustWriter) { + val operationNames = operations.map { it.toName() } + val operationRegistryName = "OperationRegistry" + val operationRegistryBuilderName = "${operationRegistryName}Builder" + + fun renderRegistryBuilderTypeParams() = writable { + operations.forEach { + val (inputT, outputT) = operationInputOutputTypes[it]!! + writeInline("Fun<$inputT, $outputT>, (), ") + } + } + + fun renderRegistryBuilderMethods() = writable { + operations.withIndex().forEach { + val (inputT, outputT) = operationInputOutputTypes[it.value]!! + val operationName = operationNames[it.index] + rust(".$operationName((|_| Box::pin(async { todo!() })) as Fun<$inputT, $outputT> )") + } + } + + val module = RustModule.LeafModule( + PROTOCOL_TEST_HELPER_MODULE_NAME, + RustMetadata( + additionalAttributes = listOf( + Attribute.CfgTest, + Attribute.AllowDeadCode, + ), + visibility = Visibility.PUBCRATE, + ), + inline = true, + ) + + writer.withInlineModule(module) { + rustTemplate( + """ + use #{Tower}::Service as _; + + pub(crate) type Fun = fn(Input) -> std::pin::Pin + Send>>; + + type RegistryBuilder = crate::operation_registry::$operationRegistryBuilderName<#{Hyper}::Body, #{RegistryBuilderTypeParams:W}>; + + fn create_operation_registry_builder() -> RegistryBuilder { + crate::operation_registry::$operationRegistryBuilderName::default() + #{RegistryBuilderMethods:W} + } + + pub(crate) async fn build_router_and_make_request( + http_request: #{Http}::request::Request<#{SmithyHttpServer}::body::Body>, + f: &dyn Fn(RegistryBuilder) -> RegistryBuilder, + ) -> #{Http}::response::Response<#{SmithyHttpServer}::body::BoxBody> { + let mut router: #{Router} = f(create_operation_registry_builder()) + .build() + .expect("unable to build operation registry") + .into(); + let http_response = router + .call(http_request) + .await + .expect("unable to make an HTTP request"); + + http_response + } + + /// The operation full name is a concatenation of `.`. + pub(crate) fn check_operation_extension_was_set(http_response: #{Http}::response::Response<#{SmithyHttpServer}::body::BoxBody>, operation_full_name: &str) { + let operation_extension = http_response.extensions() + .get::<#{SmithyHttpServer}::extension::OperationExtension>() + .expect("extension `OperationExtension` not found"); + #{AssertEq}(operation_extension.absolute(), operation_full_name); + } + """, + "RegistryBuilderTypeParams" to renderRegistryBuilderTypeParams(), + "RegistryBuilderMethods" to renderRegistryBuilderMethods(), + *codegenScope, + ) + } + } + private fun renderOperationTestCases(operationShape: OperationShape, writer: RustWriter) { val outputShape = operationShape.outputShape(codegenContext.model) val operationSymbol = symbolProvider.toSymbol(operationShape) @@ -306,12 +395,22 @@ class ServerProtocolTestGenerator( return } + // Test against original `OperationRegistryBuilder`. with(httpRequestTestCase) { renderHttpRequest(uri, method, headers, body.orNull(), queryParams, host.orNull()) } if (protocolSupport.requestBodyDeserialization) { - makeRequest(operationShape, operationSymbol, this, checkRequestHandler(operationShape, httpRequestTestCase)) - checkHandlerWasEntered(this) + makeRequest(operationShape, this, checkRequestHandler(operationShape, httpRequestTestCase)) + checkHandlerWasEntered(operationShape, operationSymbol, this) + } + + // Test against new service builder. + with(httpRequestTestCase) { + renderHttpRequest(uri, method, headers, body.orNull(), queryParams, host.orNull()) + } + if (protocolSupport.requestBodyDeserialization) { + makeRequest2(operationShape, operationSymbol, this, checkRequestHandler(operationShape, httpRequestTestCase)) + checkHandlerWasEntered2(this) } // Explicitly warn if the test case defined parameters that we aren't doing anything with @@ -341,6 +440,8 @@ class ServerProtocolTestGenerator( operationShape: OperationShape, operationSymbol: Symbol, ) { + val operationImplementationName = + "${operationSymbol.name}${ServerHttpBoundProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}" val operationErrorName = "crate::error::${operationSymbol.name}Error" if (!protocolSupport.responseSerialization || ( @@ -353,13 +454,19 @@ class ServerProtocolTestGenerator( writeInline("let output =") instantiator.render(this, shape, testCase.params) rust(";") - if (operationShape.allErrors(model).isNotEmpty() && shape.hasTrait()) { - val variant = symbolProvider.toSymbol(shape).name - rust("let output = $operationErrorName::$variant(output);") + val operationImpl = if (operationShape.allErrors(model).isNotEmpty()) { + if (shape.hasTrait()) { + val variant = symbolProvider.toSymbol(shape).name + "$operationImplementationName::Error($operationErrorName::$variant(output))" + } else { + "$operationImplementationName::Output(output)" + } + } else { + "$operationImplementationName(output)" } rustTemplate( """ - use #{SmithyHttpServer}::response::IntoResponse; + let output = super::$operationImpl; let http_response = output.into_response(); """, *codegenScope, @@ -381,13 +488,23 @@ class ServerProtocolTestGenerator( val panicMessage = "request should have been rejected, but we accepted it; we parsed operation input `{:?}`" + rust("// Use the `OperationRegistryBuilder`") rustBlock("") { with(testCase.request) { // TODO(https://github.com/awslabs/smithy/issues/1102): `uri` should probably not be an `Optional`. renderHttpRequest(uri.get(), method, headers, body.orNull(), queryParams, host.orNull()) } + makeRequest(operationShape, this, writable("""panic!("$panicMessage", &input) as $outputT""")) + checkResponse(this, testCase.response) + } - makeRequest(operationShape, operationSymbol, this, writable("""panic!("$panicMessage", &input) as $outputT""")) + rust("// Use new service builder") + rustBlock("") { + with(testCase.request) { + // TODO(https://github.com/awslabs/smithy/issues/1102): `uri` should probably not be an `Optional`. + renderHttpRequest(uri.get(), method, headers, body.orNull(), queryParams, host.orNull()) + } + makeRequest2(operationShape, operationSymbol, this, writable("""panic!("$panicMessage", &input) as $outputT""")) checkResponse(this, testCase.response) } } @@ -469,8 +586,44 @@ class ServerProtocolTestGenerator( } } - /** Checks the request. */ + /** Checks the request using the `OperationRegistryBuilder`. */ private fun makeRequest( + operationShape: OperationShape, + rustWriter: RustWriter, + operationBody: Writable, + ) { + val (inputT, outputT) = operationInputOutputTypes[operationShape]!! + + rustWriter.withBlockTemplate( + """ + let http_response = super::$PROTOCOL_TEST_HELPER_MODULE_NAME::build_router_and_make_request( + http_request, + &|builder| { + builder.${operationShape.toName()}((|input| Box::pin(async move { + """, + + "})) as super::$PROTOCOL_TEST_HELPER_MODULE_NAME::Fun<$inputT, $outputT>)}).await;", + *codegenScope, + ) { + operationBody() + } + } + + private fun checkHandlerWasEntered( + operationShape: OperationShape, + operationSymbol: Symbol, + rustWriter: RustWriter, + ) { + val operationFullName = "${operationShape.id.namespace}.${operationSymbol.name}" + rustWriter.rust( + """ + super::$PROTOCOL_TEST_HELPER_MODULE_NAME::check_operation_extension_was_set(http_response, "$operationFullName"); + """, + ) + } + + /** Checks the request using the new service builder. */ + private fun makeRequest2( operationShape: OperationShape, operationSymbol: Symbol, rustWriter: RustWriter, @@ -501,7 +654,7 @@ class ServerProtocolTestGenerator( ) } - private fun checkHandlerWasEntered(rustWriter: RustWriter) { + private fun checkHandlerWasEntered2(rustWriter: RustWriter) { rustWriter.rust( """ assert!(receiver.recv().await.is_some()); diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 5e2ce595af0..eed3dc8433f 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -161,11 +161,13 @@ private class ServerHttpBoundProtocolTraitImplGenerator( operationShape: OperationShape, ) { val operationName = symbolProvider.toSymbol(operationShape).name + val inputName = "${operationName}${ServerHttpBoundProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}" + val verifyAcceptHeader = writable { httpBindingResolver.responseContentType(operationShape)?.also { contentType -> rustTemplate( """ - if ! #{SmithyHttpServer}::protocols::accept_header_classifier(request.headers(), ${contentType.dq()}) { + if ! #{SmithyHttpServer}::protocols::accept_header_classifier(req, ${contentType.dq()}) { return Err(#{RuntimeError}::NotAcceptable) } """, @@ -185,7 +187,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( ?.let { "Some(${it.dq()})" } ?: "None" rustTemplate( """ - if #{SmithyHttpServer}::protocols::content_type_header_classifier(request.headers(), $expectedRequestContentType).is_err() { + if #{SmithyHttpServer}::protocols::content_type_header_classifier(req, $expectedRequestContentType).is_err() { return Err(#{RuntimeError}::UnsupportedMediaType) } """, @@ -198,6 +200,25 @@ private class ServerHttpBoundProtocolTraitImplGenerator( // Implement `from_request` trait for input types. rustTemplate( """ + ##[derive(Debug)] + pub(crate) struct $inputName(#{I}); + impl $inputName + { + pub async fn from_request(req: &mut #{SmithyHttpServer}::request::RequestParts) -> Result + where + B: #{SmithyHttpServer}::body::HttpBody + Send, ${streamingBodyTraitBounds(operationShape)} + B::Data: Send, + #{RequestRejection} : From<::Error> + { + #{verifyAcceptHeader:W} + #{verifyRequestContentTypeHeader:W} + #{parse_request}(req) + .await + .map($inputName) + .map_err(Into::into) + } + } + impl #{SmithyHttpServer}::request::FromRequest<#{Marker}, B> for #{I} where B: #{SmithyHttpServer}::body::HttpBody + Send, @@ -211,11 +232,8 @@ private class ServerHttpBoundProtocolTraitImplGenerator( fn from_request(request: #{http}::Request) -> Self::Future { let fut = async move { - #{verifyAcceptHeader:W} - #{verifyRequestContentTypeHeader:W} - #{parse_request}(request) - .await - .map_err(Into::into) + let mut request_parts = #{SmithyHttpServer}::request::RequestParts::new(request); + $inputName::from_request(&mut request_parts).await.map(|x| x.0) }; Box::pin(fut) } @@ -231,46 +249,143 @@ private class ServerHttpBoundProtocolTraitImplGenerator( ) // Implement `into_response` for output types. - val errorSymbol = operationShape.errorSymbol(symbolProvider) - rustTemplate( - """ - impl #{SmithyHttpServer}::response::IntoResponse<#{Marker}> for #{O} { - fn into_response(self) -> #{SmithyHttpServer}::response::Response { - match #{serialize_response}(self) { - Ok(response) => response, - Err(e) => #{SmithyHttpServer}::response::IntoResponse::<#{Marker}>::into_response(#{RuntimeError}::from(e)) - } - } - } - """.trimIndent(), - *codegenScope, - "O" to outputSymbol, - "Marker" to protocol.markerStruct(), - "serialize_response" to serverSerializeResponse(operationShape), - ) + val outputName = "${operationName}${ServerHttpBoundProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}" + val errorSymbol = operationShape.errorSymbol(symbolProvider) if (operationShape.operationErrors(model).isNotEmpty()) { - rustTemplate( + // The output of fallible operations is a `Result` which we convert into an + // isomorphic `enum` type we control that can in turn be converted into a response. + val intoResponseImpl = """ - impl #{SmithyHttpServer}::response::IntoResponse<#{Marker}> for #{E} { - fn into_response(self) -> #{SmithyHttpServer}::response::Response { - match #{serialize_error}(&self) { + match self { + Self::Output(o) => { + match #{serialize_response}(o) { + Ok(response) => response, + Err(e) => #{SmithyHttpServer}::response::IntoResponse::<#{Marker}>::into_response(#{RuntimeError}::from(e)) + } + }, + Self::Error(err) => { + match #{serialize_error}(&err) { Ok(mut response) => { - response.extensions_mut().insert(#{SmithyHttpServer}::extension::ModeledErrorExtension::new(self.name())); + response.extensions_mut().insert(#{SmithyHttpServer}::extension::ModeledErrorExtension::new(err.name())); response }, Err(e) => #{SmithyHttpServer}::response::IntoResponse::<#{Marker}>::into_response(#{RuntimeError}::from(e)) } } } + """ + + rustTemplate( + """ + pub(crate) enum $outputName { + Output(#{O}), + Error(#{E}) + } + + impl $outputName { + pub fn into_response(self) -> #{SmithyHttpServer}::response::Response { + $intoResponseImpl + } + } + + impl #{SmithyHttpServer}::response::IntoResponse<#{Marker}> for #{O} { + fn into_response(self) -> #{SmithyHttpServer}::response::Response { + $outputName::Output(self).into_response() + } + } + + impl #{SmithyHttpServer}::response::IntoResponse<#{Marker}> for #{E} { + fn into_response(self) -> #{SmithyHttpServer}::response::Response { + $outputName::Error(self).into_response() + } + } """.trimIndent(), *codegenScope, + "O" to outputSymbol, "E" to errorSymbol, "Marker" to protocol.markerStruct(), + "serialize_response" to serverSerializeResponse(operationShape), "serialize_error" to serverSerializeError(operationShape), ) + } else { + // The output of non-fallible operations is a model type which we convert into + // a "wrapper" unit `struct` type we control that can in turn be converted into a response. + val intoResponseImpl = + """ + match #{serialize_response}(self.0) { + Ok(response) => response, + Err(e) => #{SmithyHttpServer}::response::IntoResponse::<#{Marker}>::into_response(#{RuntimeError}::from(e)) + } + """.trimIndent() + + rustTemplate( + """ + pub(crate) struct $outputName(#{O}); + + impl $outputName { + pub fn into_response(self) -> #{SmithyHttpServer}::response::Response { + $intoResponseImpl + } + } + + impl #{SmithyHttpServer}::response::IntoResponse<#{Marker}> for #{O} { + fn into_response(self) -> #{SmithyHttpServer}::response::Response { + $outputName(self).into_response() + } + } + """.trimIndent(), + *codegenScope, + "O" to outputSymbol, + "Marker" to protocol.markerStruct(), + "serialize_response" to serverSerializeResponse(operationShape), + ) + } + + // Implement conversion function to "wrap" from the model operation output types. + if (operationShape.operationErrors(model).isNotEmpty()) { + rustTemplate( + """ + impl #{From}> for $outputName { + fn from(res: Result<#{O}, #{E}>) -> Self { + match res { + Ok(v) => Self::Output(v), + Err(e) => Self::Error(e), + } + } + } + """.trimIndent(), + "O" to outputSymbol, + "E" to errorSymbol, + "From" to RuntimeType.From, + ) + } else { + rustTemplate( + """ + impl #{From}<#{O}> for $outputName { + fn from(o: #{O}) -> Self { + Self(o) + } + } + """.trimIndent(), + "O" to outputSymbol, + "From" to RuntimeType.From, + ) } + + // Implement conversion function to "unwrap" into the model operation input types. + rustTemplate( + """ + impl #{From}<$inputName> for #{I} { + fn from(i: $inputName) -> Self { + i.0 + } + } + """.trimIndent(), + "I" to inputSymbol, + "From" to RuntimeType.From, + ) } private fun serverParseRequest(operationShape: OperationShape): RuntimeType { @@ -284,7 +399,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( rustBlockTemplate( """ pub async fn $fnName( - ##[allow(unused_variables)] request: #{http}::Request + ##[allow(unused_variables)] request: &mut #{SmithyHttpServer}::request::RequestParts ) -> std::result::Result< #{I}, #{RequestRejection} @@ -597,13 +712,12 @@ private class ServerHttpBoundProtocolTraitImplGenerator( "let mut input = #T::default();", inputShape.serverBuilderSymbol(codegenContext), ) - Attribute.AllowUnusedVariables.render(this) - rust("let (parts, body) = request.into_parts();") val parser = structuredDataParser.serverInputParser(operationShape) val noInputs = model.expectShape(operationShape.inputShape).expectTrait().originalId == null if (parser != null) { rustTemplate( """ + let body = request.take_body().ok_or(#{RequestRejection}::BodyAlreadyExtracted)?; let bytes = #{Hyper}::body::to_bytes(body).await?; if !bytes.is_empty() { input = #{parser}(bytes.as_ref(), input)?; @@ -641,7 +755,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( conditionalBlock("if body.is_empty() {", "}", conditional = parser != null) { rustTemplate( """ - #{SmithyHttpServer}::protocols::content_type_header_empty_body_no_modeled_input(&parts.headers)?; + #{SmithyHttpServer}::protocols::content_type_header_empty_body_no_modeled_input(request)?; """, *codegenScope, ) @@ -683,6 +797,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( rustTemplate( """ { + let body = request.take_body().ok_or(#{RequestRejection}::BodyAlreadyExtracted)?; Some(#{Deserializer}(&mut body.into().into_inner())?) } """, @@ -693,6 +808,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( rustTemplate( """ { + let body = request.take_body().ok_or(#{RequestRejection}::BodyAlreadyExtracted)?; let bytes = #{Hyper}::body::to_bytes(body).await?; #{Deserializer}(&bytes)? } @@ -762,7 +878,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( }, ) with(writer) { - rustTemplate("let input_string = parts.uri.path();") + rustTemplate("let input_string = request.uri().path();") if (greedyLabelIndex >= 0 && greedyLabelIndex + 1 < httpTrait.uri.segments.size) { rustTemplate( """ @@ -847,7 +963,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( with(writer) { rustTemplate( """ - let query_string = parts.uri.query().unwrap_or(""); + let query_string = request.uri().query().unwrap_or(""); let pairs = #{FormUrlEncoded}::parse(query_string.as_bytes()); """.trimIndent(), *codegenScope, @@ -1013,7 +1129,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( val deserializer = httpBindingGenerator.generateDeserializeHeaderFn(binding) writer.rustTemplate( """ - #{deserializer}(&parts.headers)? + #{deserializer}(request.headers().ok_or(#{RequestRejection}::HeadersAlreadyExtracted)?)? """.trimIndent(), "deserializer" to deserializer, *codegenScope, @@ -1027,7 +1143,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( val deserializer = httpBindingGenerator.generateDeserializePrefixHeadersFn(binding) writer.rustTemplate( """ - #{deserializer}(&parts.headers)? + #{deserializer}(request.headers().ok_or(#{RequestRejection}::HeadersAlreadyExtracted)?)? """.trimIndent(), "deserializer" to deserializer, *codegenScope, diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationRegistryGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationRegistryGeneratorTest.kt new file mode 100644 index 00000000000..6103de903b7 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationRegistryGeneratorTest.kt @@ -0,0 +1,120 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import io.kotest.matchers.string.shouldContain +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.knowledge.TopDownIndex +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol +import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerProtocolLoader +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext + +class ServerOperationRegistryGeneratorTest { + private val model = """ + namespace test + + use aws.protocols#restJson1 + + @restJson1 + service Service { + operations: [ + Frobnify, + SayHello, + ], + } + + /// Only the Frobnify operation is documented, + /// over multiple lines. + /// And here are #hash #tags! + @http(method: "GET", uri: "/frobnify") + operation Frobnify { + input: FrobnifyInputOutput, + output: FrobnifyInputOutput, + errors: [FrobnifyFailure] + } + + @http(method: "GET", uri: "/hello") + operation SayHello { + input: SayHelloInputOutput, + output: SayHelloInputOutput, + } + + structure FrobnifyInputOutput {} + structure SayHelloInputOutput {} + + @error("server") + structure FrobnifyFailure {} + """.asSmithyModel() + + @Test + fun `it generates quickstart example`() { + val serviceShape = model.lookup("test#Service") + val (protocolShapeId, protocolGeneratorFactory) = ServerProtocolLoader(ServerProtocolLoader.DefaultProtocols).protocolFor( + model, + serviceShape, + ) + val serverCodegenContext = serverTestCodegenContext( + model, + serviceShape, + protocolShapeId = protocolShapeId, + ) + + val index = TopDownIndex.of(serverCodegenContext.model) + val operations = index.getContainedOperations(serverCodegenContext.serviceShape).sortedBy { it.id } + val protocol = protocolGeneratorFactory.protocol(serverCodegenContext) as ServerProtocol + + val generator = ServerOperationRegistryGenerator(serverCodegenContext, protocol, operations) + val writer = RustWriter.forModule("operation_registry") + generator.render(writer) + + writer.toString() shouldContain + """ + /// ```rust + /// use std::net::SocketAddr; + /// use test_module::{input, output, error}; + /// use test_module::operation_registry::OperationRegistryBuilder; + /// use aws_smithy_http_server::routing::Router; + /// + /// #[tokio::main] + /// pub async fn main() { + /// let app: Router = OperationRegistryBuilder::default() + /// .frobnify(frobnify) + /// .say_hello(say_hello) + /// .build() + /// .expect("unable to build operation registry") + /// .into(); + /// + /// let bind: SocketAddr = format!("{}:{}", "127.0.0.1", "6969") + /// .parse() + /// .expect("unable to parse the server bind address and port"); + /// + /// let server = hyper::Server::bind(&bind).serve(app.into_make_service()); + /// + /// // Run your service! + /// // if let Err(err) = server.await { + /// // eprintln!("server error: {}", err); + /// // } + /// } + /// + /// /// Only the Frobnify operation is documented, + /// /// over multiple lines. + /// /// And here are #hash #tags! + /// async fn frobnify(input: input::FrobnifyInputOutput) -> Result { + /// todo!() + /// } + /// + /// async fn say_hello(input: input::SayHelloInputOutput) -> output::SayHelloInputOutput { + /// todo!() + /// } + /// ``` + /// + """.trimIndent() + } +} diff --git a/rust-runtime/aws-smithy-http-server/src/extension.rs b/rust-runtime/aws-smithy-http-server/src/extension.rs index ce265bd7f2e..5c302a757af 100644 --- a/rust-runtime/aws-smithy-http-server/src/extension.rs +++ b/rust-runtime/aws-smithy-http-server/src/extension.rs @@ -28,6 +28,8 @@ use tower::{layer::util::Stack, Layer, Service}; use crate::operation::{Operation, OperationShape}; use crate::plugin::{plugin_from_operation_name_fn, OperationNameFn, Plugin, PluginPipeline, PluginStack}; +#[allow(deprecated)] +use crate::request::RequestParts; pub use crate::request::extension::{Extension, MissingExtension}; @@ -234,6 +236,37 @@ impl Deref for RuntimeErrorExtension { } } +/// Extract an [`Extension`] from a request. +/// This is essentially the implementation of `FromRequest` for `Extension`, but with a +/// protocol-agnostic rejection type. The actual code-generated implementation simply delegates to +/// this function and converts the rejection type into a [`crate::runtime_error::RuntimeError`]. +#[deprecated( + since = "0.52.0", + note = "This was used for extraction under the older service builder. The `FromParts::from_parts` method is now used instead." +)] +#[allow(deprecated)] +pub async fn extract_extension( + req: &mut RequestParts, +) -> Result, crate::rejection::RequestExtensionNotFoundRejection> +where + T: Clone + Send + Sync + 'static, + B: Send, +{ + let value = req + .extensions() + .ok_or(crate::rejection::RequestExtensionNotFoundRejection::ExtensionsAlreadyExtracted)? + .get::() + .ok_or_else(|| { + crate::rejection::RequestExtensionNotFoundRejection::MissingExtension(format!( + "Extension of type `{}` was not found. Perhaps you forgot to add it?", + std::any::type_name::() + )) + }) + .map(|x| x.clone())?; + + Ok(Extension(value)) +} + #[cfg(test)] mod tests { use super::*; diff --git a/rust-runtime/aws-smithy-http-server/src/lib.rs b/rust-runtime/aws-smithy-http-server/src/lib.rs index 804b00d8a30..8fefaf242db 100644 --- a/rust-runtime/aws-smithy-http-server/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server/src/lib.rs @@ -28,11 +28,16 @@ pub mod routing; #[doc(hidden)] pub mod runtime_error; +#[doc(hidden)] +pub mod routers; + #[doc(inline)] pub(crate) use self::error::Error; -#[doc(inline)] pub use self::request::extension::Extension; #[doc(inline)] +#[allow(deprecated)] +pub use self::routing::Router; +#[doc(inline)] pub use tower_http::add_extension::{AddExtension, AddExtensionLayer}; #[cfg(test)] diff --git a/rust-runtime/aws-smithy-http-server/src/proto/aws_json/router.rs b/rust-runtime/aws-smithy-http-server/src/proto/aws_json/router.rs index 4474dfb600f..67041a62fc9 100644 --- a/rust-runtime/aws-smithy-http-server/src/proto/aws_json/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/proto/aws_json/router.rs @@ -9,9 +9,9 @@ use tower::Layer; use tower::Service; use crate::body::BoxBody; +use crate::routers::Router; use crate::routing::tiny_map::TinyMap; use crate::routing::Route; -use crate::routing::Router; use http::header::ToStrError; use thiserror::Error; @@ -117,42 +117,3 @@ impl FromIterator<(String, S)> for AwsJsonRouter { } } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::{proto::test_helpers::req, routing::Router}; - - use http::{HeaderMap, HeaderValue, Method}; - use pretty_assertions::assert_eq; - - #[tokio::test] - async fn simple_routing() { - let routes = vec![("Service.Operation")]; - let router: AwsJsonRouter<_> = routes - .clone() - .into_iter() - .map(|operation| (operation.to_string(), ())) - .collect(); - - let mut headers = HeaderMap::new(); - headers.insert("x-amz-target", HeaderValue::from_static("Service.Operation")); - - // Valid request, should match. - router - .match_route(&req(&Method::POST, "/", Some(headers.clone()))) - .unwrap(); - - // No headers, should return `MissingHeader`. - let res = router.match_route(&req(&Method::POST, "/", None)); - assert_eq!(res.unwrap_err().to_string(), Error::MissingHeader.to_string()); - - // Wrong HTTP method, should return `MethodNotAllowed`. - let res = router.match_route(&req(&Method::GET, "/", Some(headers.clone()))); - assert_eq!(res.unwrap_err().to_string(), Error::MethodNotAllowed.to_string()); - - // Wrong URI, should return `NotRootUrl`. - let res = router.match_route(&req(&Method::POST, "/something", Some(headers))); - assert_eq!(res.unwrap_err().to_string(), Error::NotRootUrl.to_string()); - } -} diff --git a/rust-runtime/aws-smithy-http-server/src/proto/aws_json_10/router.rs b/rust-runtime/aws-smithy-http-server/src/proto/aws_json_10/router.rs index 31d5ce8a9e1..5c582f569cf 100644 --- a/rust-runtime/aws-smithy-http-server/src/proto/aws_json_10/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/proto/aws_json_10/router.rs @@ -7,7 +7,7 @@ use crate::body::{empty, BoxBody}; use crate::extension::RuntimeErrorExtension; use crate::proto::aws_json::router::Error; use crate::response::IntoResponse; -use crate::routing::{method_disallowed, UNKNOWN_OPERATION_EXCEPTION}; +use crate::routers::{method_disallowed, UNKNOWN_OPERATION_EXCEPTION}; use super::AwsJson1_0; diff --git a/rust-runtime/aws-smithy-http-server/src/proto/aws_json_11/router.rs b/rust-runtime/aws-smithy-http-server/src/proto/aws_json_11/router.rs index 18f3b4b3293..8d0f8c0a06d 100644 --- a/rust-runtime/aws-smithy-http-server/src/proto/aws_json_11/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/proto/aws_json_11/router.rs @@ -7,7 +7,7 @@ use crate::body::{empty, BoxBody}; use crate::extension::RuntimeErrorExtension; use crate::proto::aws_json::router::Error; use crate::response::IntoResponse; -use crate::routing::{method_disallowed, UNKNOWN_OPERATION_EXCEPTION}; +use crate::routers::{method_disallowed, UNKNOWN_OPERATION_EXCEPTION}; use super::AwsJson1_1; diff --git a/rust-runtime/aws-smithy-http-server/src/proto/mod.rs b/rust-runtime/aws-smithy-http-server/src/proto/mod.rs index 26fb17d8937..39344ab4f1d 100644 --- a/rust-runtime/aws-smithy-http-server/src/proto/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/proto/mod.rs @@ -9,27 +9,3 @@ pub mod aws_json_11; pub mod rest; pub mod rest_json_1; pub mod rest_xml; - -#[cfg(test)] -pub mod test_helpers { - use http::{HeaderMap, Method, Request}; - - /// Helper function to build a `Request`. Used in other test modules. - pub fn req(method: &Method, uri: &str, headers: Option) -> Request<()> { - let mut r = Request::builder().method(method).uri(uri).body(()).unwrap(); - if let Some(headers) = headers { - *r.headers_mut() = headers - } - r - } - - // Returns a `Response`'s body as a `String`, without consuming the response. - pub async fn get_body_as_string(body: B) -> String - where - B: http_body::Body + std::marker::Unpin, - B::Error: std::fmt::Debug, - { - let body_bytes = hyper::body::to_bytes(body).await.unwrap(); - String::from(std::str::from_utf8(&body_bytes).unwrap()) - } -} diff --git a/rust-runtime/aws-smithy-http-server/src/proto/rest/router.rs b/rust-runtime/aws-smithy-http-server/src/proto/rest/router.rs index 1d55f676d6e..4c73eda72bf 100644 --- a/rust-runtime/aws-smithy-http-server/src/proto/rest/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/proto/rest/router.rs @@ -6,17 +6,17 @@ use std::convert::Infallible; use crate::body::BoxBody; +use crate::routers::Router; use crate::routing::request_spec::Match; use crate::routing::request_spec::RequestSpec; use crate::routing::Route; -use crate::routing::Router; use tower::Layer; use tower::Service; use thiserror::Error; /// An AWS REST routing error. -#[derive(Debug, Error, PartialEq)] +#[derive(Debug, Error)] pub enum Error { /// Operation not found. #[error("operation not found")] @@ -108,166 +108,3 @@ impl FromIterator<(RequestSpec, S)> for RestRouter { Self { routes } } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::{proto::test_helpers::req, routing::request_spec::*}; - - use http::Method; - - // This test is a rewrite of `mux.spec.ts`. - // https://github.com/awslabs/smithy-typescript/blob/fbf97a9bf4c1d8cf7f285ea7c24e1f0ef280142a/smithy-typescript-ssdk-libs/server-common/src/httpbinding/mux.spec.ts - #[test] - fn simple_routing() { - let request_specs: Vec<(RequestSpec, &'static str)> = vec![ - ( - RequestSpec::from_parts( - Method::GET, - vec![ - PathSegment::Literal(String::from("a")), - PathSegment::Label, - PathSegment::Label, - ], - Vec::new(), - ), - "A", - ), - ( - RequestSpec::from_parts( - Method::GET, - vec![ - PathSegment::Literal(String::from("mg")), - PathSegment::Greedy, - PathSegment::Literal(String::from("z")), - ], - Vec::new(), - ), - "MiddleGreedy", - ), - ( - RequestSpec::from_parts( - Method::DELETE, - Vec::new(), - vec![ - QuerySegment::KeyValue(String::from("foo"), String::from("bar")), - QuerySegment::Key(String::from("baz")), - ], - ), - "Delete", - ), - ( - RequestSpec::from_parts( - Method::POST, - vec![PathSegment::Literal(String::from("query_key_only"))], - vec![QuerySegment::Key(String::from("foo"))], - ), - "QueryKeyOnly", - ), - ]; - - // Test both RestJson1 and RestXml routers. - let router: RestRouter<_> = request_specs - .into_iter() - .map(|(spec, svc_name)| (spec, svc_name)) - .collect(); - - let hits = vec![ - ("A", Method::GET, "/a/b/c"), - ("MiddleGreedy", Method::GET, "/mg/a/z"), - ("MiddleGreedy", Method::GET, "/mg/a/b/c/d/z?abc=def"), - ("Delete", Method::DELETE, "/?foo=bar&baz=quux"), - ("Delete", Method::DELETE, "/?foo=bar&baz"), - ("Delete", Method::DELETE, "/?foo=bar&baz=&"), - ("Delete", Method::DELETE, "/?foo=bar&baz=quux&baz=grault"), - ("QueryKeyOnly", Method::POST, "/query_key_only?foo=bar"), - ("QueryKeyOnly", Method::POST, "/query_key_only?foo"), - ("QueryKeyOnly", Method::POST, "/query_key_only?foo="), - ("QueryKeyOnly", Method::POST, "/query_key_only?foo=&"), - ]; - for (svc_name, method, uri) in &hits { - assert_eq!(router.match_route(&req(method, uri, None)).unwrap(), *svc_name); - } - - for (_, _, uri) in hits { - let res = router.match_route(&req(&Method::PATCH, uri, None)); - assert_eq!(res.unwrap_err(), Error::MethodNotAllowed); - } - - let misses = vec![ - (Method::GET, "/a"), - (Method::GET, "/a/b"), - (Method::GET, "/mg"), - (Method::GET, "/mg/q"), - (Method::GET, "/mg/z"), - (Method::GET, "/mg/a/b/z/c"), - (Method::DELETE, "/?foo=bar"), - (Method::DELETE, "/?foo=bar"), - (Method::DELETE, "/?baz=quux"), - (Method::POST, "/query_key_only?baz=quux"), - (Method::GET, "/"), - (Method::POST, "/"), - ]; - for (method, miss) in misses { - let res = router.match_route(&req(&method, miss, None)); - assert_eq!(res.unwrap_err(), Error::NotFound); - } - } - - #[tokio::test] - async fn basic_pattern_conflict_avoidance() { - let request_specs: Vec<(RequestSpec, &'static str)> = vec![ - ( - RequestSpec::from_parts( - Method::GET, - vec![PathSegment::Literal(String::from("a")), PathSegment::Label], - Vec::new(), - ), - "A1", - ), - ( - RequestSpec::from_parts( - Method::GET, - vec![ - PathSegment::Literal(String::from("a")), - PathSegment::Label, - PathSegment::Literal(String::from("a")), - ], - Vec::new(), - ), - "A2", - ), - ( - RequestSpec::from_parts( - Method::GET, - vec![PathSegment::Literal(String::from("b")), PathSegment::Greedy], - Vec::new(), - ), - "B1", - ), - ( - RequestSpec::from_parts( - Method::GET, - vec![PathSegment::Literal(String::from("b")), PathSegment::Greedy], - vec![QuerySegment::Key(String::from("q"))], - ), - "B2", - ), - ]; - - let router: RestRouter<_> = request_specs - .into_iter() - .map(|(spec, svc_name)| (spec, svc_name)) - .collect(); - - let hits = vec![ - ("A1", Method::GET, "/a/foo"), - ("A2", Method::GET, "/a/foo/a"), - ("B1", Method::GET, "/b/foo/bar/baz"), - ("B2", Method::GET, "/b/foo?q=baz"), - ]; - for (svc_name, method, uri) in hits { - assert_eq!(router.match_route(&req(&method, uri, None)).unwrap(), svc_name); - } - } -} diff --git a/rust-runtime/aws-smithy-http-server/src/proto/rest_json_1/router.rs b/rust-runtime/aws-smithy-http-server/src/proto/rest_json_1/router.rs index 189658d317a..c737b665c49 100644 --- a/rust-runtime/aws-smithy-http-server/src/proto/rest_json_1/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/proto/rest_json_1/router.rs @@ -7,7 +7,7 @@ use crate::body::BoxBody; use crate::extension::RuntimeErrorExtension; use crate::proto::rest::router::Error; use crate::response::IntoResponse; -use crate::routing::{method_disallowed, UNKNOWN_OPERATION_EXCEPTION}; +use crate::routers::{method_disallowed, UNKNOWN_OPERATION_EXCEPTION}; use super::RestJson1; diff --git a/rust-runtime/aws-smithy-http-server/src/proto/rest_xml/router.rs b/rust-runtime/aws-smithy-http-server/src/proto/rest_xml/router.rs index b771884b04d..1b1b21742fc 100644 --- a/rust-runtime/aws-smithy-http-server/src/proto/rest_xml/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/proto/rest_xml/router.rs @@ -8,7 +8,8 @@ use crate::body::BoxBody; use crate::extension::RuntimeErrorExtension; use crate::proto::rest::router::Error; use crate::response::IntoResponse; -use crate::routing::{method_disallowed, UNKNOWN_OPERATION_EXCEPTION}; +use crate::routers::method_disallowed; +use crate::routers::UNKNOWN_OPERATION_EXCEPTION; use super::RestXml; diff --git a/rust-runtime/aws-smithy-http-server/src/protocols.rs b/rust-runtime/aws-smithy-http-server/src/protocols.rs index 2267c0384c3..bf56eea89fc 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocols.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocols.rs @@ -5,11 +5,20 @@ //! Protocol helpers. use crate::rejection::MissingContentTypeReason; +#[allow(deprecated)] +use crate::request::RequestParts; use http::HeaderMap; /// When there are no modeled inputs, /// a request body is empty and the content-type request header must not be set -pub fn content_type_header_empty_body_no_modeled_input(headers: &HeaderMap) -> Result<(), MissingContentTypeReason> { +#[allow(deprecated)] +pub fn content_type_header_empty_body_no_modeled_input( + req: &RequestParts, +) -> Result<(), MissingContentTypeReason> { + if req.headers().is_none() { + return Ok(()); + } + let headers = req.headers().unwrap(); if headers.contains_key(http::header::CONTENT_TYPE) { let found_mime = parse_content_type(headers)?; Err(MissingContentTypeReason::UnexpectedMimeType { @@ -33,10 +42,15 @@ fn parse_content_type(headers: &HeaderMap) -> Result( + req: &RequestParts, expected_content_type: Option<&'static str>, ) -> Result<(), MissingContentTypeReason> { + // Allow no CONTENT-TYPE header + if req.headers().is_none() { + return Ok(()); + } + let headers = req.headers().unwrap(); // Headers are present, `unwrap` will not panic. if !headers.contains_key(http::header::CONTENT_TYPE) { return Ok(()); } @@ -65,7 +79,12 @@ pub fn content_type_header_classifier( } #[allow(deprecated)] -pub fn accept_header_classifier(headers: &HeaderMap, content_type: &'static str) -> bool { +pub fn accept_header_classifier(req: &RequestParts, content_type: &'static str) -> bool { + // Allow no ACCEPT header + if req.headers().is_none() { + return true; + } + let headers = req.headers().unwrap(); if !headers.contains_key(http::header::ACCEPT) { return true; } @@ -107,25 +126,28 @@ pub fn accept_header_classifier(headers: &HeaderMap, content_type: &'static str) #[cfg(test)] mod tests { use super::*; - use http::header::{HeaderValue, ACCEPT, CONTENT_TYPE}; + use http::Request; - fn req_content_type(content_type: &'static str) -> HeaderMap { - let mut headers = HeaderMap::new(); - headers.insert(CONTENT_TYPE, HeaderValue::from_str(content_type).unwrap()); - headers + fn req_content_type(content_type: &str) -> RequestParts<&str> { + let request = Request::builder() + .header("content-type", content_type) + .body("") + .unwrap(); + RequestParts::new(request) } - fn req_accept(accept: &'static str) -> HeaderMap { - let mut headers = HeaderMap::new(); - headers.insert(ACCEPT, HeaderValue::from_static(accept)); - headers + fn req_accept(content_type: &str) -> RequestParts<&str> { + let request = Request::builder().header("accept", content_type).body("").unwrap(); + RequestParts::new(request) } const EXPECTED_MIME_APPLICATION_JSON: Option<&'static str> = Some("application/json"); #[test] fn check_content_type_header_empty_body_no_modeled_input() { - assert!(content_type_header_empty_body_no_modeled_input(&HeaderMap::new()).is_ok()); + let request = Request::builder().body("").unwrap(); + let request = RequestParts::new(request); + assert!(content_type_header_empty_body_no_modeled_input(&request).is_ok()); } #[test] @@ -171,7 +193,8 @@ mod tests { #[test] fn check_missing_content_type_is_allowed() { - let result = content_type_header_classifier(&HeaderMap::new(), EXPECTED_MIME_APPLICATION_JSON); + let request = RequestParts::new(Request::builder().body("").unwrap()); + let result = content_type_header_classifier(&request, EXPECTED_MIME_APPLICATION_JSON); assert!(result.is_ok()); } @@ -218,7 +241,9 @@ mod tests { #[test] fn valid_empty_accept_header_classifier() { - assert!(accept_header_classifier(&HeaderMap::new(), "application/json")); + let valid_request = Request::builder().body("").unwrap(); + let valid_request = RequestParts::new(valid_request); + assert!(accept_header_classifier(&valid_request, "application/json")); } #[test] diff --git a/rust-runtime/aws-smithy-http-server/src/rejection.rs b/rust-runtime/aws-smithy-http-server/src/rejection.rs index f01344f7657..65fb8e2e620 100644 --- a/rust-runtime/aws-smithy-http-server/src/rejection.rs +++ b/rust-runtime/aws-smithy-http-server/src/rejection.rs @@ -9,11 +9,13 @@ //! handle requests and responses that return `Result` throughout the framework. These //! include functions to deserialize incoming requests and serialize outgoing responses. //! -//! All types end with `Rejection`. There are two types: +//! All types end with `Rejection`. There are three types: //! //! 1. [`RequestRejection`]s are used when the framework fails to deserialize the request into the //! corresponding operation input. -//! 2. [`ResponseRejection`]s are used when the framework fails to serialize the operation +//! 1. [`RequestExtensionNotFoundRejection`]s are used when the framework fails to deserialize from +//! the request's extensions a particular [`crate::Extension`] that was expected to be found. +//! 1. [`ResponseRejection`]s are used when the framework fails to serialize the operation //! output into a response. //! //! They are called _rejection_ types and not _error_ types to signal that the input was _rejected_ @@ -39,10 +41,35 @@ //! [`crate::runtime_error::RuntimeError`], thus allowing us to represent the full //! error chain. +// For some reason `deprecated(deprecated)` warns of its own deprecation. Putting `allow(deprecated)` at the module +// level remedies it. +#![allow(deprecated)] + use strum_macros::Display; use crate::response::IntoResponse; +/// Rejection used for when failing to extract an [`crate::Extension`] from an incoming [request's +/// extensions]. Contains one variant for each way the extractor can fail. +/// +/// [request's extensions]: https://docs.rs/http/latest/http/struct.Extensions.html +#[deprecated( + since = "0.52.0", + note = "This was used for extraction under the older service builder. The `MissingExtension` struct returned by `FromParts::from_parts` is now used." +)] +#[derive(Debug, Display)] +pub enum RequestExtensionNotFoundRejection { + /// Used when a particular [`crate::Extension`] was expected to be found in the request but we + /// did not find it. + /// This most likely means the service implementer simply forgot to add a [`tower::Layer`] that + /// registers the particular extension in their service to incoming requests. + MissingExtension(String), + // Used when the request extensions have already been taken by another extractor. + ExtensionsAlreadyExtracted, +} + +impl std::error::Error for RequestExtensionNotFoundRejection {} + /// Errors that can occur when serializing the operation output provided by the service implementer /// into an HTTP response. #[derive(Debug, Display)] @@ -77,7 +104,8 @@ convert_to_response_rejection!(aws_smithy_http::operation::error::SerializationE convert_to_response_rejection!(http::Error, Http); /// Errors that can occur when deserializing an HTTP request into an _operation input_, the input -/// that is passed as the first argument to operation handlers. +/// that is passed as the first argument to operation handlers. To deserialize into the service's +/// registered state, a different rejection type is used, [`RequestExtensionNotFoundRejection`]. /// /// This type allows us to easily keep track of all the possible errors that can occur in the /// lifecycle of an incoming HTTP request. @@ -103,6 +131,10 @@ convert_to_response_rejection!(http::Error, Http); // The variants are _roughly_ sorted in the order in which the HTTP request is processed. #[derive(Debug, Display)] pub enum RequestRejection { + /// Used when attempting to take the request's body, and it has already been taken (presumably + /// by an outer `Service` that handled the request before us). + BodyAlreadyExtracted, + /// Used when failing to convert non-streaming requests into a byte slab with /// `hyper::body::to_bytes`. HttpBody(crate::Error), @@ -117,6 +149,10 @@ pub enum RequestRejection { /// input it should represent. XmlDeserialize(crate::Error), + /// Used when attempting to take the request's headers, and they have already been taken (presumably + /// by an outer `Service` that handled the request before us). + HeadersAlreadyExtracted, + /// Used when failing to parse HTTP headers that are bound to input members with the `httpHeader` /// or the `httpPrefixHeaders` traits. HeaderParse(crate::Error), diff --git a/rust-runtime/aws-smithy-http-server/src/request/mod.rs b/rust-runtime/aws-smithy-http-server/src/request/mod.rs index be507b804fc..1b08c632895 100644 --- a/rust-runtime/aws-smithy-http-server/src/request/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/request/mod.rs @@ -54,7 +54,7 @@ use futures_util::{ future::{try_join, MapErr, MapOk, TryJoin}, TryFutureExt, }; -use http::{request::Parts, Request, StatusCode}; +use http::{request::Parts, Extensions, HeaderMap, Request, StatusCode, Uri}; use crate::{ body::{empty, BoxBody}, @@ -77,6 +77,77 @@ fn internal_server_error() -> http::Response { response } +#[doc(hidden)] +#[deprecated( + since = "0.52.0", + note = "This is not used by the new service builder. We use the `http::Parts` struct directly." +)] +#[derive(Debug)] +pub struct RequestParts { + uri: Uri, + headers: Option, + extensions: Option, + body: Option, +} + +#[allow(deprecated)] +impl RequestParts { + /// Create a new `RequestParts`. + /// + /// You generally shouldn't need to construct this type yourself, unless + /// using extractors outside of axum for example to implement a + /// [`tower::Service`]. + /// + /// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html + #[doc(hidden)] + pub fn new(req: Request) -> Self { + let ( + Parts { + uri, + headers, + extensions, + .. + }, + body, + ) = req.into_parts(); + + RequestParts { + uri, + headers: Some(headers), + extensions: Some(extensions), + body: Some(body), + } + } + + /// Gets a reference to the request headers. + /// + /// Returns `None` if the headers has been taken by another extractor. + #[doc(hidden)] + pub fn headers(&self) -> Option<&HeaderMap> { + self.headers.as_ref() + } + + /// Takes the body out of the request, leaving a `None` in its place. + #[doc(hidden)] + pub fn take_body(&mut self) -> Option { + self.body.take() + } + + /// Gets a reference the request URI. + #[doc(hidden)] + pub fn uri(&self) -> &Uri { + &self.uri + } + + /// Gets a reference to the request extensions. + /// + /// Returns `None` if the extensions has been taken by another extractor. + #[doc(hidden)] + pub fn extensions(&self) -> Option<&Extensions> { + self.extensions.as_ref() + } +} + /// Provides a protocol aware extraction from a [`Request`]. This borrows the [`Parts`], in contrast to /// [`FromRequest`] which consumes the entire [`http::Request`] including the body. pub trait FromParts: Sized { diff --git a/rust-runtime/aws-smithy-http-server/src/routers.rs b/rust-runtime/aws-smithy-http-server/src/routers.rs new file mode 100644 index 00000000000..ecffe36e0cd --- /dev/null +++ b/rust-runtime/aws-smithy-http-server/src/routers.rs @@ -0,0 +1,176 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use std::{ + error::Error, + fmt, + future::{ready, Future, Ready}, + marker::PhantomData, + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use futures_util::{ + future::{Either, MapOk}, + TryFutureExt, +}; +use http::Response; +use http_body::Body as HttpBody; +use tower::{util::Oneshot, Service, ServiceExt}; +use tracing::debug; + +use crate::{ + body::{boxed, BoxBody}, + error::BoxError, + response::IntoResponse, +}; + +pub(crate) const UNKNOWN_OPERATION_EXCEPTION: &str = "UnknownOperationException"; + +/// Constructs common response to method disallowed. +pub(crate) fn method_disallowed() -> http::Response { + let mut responses = http::Response::default(); + *responses.status_mut() = http::StatusCode::METHOD_NOT_ALLOWED; + responses +} + +/// An interface for retrieving an inner [`Service`] given a [`http::Request`]. +pub trait Router { + type Service; + type Error; + + /// Matches a [`http::Request`] to a target [`Service`]. + fn match_route(&self, request: &http::Request) -> Result; +} + +/// A [`Service`] using the [`Router`] `R` to redirect messages to specific routes. +/// +/// The `Protocol` parameter is used to determine the serialization of errors. +pub struct RoutingService { + router: R, + _protocol: PhantomData, +} + +impl fmt::Debug for RoutingService +where + R: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RoutingService") + .field("router", &self.router) + .field("_protocol", &self._protocol) + .finish() + } +} + +impl Clone for RoutingService +where + R: Clone, +{ + fn clone(&self) -> Self { + Self { + router: self.router.clone(), + _protocol: PhantomData, + } + } +} + +impl RoutingService { + /// Creates a [`RoutingService`] from a [`Router`]. + pub fn new(router: R) -> Self { + Self { + router, + _protocol: PhantomData, + } + } + + /// Maps a [`Router`] using a closure. + pub fn map(self, f: F) -> RoutingService + where + F: FnOnce(R) -> RNew, + { + RoutingService { + router: f(self.router), + _protocol: PhantomData, + } + } +} + +type EitherOneshotReady = Either< + MapOk>, fn(>>::Response) -> http::Response>, + Ready, >>::Error>>, +>; + +pin_project_lite::pin_project! { + pub struct RoutingFuture where S: Service> { + #[pin] + inner: EitherOneshotReady + } +} + +impl RoutingFuture +where + S: Service>, +{ + /// Creates a [`RoutingFuture`] from [`ServiceExt::oneshot`]. + pub(super) fn from_oneshot(future: Oneshot>) -> Self + where + S: Service, Response = http::Response>, + RespB: HttpBody + Send + 'static, + RespB::Error: Into, + { + Self { + inner: Either::Left(future.map_ok(|x| x.map(boxed))), + } + } + + /// Creates a [`RoutingFuture`] from [`Service::Response`]. + pub(super) fn from_response(response: http::Response) -> Self { + Self { + inner: Either::Right(ready(Ok(response))), + } + } +} + +impl Future for RoutingFuture +where + S: Service>, +{ + type Output = Result, S::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().inner.poll(cx) + } +} + +impl Service> for RoutingService +where + R: Router, + R::Service: Service, Response = http::Response> + Clone, + R::Error: IntoResponse

+ Error, + RespB: HttpBody + Send + 'static, + RespB::Error: Into, +{ + type Response = Response; + type Error = >>::Error; + type Future = RoutingFuture; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: http::Request) -> Self::Future { + match self.router.match_route(&req) { + // Successfully routed, use the routes `Service::call`. + Ok(ok) => RoutingFuture::from_oneshot(ok.oneshot(req)), + // Failed to route, use the `R::Error`s `IntoResponse

`. + Err(error) => { + debug!(%error, "failed to route"); + RoutingFuture::from_response(error.into_response()) + } + } + } +} diff --git a/rust-runtime/aws-smithy-http-server/src/routing/future.rs b/rust-runtime/aws-smithy-http-server/src/routing/future.rs new file mode 100644 index 00000000000..dcb1d83c532 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server/src/routing/future.rs @@ -0,0 +1,51 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +// This code was copied and then modified from Tokio's Axum. + +/* Copyright (c) 2021 Tower Contributors + * + * Permission is hereby granted, free of charge, to any + * person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the + * Software without restriction, including without + * limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of + * the Software, and to permit persons to whom the Software + * is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice + * shall be included in all copies or substantial portions + * of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF + * ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED + * TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A + * PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT + * SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR + * IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#![allow(deprecated)] + +//! Future types. + +use crate::routers::RoutingFuture; + +use super::Route; +pub use super::{into_make_service::IntoMakeService, route::RouteFuture}; + +opaque_future! { + #[deprecated( + since = "0.52.0", + note = "`OperationRegistry` is part of the deprecated service builder API. This type no longer appears in the public API." + )] + /// Response future for [`Router`](super::Router). + pub type RouterFuture = RoutingFuture, B>; +} diff --git a/rust-runtime/aws-smithy-http-server/src/routing/mod.rs b/rust-runtime/aws-smithy-http-server/src/routing/mod.rs index 4f0cb8c0fa7..2ff9c82ef5c 100644 --- a/rust-runtime/aws-smithy-http-server/src/routing/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/routing/mod.rs @@ -7,6 +7,27 @@ //! //! [Smithy specification]: https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html +use std::{ + convert::Infallible, + task::{Context, Poll}, +}; + +use self::request_spec::RequestSpec; +use crate::{ + body::{boxed, Body, BoxBody, HttpBody}, + proto::{ + aws_json::router::AwsJsonRouter, aws_json_10::AwsJson1_0, aws_json_11::AwsJson1_1, rest::router::RestRouter, + rest_json_1::RestJson1, rest_xml::RestXml, + }, +}; +use crate::{error::BoxError, routers::RoutingService}; + +use http::{Request, Response}; +use tower::layer::Layer; +use tower::{Service, ServiceBuilder}; +use tower_http::map_response_body::MapResponseBodyLayer; + +mod future; mod into_make_service; mod into_make_service_with_connect_info; #[cfg(feature = "aws-lambda")] @@ -20,185 +41,571 @@ mod route; pub(crate) mod tiny_map; -use std::{ - error::Error, - fmt, - future::{ready, Future, Ready}, - marker::PhantomData, - pin::Pin, - task::{Context, Poll}, -}; - -use bytes::Bytes; -use futures_util::{ - future::{Either, MapOk}, - TryFutureExt, -}; -use http::Response; -use http_body::Body as HttpBody; -use tower::{util::Oneshot, Service, ServiceExt}; -use tracing::debug; - -use crate::{ - body::{boxed, BoxBody}, - error::BoxError, - response::IntoResponse, -}; - #[cfg(feature = "aws-lambda")] #[cfg_attr(docsrs, doc(cfg(feature = "aws-lambda")))] pub use self::lambda_handler::LambdaHandler; #[allow(deprecated)] pub use self::{ + future::RouterFuture, into_make_service::IntoMakeService, into_make_service_with_connect_info::{Connected, IntoMakeServiceWithConnectInfo}, route::Route, }; -pub(crate) const UNKNOWN_OPERATION_EXCEPTION: &str = "UnknownOperationException"; - -/// Constructs common response to method disallowed. -pub(crate) fn method_disallowed() -> http::Response { - let mut responses = http::Response::default(); - *responses.status_mut() = http::StatusCode::METHOD_NOT_ALLOWED; - responses -} - -/// An interface for retrieving an inner [`Service`] given a [`http::Request`]. -pub trait Router { - type Service; - type Error; - - /// Matches a [`http::Request`] to a target [`Service`]. - fn match_route(&self, request: &http::Request) -> Result; +/// The router is a [`tower::Service`] that routes incoming requests to other `Service`s +/// based on the request's URI and HTTP method or on some specific header setting the target operation. +/// The former is adhering to the [Smithy specification], while the latter is adhering to +/// the [AwsJson specification]. +/// +/// The router is also [Protocol] aware and currently supports REST based protocols like [restJson1] or [restXml] +/// and RPC based protocols like [awsJson1.0] or [awsJson1.1]. +/// It currently does not support Smithy's [endpoint trait]. +/// +/// You should not **instantiate** this router directly; it will be created for you from the +/// code generated from your Smithy model by `smithy-rs`. +/// +/// [Smithy specification]: https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html +/// [AwsJson specification]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#protocol-behaviors +/// [Protocol]: https://awslabs.github.io/smithy/1.0/spec/aws/index.html#aws-protocols +/// [restJson1]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-restjson1-protocol.html +/// [restXml]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-restxml-protocol.html +/// [awsJson1.0]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html +/// [awsJson1.1]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html +/// [endpoint trait]: https://awslabs.github.io/smithy/1.0/spec/core/endpoint-traits.html#endpoint-trait +#[derive(Debug)] +#[deprecated( + since = "0.52.0", + note = "`OperationRegistry` is part of the deprecated service builder API. This type no longer appears in the public API." +)] +pub struct Router { + routes: Routes, } -/// A [`Service`] using the [`Router`] `R` to redirect messages to specific routes. +/// Protocol-aware routes types. +/// +/// RestJson1 and RestXml routes are stored in a `Vec` because there can be multiple matches on the +/// request URI and we thus need to iterate the whole list and use a ranking mechanism to choose. /// -/// The `Protocol` parameter is used to determine the serialization of errors. -pub struct RoutingService { - router: R, - _protocol: PhantomData, +/// AwsJson 1.0 and 1.1 routes can be stored in a `HashMap` since the requested operation can be +/// directly found in the `X-Amz-Target` HTTP header. +#[derive(Debug)] +enum Routes { + RestXml(RoutingService>, RestXml>), + RestJson1(RoutingService>, RestJson1>), + AwsJson1_0(RoutingService>, AwsJson1_0>), + AwsJson1_1(RoutingService>, AwsJson1_1>), } -impl fmt::Debug for RoutingService -where - R: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("RoutingService") - .field("router", &self.router) - .field("_protocol", &self._protocol) - .finish() +#[allow(deprecated)] +impl Clone for Router { + fn clone(&self) -> Self { + match &self.routes { + Routes::RestJson1(routes) => Router { + routes: Routes::RestJson1(routes.clone()), + }, + Routes::RestXml(routes) => Router { + routes: Routes::RestXml(routes.clone()), + }, + Routes::AwsJson1_0(routes) => Router { + routes: Routes::AwsJson1_0(routes.clone()), + }, + Routes::AwsJson1_1(routes) => Router { + routes: Routes::AwsJson1_1(routes.clone()), + }, + } } } -impl Clone for RoutingService +#[allow(deprecated)] +impl Router where - R: Clone, + B: Send + 'static, { - fn clone(&self) -> Self { + /// Convert this router into a [`MakeService`], that is a [`Service`] whose + /// response is another service. + /// + /// This is useful when running your application with hyper's + /// [`Server`]. + /// + /// [`Server`]: hyper::server::Server + /// [`MakeService`]: tower::make::MakeService + pub fn into_make_service(self) -> IntoMakeService { + IntoMakeService::new(self) + } + + /// Apply a [`tower::Layer`] to the router. + /// + /// All requests to the router will be processed by the layer's + /// corresponding middleware. + /// + /// This can be used to add additional processing to all routes. + pub fn layer(self, layer: L) -> Router + where + L: Layer>, + L::Service: + Service, Response = Response, Error = Infallible> + Clone + Send + 'static, + >>::Future: Send + 'static, + NewResBody: HttpBody + Send + 'static, + NewResBody::Error: Into, + { + let layer = ServiceBuilder::new() + .layer(MapResponseBodyLayer::new(boxed)) + .layer(layer); + match self.routes { + Routes::RestJson1(routes) => Router { + routes: Routes::RestJson1(routes.map(|router| router.layer(layer).boxed())), + }, + Routes::RestXml(routes) => Router { + routes: Routes::RestXml(routes.map(|router| router.layer(layer).boxed())), + }, + Routes::AwsJson1_0(routes) => Router { + routes: Routes::AwsJson1_0(routes.map(|router| router.layer(layer).boxed())), + }, + Routes::AwsJson1_1(routes) => Router { + routes: Routes::AwsJson1_1(routes.map(|router| router.layer(layer).boxed())), + }, + } + } + + /// Create a new RestJson1 `Router` from an iterator over pairs of [`RequestSpec`]s and services. + /// + /// If the iterator is empty the router will respond `404 Not Found` to all requests. + #[doc(hidden)] + pub fn new_rest_json_router(routes: T) -> Self + where + T: IntoIterator< + Item = ( + tower::util::BoxCloneService, Response, Infallible>, + RequestSpec, + ), + >, + { + let svc = RoutingService::new( + routes + .into_iter() + .map(|(svc, request_spec)| (request_spec, Route::from_box_clone_service(svc))) + .collect(), + ); Self { - router: self.router.clone(), - _protocol: PhantomData, + routes: Routes::RestJson1(svc), } } -} -impl RoutingService { - /// Creates a [`RoutingService`] from a [`Router`]. - pub fn new(router: R) -> Self { + /// Create a new RestXml `Router` from an iterator over pairs of [`RequestSpec`]s and services. + /// + /// If the iterator is empty the router will respond `404 Not Found` to all requests. + #[doc(hidden)] + pub fn new_rest_xml_router(routes: T) -> Self + where + T: IntoIterator< + Item = ( + tower::util::BoxCloneService, Response, Infallible>, + RequestSpec, + ), + >, + { + let svc = RoutingService::new( + routes + .into_iter() + .map(|(svc, request_spec)| (request_spec, Route::from_box_clone_service(svc))) + .collect(), + ); Self { - router, - _protocol: PhantomData, + routes: Routes::RestXml(svc), } } - /// Maps a [`Router`] using a closure. - pub fn map(self, f: F) -> RoutingService + /// Create a new AwsJson 1.0 `Router` from an iterator over pairs of operation names and services. + /// + /// If the iterator is empty the router will respond `404 Not Found` to all requests. + #[doc(hidden)] + pub fn new_aws_json_10_router(routes: T) -> Self where - F: FnOnce(R) -> RNew, + T: IntoIterator< + Item = ( + tower::util::BoxCloneService, Response, Infallible>, + String, + ), + >, { - RoutingService { - router: f(self.router), - _protocol: PhantomData, + let svc = RoutingService::new( + routes + .into_iter() + .map(|(svc, operation)| (operation, Route::from_box_clone_service(svc))) + .collect(), + ); + + Self { + routes: Routes::AwsJson1_0(svc), } } -} -type EitherOneshotReady = Either< - MapOk>, fn(>>::Response) -> http::Response>, - Ready, >>::Error>>, ->; + /// Create a new AwsJson 1.1 `Router` from a vector of pairs of operations and services. + /// + /// If the vector is empty the router will respond `404 Not Found` to all requests. + #[doc(hidden)] + pub fn new_aws_json_11_router(routes: T) -> Self + where + T: IntoIterator< + Item = ( + tower::util::BoxCloneService, Response, Infallible>, + String, + ), + >, + { + let svc = RoutingService::new( + routes + .into_iter() + .map(|(svc, operation)| (operation, Route::from_box_clone_service(svc))) + .collect(), + ); -pin_project_lite::pin_project! { - pub struct RoutingFuture where S: Service> { - #[pin] - inner: EitherOneshotReady + Self { + routes: Routes::AwsJson1_1(svc), + } } } -impl RoutingFuture +#[allow(deprecated)] +impl Service> for Router where - S: Service>, + B: Send + 'static, { - /// Creates a [`RoutingFuture`] from [`ServiceExt::oneshot`]. - pub(super) fn from_oneshot(future: Oneshot>) -> Self + type Response = Response; + type Error = Infallible; + type Future = RouterFuture; + + #[inline] + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + #[inline] + fn call(&mut self, req: Request) -> Self::Future { + let fut = match &mut self.routes { + // REST routes. + Routes::RestJson1(routes) => routes.call(req), + Routes::RestXml(routes) => routes.call(req), + // AwsJson routes. + Routes::AwsJson1_0(routes) => routes.call(req), + Routes::AwsJson1_1(routes) => routes.call(req), + }; + RouterFuture::new(fut) + } +} + +#[cfg(test)] +#[allow(deprecated)] +mod rest_tests { + use super::*; + use crate::{ + body::{boxed, BoxBody}, + routing::request_spec::*, + }; + use futures_util::Future; + use http::{HeaderMap, Method, StatusCode}; + use std::pin::Pin; + + /// Helper function to build a `Request`. Used in other test modules. + pub fn req(method: &Method, uri: &str, headers: Option) -> Request<()> { + let mut r = Request::builder().method(method).uri(uri).body(()).unwrap(); + if let Some(headers) = headers { + *r.headers_mut() = headers + } + r + } + + // Returns a `Response`'s body as a `String`, without consuming the response. + pub async fn get_body_as_string(res: &mut Response) -> String where - S: Service, Response = http::Response>, - RespB: HttpBody + Send + 'static, - RespB::Error: Into, + B: http_body::Body + std::marker::Unpin, + B::Error: std::fmt::Debug, { - Self { - inner: Either::Left(future.map_ok(|x| x.map(boxed))), + let body_mut = res.body_mut(); + let body_bytes = hyper::body::to_bytes(body_mut).await.unwrap(); + String::from(std::str::from_utf8(&body_bytes).unwrap()) + } + + /// A service that returns its name and the request's URI in the response body. + #[derive(Clone)] + struct NamedEchoUriService(String); + + impl Service> for NamedEchoUriService { + type Response = Response; + type Error = Infallible; + type Future = Pin> + Send>>; + + #[inline] + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + #[inline] + fn call(&mut self, req: Request) -> Self::Future { + let body = boxed(Body::from(format!("{} :: {}", self.0, req.uri()))); + let fut = async { Ok(Response::builder().status(&http::StatusCode::OK).body(body).unwrap()) }; + Box::pin(fut) } } - /// Creates a [`RoutingFuture`] from [`Service::Response`]. - pub(super) fn from_response(response: http::Response) -> Self { - Self { - inner: Either::Right(ready(Ok(response))), + // This test is a rewrite of `mux.spec.ts`. + // https://github.com/awslabs/smithy-typescript/blob/fbf97a9bf4c1d8cf7f285ea7c24e1f0ef280142a/smithy-typescript-ssdk-libs/server-common/src/httpbinding/mux.spec.ts + #[tokio::test] + async fn simple_routing() { + let request_specs: Vec<(RequestSpec, &str)> = vec![ + ( + RequestSpec::from_parts( + Method::GET, + vec![ + PathSegment::Literal(String::from("a")), + PathSegment::Label, + PathSegment::Label, + ], + Vec::new(), + ), + "A", + ), + ( + RequestSpec::from_parts( + Method::GET, + vec![ + PathSegment::Literal(String::from("mg")), + PathSegment::Greedy, + PathSegment::Literal(String::from("z")), + ], + Vec::new(), + ), + "MiddleGreedy", + ), + ( + RequestSpec::from_parts( + Method::DELETE, + Vec::new(), + vec![ + QuerySegment::KeyValue(String::from("foo"), String::from("bar")), + QuerySegment::Key(String::from("baz")), + ], + ), + "Delete", + ), + ( + RequestSpec::from_parts( + Method::POST, + vec![PathSegment::Literal(String::from("query_key_only"))], + vec![QuerySegment::Key(String::from("foo"))], + ), + "QueryKeyOnly", + ), + ]; + + // Test both RestJson1 and RestXml routers. + let router_json = Router::new_rest_json_router(request_specs.clone().into_iter().map(|(spec, svc_name)| { + ( + tower::util::BoxCloneService::new(NamedEchoUriService(String::from(svc_name))), + spec, + ) + })); + let router_xml = Router::new_rest_xml_router(request_specs.into_iter().map(|(spec, svc_name)| { + ( + tower::util::BoxCloneService::new(NamedEchoUriService(String::from(svc_name))), + spec, + ) + })); + + for mut router in [router_json, router_xml] { + let hits = vec![ + ("A", Method::GET, "/a/b/c"), + ("MiddleGreedy", Method::GET, "/mg/a/z"), + ("MiddleGreedy", Method::GET, "/mg/a/b/c/d/z?abc=def"), + ("Delete", Method::DELETE, "/?foo=bar&baz=quux"), + ("Delete", Method::DELETE, "/?foo=bar&baz"), + ("Delete", Method::DELETE, "/?foo=bar&baz=&"), + ("Delete", Method::DELETE, "/?foo=bar&baz=quux&baz=grault"), + ("QueryKeyOnly", Method::POST, "/query_key_only?foo=bar"), + ("QueryKeyOnly", Method::POST, "/query_key_only?foo"), + ("QueryKeyOnly", Method::POST, "/query_key_only?foo="), + ("QueryKeyOnly", Method::POST, "/query_key_only?foo=&"), + ]; + for (svc_name, method, uri) in &hits { + let mut res = router.call(req(method, uri, None)).await.unwrap(); + let actual_body = get_body_as_string(&mut res).await; + + assert_eq!(format!("{} :: {}", svc_name, uri), actual_body); + } + + for (_, _, uri) in hits { + let res = router.call(req(&Method::PATCH, uri, None)).await.unwrap(); + assert_eq!(StatusCode::METHOD_NOT_ALLOWED, res.status()); + } + + let misses = vec![ + (Method::GET, "/a"), + (Method::GET, "/a/b"), + (Method::GET, "/mg"), + (Method::GET, "/mg/q"), + (Method::GET, "/mg/z"), + (Method::GET, "/mg/a/b/z/c"), + (Method::DELETE, "/?foo=bar"), + (Method::DELETE, "/?foo=bar"), + (Method::DELETE, "/?baz=quux"), + (Method::POST, "/query_key_only?baz=quux"), + (Method::GET, "/"), + (Method::POST, "/"), + ]; + for (method, miss) in misses { + let res = router.call(req(&method, miss, None)).await.unwrap(); + assert_eq!(StatusCode::NOT_FOUND, res.status()); + } } } -} -impl Future for RoutingFuture -where - S: Service>, -{ - type Output = Result, S::Error>; + #[tokio::test] + async fn basic_pattern_conflict_avoidance() { + let request_specs: Vec<(RequestSpec, &str)> = vec![ + ( + RequestSpec::from_parts( + Method::GET, + vec![PathSegment::Literal(String::from("a")), PathSegment::Label], + Vec::new(), + ), + "A1", + ), + ( + RequestSpec::from_parts( + Method::GET, + vec![ + PathSegment::Literal(String::from("a")), + PathSegment::Label, + PathSegment::Literal(String::from("a")), + ], + Vec::new(), + ), + "A2", + ), + ( + RequestSpec::from_parts( + Method::GET, + vec![PathSegment::Literal(String::from("b")), PathSegment::Greedy], + Vec::new(), + ), + "B1", + ), + ( + RequestSpec::from_parts( + Method::GET, + vec![PathSegment::Literal(String::from("b")), PathSegment::Greedy], + vec![QuerySegment::Key(String::from("q"))], + ), + "B2", + ), + ]; + + let mut router = Router::new_rest_json_router(request_specs.into_iter().map(|(spec, svc_name)| { + ( + tower::util::BoxCloneService::new(NamedEchoUriService(String::from(svc_name))), + spec, + ) + })); - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.project().inner.poll(cx) + let hits = vec![ + ("A1", Method::GET, "/a/foo"), + ("A2", Method::GET, "/a/foo/a"), + ("B1", Method::GET, "/b/foo/bar/baz"), + ("B2", Method::GET, "/b/foo?q=baz"), + ]; + for (svc_name, method, uri) in &hits { + let mut res = router.call(req(method, uri, None)).await.unwrap(); + let actual_body = get_body_as_string(&mut res).await; + + assert_eq!(format!("{} :: {}", svc_name, uri), actual_body); + } } } -impl Service> for RoutingService -where - R: Router, - R::Service: Service, Response = http::Response> + Clone, - R::Error: IntoResponse

+ Error, - RespB: HttpBody + Send + 'static, - RespB::Error: Into, -{ - type Response = Response; - type Error = >>::Error; - type Future = RoutingFuture; +#[allow(deprecated)] +#[cfg(test)] +mod awsjson_tests { + use super::rest_tests::{get_body_as_string, req}; + use super::*; + use crate::body::boxed; + use futures_util::Future; + use http::{HeaderMap, HeaderValue, Method, StatusCode}; + use pretty_assertions::assert_eq; + use std::pin::Pin; - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) + /// A service that returns its name and the request's URI in the response body. + #[derive(Clone)] + struct NamedEchoOperationService(String); + + impl Service> for NamedEchoOperationService { + type Response = Response; + type Error = Infallible; + type Future = Pin> + Send>>; + + #[inline] + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + #[inline] + fn call(&mut self, req: Request) -> Self::Future { + let target = req + .headers() + .get("x-amz-target") + .map(|x| x.to_str().unwrap()) + .unwrap_or("unknown"); + let body = boxed(Body::from(format!("{} :: {}", self.0, target))); + let fut = async { Ok(Response::builder().status(&http::StatusCode::OK).body(body).unwrap()) }; + Box::pin(fut) + } } - fn call(&mut self, req: http::Request) -> Self::Future { - match self.router.match_route(&req) { - // Successfully routed, use the routes `Service::call`. - Ok(ok) => RoutingFuture::from_oneshot(ok.oneshot(req)), - // Failed to route, use the `R::Error`s `IntoResponse

`. - Err(error) => { - debug!(%error, "failed to route"); - RoutingFuture::from_response(error.into_response()) - } + #[tokio::test] + async fn simple_routing() { + let routes = vec![("Service.Operation", "A")]; + let router_json10 = Router::new_aws_json_10_router(routes.clone().into_iter().map(|(operation, svc_name)| { + ( + tower::util::BoxCloneService::new(NamedEchoOperationService(String::from(svc_name))), + operation.to_string(), + ) + })); + let router_json11 = Router::new_aws_json_11_router(routes.into_iter().map(|(operation, svc_name)| { + ( + tower::util::BoxCloneService::new(NamedEchoOperationService(String::from(svc_name))), + operation.to_string(), + ) + })); + + for mut router in [router_json10, router_json11] { + let mut headers = HeaderMap::new(); + headers.insert("x-amz-target", HeaderValue::from_static("Service.Operation")); + + // Valid request, should return a valid body. + let mut res = router + .call(req(&Method::POST, "/", Some(headers.clone()))) + .await + .unwrap(); + let actual_body = get_body_as_string(&mut res).await; + assert_eq!(format!("{} :: {}", "A", "Service.Operation"), actual_body); + + // No headers, should return NOT_FOUND. + let res = router.call(req(&Method::POST, "/", None)).await.unwrap(); + assert_eq!(res.status(), StatusCode::NOT_FOUND); + + // Wrong HTTP method, should return METHOD_NOT_ALLOWED. + let res = router + .call(req(&Method::GET, "/", Some(headers.clone()))) + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED); + + // Wrong URI, should return NOT_FOUND. + let res = router + .call(req(&Method::POST, "/something", Some(headers))) + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::NOT_FOUND); } } } diff --git a/rust-runtime/aws-smithy-http-server/src/routing/request_spec.rs b/rust-runtime/aws-smithy-http-server/src/routing/request_spec.rs index 84f431c24d5..bbabefb0193 100644 --- a/rust-runtime/aws-smithy-http-server/src/routing/request_spec.rs +++ b/rust-runtime/aws-smithy-http-server/src/routing/request_spec.rs @@ -250,9 +250,8 @@ impl RequestSpec { #[cfg(test)] mod tests { + use super::super::rest_tests::req; use super::*; - use crate::proto::test_helpers::req; - use http::Method; #[test] diff --git a/rust-runtime/aws-smithy-http-server/src/routing/route.rs b/rust-runtime/aws-smithy-http-server/src/routing/route.rs index 7eda401f61b..67d7ec53530 100644 --- a/rust-runtime/aws-smithy-http-server/src/routing/route.rs +++ b/rust-runtime/aws-smithy-http-server/src/routing/route.rs @@ -64,6 +64,10 @@ impl Route { service: BoxCloneService::new(svc), } } + + pub(super) fn from_box_clone_service(svc: BoxCloneService, Response, Infallible>) -> Self { + Self { service: svc } + } } impl Clone for Route { diff --git a/rust-runtime/aws-smithy-http-server/src/runtime_error.rs b/rust-runtime/aws-smithy-http-server/src/runtime_error.rs index 9fc1b8e1a77..1201ca61aae 100644 --- a/rust-runtime/aws-smithy-http-server/src/runtime_error.rs +++ b/rust-runtime/aws-smithy-http-server/src/runtime_error.rs @@ -174,6 +174,13 @@ impl IntoResponse for RuntimeError { } } +#[allow(deprecated)] +impl From for RuntimeError { + fn from(err: crate::rejection::RequestExtensionNotFoundRejection) -> Self { + Self::InternalFailure(crate::Error::new(err)) + } +} + impl From for RuntimeError { fn from(err: crate::rejection::ResponseRejection) -> Self { Self::Serialization(crate::Error::new(err)) diff --git a/rust-runtime/inlineable/src/lib.rs b/rust-runtime/inlineable/src/lib.rs index 41af3589197..2c2634110cc 100644 --- a/rust-runtime/inlineable/src/lib.rs +++ b/rust-runtime/inlineable/src/lib.rs @@ -15,6 +15,8 @@ mod json_errors; mod rest_xml_unwrapped_errors; #[allow(unused)] mod rest_xml_wrapped_errors; +#[allow(unused)] +mod server_operation_handler_trait; #[allow(unused)] mod endpoint_lib; diff --git a/rust-runtime/inlineable/src/server_operation_handler_trait.rs b/rust-runtime/inlineable/src/server_operation_handler_trait.rs new file mode 100644 index 00000000000..633120912eb --- /dev/null +++ b/rust-runtime/inlineable/src/server_operation_handler_trait.rs @@ -0,0 +1,104 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +use async_trait::async_trait; +use aws_smithy_http_server::{body::BoxBody, opaque_future}; +use futures_util::{ + future::{BoxFuture, Map}, + FutureExt, +}; +use http::{Request, Response}; +use std::marker::PhantomData; +use tower::Service; + +/// Struct that holds a handler, that is, a function provided by the user that implements the +/// Smithy operation. +#[deprecated( + since = "0.52.0", + note = "`OperationHandler` is part of the older service builder API. This type no longer appears in the public API." +)] +pub struct OperationHandler { + handler: H, + #[allow(clippy::type_complexity)] + _marker: PhantomData<(B, R, I)>, +} + +#[allow(deprecated)] +impl Clone for OperationHandler +where + H: Clone, +{ + fn clone(&self) -> Self { + Self { + handler: self.handler.clone(), + _marker: PhantomData, + } + } +} + +/// Construct an [`OperationHandler`] out of a function implementing the operation. +#[allow(deprecated)] +#[deprecated( + since = "0.52.0", + note = "`OperationHandler` is part of the older service builder API. This type no longer appears in the public API." +)] +pub fn operation(handler: H) -> OperationHandler { + OperationHandler { + handler, + _marker: PhantomData, + } +} + +#[allow(deprecated)] +impl Service> for OperationHandler +where + H: Handler, + B: Send + 'static, +{ + type Response = Response; + type Error = std::convert::Infallible; + type Future = OperationHandlerFuture; + + #[inline] + fn poll_ready( + &mut self, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + let future = + Handler::call(self.handler.clone(), req).map(Ok::<_, std::convert::Infallible> as _); + OperationHandlerFuture::new(future) + } +} + +type WrapResultInResponseFn = + fn(Response) -> Result, std::convert::Infallible>; + +opaque_future! { + /// Response future for [`OperationHandler`]. + pub type OperationHandlerFuture = + Map>, WrapResultInResponseFn>; +} + +pub(crate) mod sealed { + #![allow(unreachable_pub, missing_docs, missing_debug_implementations)] + pub trait HiddenTrait {} + pub struct Hidden; + impl HiddenTrait for Hidden {} +} + +#[deprecated( + since = "0.52.0", + note = "The inlineable `Handler` is part of the deprecated service builder API. This type no longer appears in the public API." +)] +#[async_trait] +pub trait Handler: Clone + Send + Sized + 'static { + #[doc(hidden)] + type Sealed: sealed::HiddenTrait; + + async fn call(self, req: Request) -> Response; +}