diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt index 00ea40dc7a..8fd2a07f58 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt @@ -10,16 +10,19 @@ import software.amazon.smithy.build.PluginContext import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.NullableIndex +import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerEnumGenerator +import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerOperationHandlerGenerator import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerServiceGenerator import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerStructureGenerator import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext @@ -164,4 +167,11 @@ class PythonServerCodegenVisitor( ) .render() } + + override fun operationShape(shape: OperationShape) { + super.operationShape(shape) + rustCrate.withModule(RustModule.public("python_operation_adaptor")) { + PythonServerOperationHandlerGenerator(codegenContext, shape).render(this) + } + } } diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt index ef0591cf8e..793118b441 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt @@ -199,7 +199,7 @@ class PythonApplicationGenerator( let ${name}_locals = #{pyo3_asyncio}::TaskLocals::new(event_loop); let handler = self.handlers.get("$name").expect("Python handler for operation `$name` not found").clone(); let builder = builder.$name(move |input, state| { - #{pyo3_asyncio}::tokio::scope(${name}_locals.clone(), crate::operation_handler::$name(input, state, handler.clone())) + #{pyo3_asyncio}::tokio::scope(${name}_locals.clone(), crate::python_operation_adaptor::$name(input, state, handler.clone())) }); """, *codegenScope, diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt index 60784f36b5..f107806098 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt @@ -13,8 +13,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCargoDependency -import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerOperationHandlerGenerator -import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol /** * The Rust code responsible to run the Python business logic on the Python interpreter @@ -32,9 +30,8 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.Ser */ class PythonServerOperationHandlerGenerator( codegenContext: CodegenContext, - protocol: ServerProtocol, - private val operations: List, -) : ServerOperationHandlerGenerator(codegenContext, protocol, operations) { + private val operation: OperationShape, +) { private val symbolProvider = codegenContext.symbolProvider private val runtimeConfig = codegenContext.runtimeConfig private val codegenScope = @@ -47,42 +44,39 @@ class PythonServerOperationHandlerGenerator( "tracing" to PythonServerCargoDependency.Tracing.toType(), ) - override fun render(writer: RustWriter) { - super.render(writer) + fun render(writer: RustWriter) { renderPythonOperationHandlerImpl(writer) } private fun renderPythonOperationHandlerImpl(writer: RustWriter) { - for (operation in operations) { - val operationName = symbolProvider.toSymbol(operation).name - val input = "crate::input::${operationName}Input" - val output = "crate::output::${operationName}Output" - val error = "crate::error::${operationName}Error" - val fnName = operationName.toSnakeCase() + val operationName = symbolProvider.toSymbol(operation).name + val input = "crate::input::${operationName}Input" + val output = "crate::output::${operationName}Output" + val error = "crate::error::${operationName}Error" + val fnName = operationName.toSnakeCase() - writer.rustTemplate( - """ - /// Python handler for operation `$operationName`. - pub(crate) async fn $fnName( - input: $input, - state: #{SmithyServer}::Extension<#{SmithyPython}::context::PyContext>, - handler: #{SmithyPython}::PyHandler, - ) -> std::result::Result<$output, $error> { - // Async block used to run the handler and catch any Python error. - let result = if handler.is_coroutine { - #{PyCoroutine:W} - } else { - #{PyFunction:W} - }; - #{PyError:W} - } - """, - *codegenScope, - "PyCoroutine" to renderPyCoroutine(fnName, output), - "PyFunction" to renderPyFunction(fnName, output), - "PyError" to renderPyError(), - ) - } + writer.rustTemplate( + """ + /// Python handler for operation `$operationName`. + pub(crate) async fn $fnName( + input: $input, + state: #{SmithyServer}::Extension<#{SmithyPython}::context::PyContext>, + handler: #{SmithyPython}::PyHandler, + ) -> std::result::Result<$output, $error> { + // Async block used to run the handler and catch any Python error. + let result = if handler.is_coroutine { + #{PyCoroutine:W} + } else { + #{PyFunction:W} + }; + #{PyError:W} + } + """, + *codegenScope, + "PyCoroutine" to renderPyCoroutine(fnName, output), + "PyFunction" to renderPyFunction(fnName, output), + "PyError" to renderPyError(), + ) } private fun renderPyFunction(name: String, output: String): Writable = diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerServiceGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerServiceGenerator.kt index 544eaae66a..58e87bc388 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerServiceGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerServiceGenerator.kt @@ -33,10 +33,6 @@ class PythonServerServiceGenerator( PythonServerOperationErrorGenerator(context.model, context.symbolProvider, operation).render(writer) } - override fun renderOperationHandler(writer: RustWriter, operations: List) { - PythonServerOperationHandlerGenerator(context, protocol, operations).render(writer) - } - override fun renderExtras(operations: List) { rustCrate.withModule(RustModule.public("python_server_application", "Python server and application implementation.")) { PythonApplicationGenerator(context, protocol, operations) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt index 6b4df3204e..cb4e62fe9a 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt @@ -7,8 +7,6 @@ package software.amazon.smithy.rust.codegen.server.smithy import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.CratesIo import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope -import software.amazon.smithy.rust.codegen.core.rustlang.InlineDependency -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig /** @@ -27,32 +25,8 @@ object ServerCargoDependency { val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4")) val TokioDev: CargoDependency = CargoDependency("tokio", CratesIo("1.8.4"), scope = DependencyScope.Dev) val Regex: CargoDependency = CargoDependency("regex", CratesIo("1.5.5")) + val HyperDev: CargoDependency = CargoDependency("hyper", CratesIo("0.14.12"), DependencyScope.Dev) fun smithyHttpServer(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http-server") fun smithyTypes(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-types") } - -/** - * A dependency on a snippet of code - * - * ServerInlineDependency should not be instantiated directly, rather, it should be constructed with - * [software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.forInlineFun] - * - * ServerInlineDependencies are created as private modules within the main crate. This is useful for any code that - * doesn't need to exist in a shared crate, but must still be generated exactly once during codegen. - * - * CodegenVisitor de-duplicates inline dependencies by (module, name) during code generation. - */ -object ServerInlineDependency { - fun serverOperationHandler(runtimeConfig: RuntimeConfig): InlineDependency = - InlineDependency.forRustFile( - RustModule.private("server_operation_handler_trait"), - "/inlineable/src/server_operation_handler_trait.rs", - ServerCargoDependency.smithyHttpServer(runtimeConfig), - CargoDependency.Http, - ServerCargoDependency.PinProjectLite, - ServerCargoDependency.Tower, - ServerCargoDependency.FuturesUtil, - ServerCargoDependency.AsyncTrait, - ) -} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt index a0a1baa23d..e31275fa33 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt @@ -19,9 +19,6 @@ object ServerRuntimeType { fun router(runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("routing::Router") - fun operationHandler(runtimeConfig: RuntimeConfig) = - forInlineDependency(ServerInlineDependency.serverOperationHandler(runtimeConfig)) - fun runtimeError(runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("runtime_error::RuntimeError") fun requestRejection(runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("rejection::RequestRejection") diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt deleted file mode 100644 index 308f2f0470..0000000000 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt +++ /dev/null @@ -1,154 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.rust.codegen.server.smithy.generators - -import software.amazon.smithy.model.shapes.OperationShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate -import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.generators.error.errorSymbol -import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErrors -import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember -import software.amazon.smithy.rust.codegen.core.util.inputShape -import software.amazon.smithy.rust.codegen.core.util.outputShape -import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency -import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType -import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol -import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpBoundProtocolGenerator - -/** - * ServerOperationHandlerGenerator - */ -open class ServerOperationHandlerGenerator( - codegenContext: CodegenContext, - val protocol: ServerProtocol, - private val operations: List, -) { - private val serverCrate = "aws_smithy_http_server" - private val model = codegenContext.model - private val symbolProvider = codegenContext.symbolProvider - private val runtimeConfig = codegenContext.runtimeConfig - private val codegenScope = arrayOf( - "AsyncTrait" to ServerCargoDependency.AsyncTrait.toType(), - "Tower" to ServerCargoDependency.Tower.toType(), - "FuturesUtil" to ServerCargoDependency.FuturesUtil.toType(), - "SmithyHttp" to RuntimeType.smithyHttp(runtimeConfig), - "SmithyHttpServer" to ServerCargoDependency.smithyHttpServer(runtimeConfig).toType(), - "Phantom" to RuntimeType.Phantom, - "ServerOperationHandler" to ServerRuntimeType.operationHandler(runtimeConfig), - "http" to RuntimeType.Http, - ) - - open fun render(writer: RustWriter) { - renderHandlerImplementations(writer, false) - renderHandlerImplementations(writer, true) - } - - /** - * Renders the implementation of the `Handler` trait for all operations. - * Handlers are implemented for `FnOnce` function types whose signatures take in state or not. - */ - private fun renderHandlerImplementations(writer: RustWriter, state: Boolean) { - operations.map { operation -> - val operationName = symbolProvider.toSymbol(operation).name - val inputName = symbolProvider.toSymbol(operation.inputShape(model)).fullName - val inputWrapperName = "crate::operation::$operationName${ServerHttpBoundProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}" - val outputWrapperName = "crate::operation::$operationName${ServerHttpBoundProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}" - val fnSignature = if (state) { - "impl #{ServerOperationHandler}::Handler, $inputName> for Fun" - } else { - "impl #{ServerOperationHandler}::Handler for Fun" - } - writer.rustBlockTemplate( - """ - ##[#{AsyncTrait}::async_trait] - $fnSignature - where - ${operationTraitBounds(operation, inputName, state)} - """.trimIndent(), - *codegenScope, - ) { - val callImpl = if (state) { - """ - let state = match $serverCrate::extension::extract_extension(&mut req).await { - Ok(v) => v, - Err(extension_not_found_rejection) => { - let extension = $serverCrate::extension::RuntimeErrorExtension::new(extension_not_found_rejection.to_string()); - let runtime_error = $serverCrate::runtime_error::RuntimeError::from(extension_not_found_rejection); - let mut response = #{SmithyHttpServer}::response::IntoResponse::<#{Protocol}>::into_response(runtime_error); - response.extensions_mut().insert(extension); - return response.map($serverCrate::body::boxed); - } - }; - let input_inner = input_wrapper.into(); - let output_inner = self(input_inner, state).await; - """.trimIndent() - } else { - """ - let input_inner = input_wrapper.into(); - let output_inner = self(input_inner).await; - """.trimIndent() - } - rustTemplate( - """ - type Sealed = #{ServerOperationHandler}::sealed::Hidden; - async fn call(self, req: #{http}::Request) -> #{http}::Response<#{SmithyHttpServer}::body::BoxBody> { - let mut req = #{SmithyHttpServer}::request::RequestParts::new(req); - let input_wrapper = match $inputWrapperName::from_request(&mut req).await { - Ok(v) => v, - Err(runtime_error) => { - let response = #{SmithyHttpServer}::response::IntoResponse::<#{Protocol}>::into_response(runtime_error); - return response.map($serverCrate::body::boxed); - } - }; - $callImpl - let output_wrapper: $outputWrapperName = output_inner.into(); - let mut response = output_wrapper.into_response(); - let operation_ext = #{SmithyHttpServer}::extension::OperationExtension::new("${operation.id.namespace}.$operationName").expect("malformed absolute shape ID"); - response.extensions_mut().insert(operation_ext); - response.map(#{SmithyHttpServer}::body::boxed) - } - """, - "Protocol" to protocol.markerStruct(), - *codegenScope, - ) - } - } - } - - /** - * Generates the trait bounds of the `Handler` trait implementation, depending on: - * - the presence of state; and - * - whether the operation is fallible or not. - */ - private fun operationTraitBounds(operation: OperationShape, inputName: String, state: Boolean): String { - val inputFn = if (state) { - """S: Send + Clone + Sync + 'static, - Fun: FnOnce($inputName, $serverCrate::Extension) -> Fut + Clone + Send + 'static,""" - } else { - "Fun: FnOnce($inputName) -> Fut + Clone + Send + 'static," - } - val outputType = if (operation.operationErrors(model).isNotEmpty()) { - "Result<${symbolProvider.toSymbol(operation.outputShape(model)).fullName}, ${operation.errorSymbol(symbolProvider).fullyQualifiedName()}>" - } else { - symbolProvider.toSymbol(operation.outputShape(model)).fullName - } - val streamingBodyTraitBounds = if (operation.inputShape(model).hasStreamingMember(model)) { - "\n B: Into<#{SmithyHttp}::byte_stream::ByteStream>," - } else { - "" - } - return """ - $inputFn - Fut: std::future::Future + Send, - B: $serverCrate::body::HttpBody + Send + 'static, $streamingBodyTraitBounds - B::Data: Send, - $serverCrate::rejection::RequestRejection: From<::Error> - """.trimIndent() - } -} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationRegistryGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationRegistryGenerator.kt deleted file mode 100644 index 10b11978d0..0000000000 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationRegistryGenerator.kt +++ /dev/null @@ -1,407 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.rust.codegen.server.smithy.generators - -import software.amazon.smithy.model.shapes.OperationShape -import software.amazon.smithy.model.traits.DocumentationTrait -import software.amazon.smithy.rust.codegen.core.rustlang.Attribute -import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.derive -import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency -import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope -import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.rustlang.Writable -import software.amazon.smithy.rust.codegen.core.rustlang.rust -import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock -import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate -import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.core.rustlang.withBlock -import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate -import software.amazon.smithy.rust.codegen.core.rustlang.writable -import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext -import software.amazon.smithy.rust.codegen.core.smithy.ErrorsModule -import software.amazon.smithy.rust.codegen.core.smithy.InputsModule -import software.amazon.smithy.rust.codegen.core.smithy.OutputsModule -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.generators.error.errorSymbol -import software.amazon.smithy.rust.codegen.core.util.getTrait -import software.amazon.smithy.rust.codegen.core.util.inputShape -import software.amazon.smithy.rust.codegen.core.util.outputShape -import software.amazon.smithy.rust.codegen.core.util.toSnakeCase -import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency -import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType -import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol - -/** - * [ServerOperationRegistryGenerator] renders the `OperationRegistry` struct, a place where users can register their - * service's operation implementations. - * - * Users can construct the operation registry using a builder. They can subsequently convert the operation registry into - * the [`aws_smithy_http_server::Router`], a [`tower::Service`] that will route incoming requests to their operation - * handlers, invoking them and returning the response. - * - * [`aws_smithy_http_server::Router`]: https://docs.rs/aws-smithy-http-server/latest/aws_smithy_http_server/struct.Router.html - * [`tower::Service`]: https://docs.rs/tower/latest/tower/trait.Service.html - */ -class ServerOperationRegistryGenerator( - private val codegenContext: CodegenContext, - private val protocol: ServerProtocol, - private val operations: List, -) { - private val crateName = codegenContext.settings.moduleName - private val model = codegenContext.model - private val symbolProvider = codegenContext.symbolProvider - private val serviceName = codegenContext.serviceShape.toShapeId().name - private val operationNames = operations.map { RustReservedWords.escapeIfNeeded(symbolProvider.toSymbol(it).name.toSnakeCase()) } - private val runtimeConfig = codegenContext.runtimeConfig - private val codegenScope = arrayOf( - "Router" to ServerRuntimeType.router(runtimeConfig), - "SmithyHttpServer" to ServerCargoDependency.smithyHttpServer(runtimeConfig).toType(), - "ServerOperationHandler" to ServerRuntimeType.operationHandler(runtimeConfig), - "Tower" to ServerCargoDependency.Tower.toType(), - "Phantom" to RuntimeType.Phantom, - "StdError" to RuntimeType.StdError, - "Display" to RuntimeType.Display, - "From" to RuntimeType.From, - ) - private val operationRegistryName = "OperationRegistry" - private val operationRegistryBuilderName = "${operationRegistryName}Builder" - private val operationRegistryErrorName = "${operationRegistryBuilderName}Error" - private val genericArguments = "B, " + operations.mapIndexed { i, _ -> "Op$i, In$i" }.joinToString() - private val operationRegistryNameWithArguments = "$operationRegistryName<$genericArguments>" - private val operationRegistryBuilderNameWithArguments = "$operationRegistryBuilderName<$genericArguments>" - - fun render(writer: RustWriter) { - renderOperationRegistryRustDocs(writer) - renderOperationRegistryStruct(writer) - renderOperationRegistryBuilderStruct(writer) - renderOperationRegistryBuilderError(writer) - renderOperationRegistryBuilderDefault(writer) - renderOperationRegistryBuilderImplementation(writer) - renderRouterImplementationFromOperationRegistryBuilder(writer) - } - - private fun renderOperationRegistryRustDocs(writer: RustWriter) { - val inputOutputErrorsImport = if (operations.any { it.errors.isNotEmpty() }) { - "/// use ${crateName.toSnakeCase()}::{${InputsModule.name}, ${OutputsModule.name}, ${ErrorsModule.name}};" - } else { - "/// use ${crateName.toSnakeCase()}::{${InputsModule.name}, ${OutputsModule.name}};" - } - - writer.rustTemplate( -""" -##[allow(clippy::tabs_in_doc_comments)] -/// The `$operationRegistryName` is the place where you can register -/// your service's operation implementations. -/// -/// Use [`$operationRegistryBuilderName`] to construct the -/// `$operationRegistryName`. For each of the [operations] modeled in -/// your Smithy service, you need to provide an implementation in the -/// form of a Rust async function or closure that takes in the -/// operation's input as their first parameter, and returns the -/// operation's output. If your operation is fallible (i.e. it -/// contains the `errors` member in your Smithy model), the function -/// implementing the operation has to be fallible (i.e. return a -/// [`Result`]). **You must register an implementation for all -/// operations with the correct signature**, or your application -/// will fail to compile. -/// -/// The operation registry can be converted into an [`#{Router}`] for -/// your service. This router will take care of routing HTTP -/// requests to the matching operation implementation, adhering to -/// your service's protocol and the [HTTP binding traits] that you -/// used in your Smithy model. This router can be converted into a -/// type implementing [`tower::make::MakeService`], a _service -/// factory_. You can feed this value to a [Hyper server], and the -/// server will instantiate and [`serve`] your service. -/// -/// Here's a full example to get you started: -/// -/// ```rust -/// use std::net::SocketAddr; -$inputOutputErrorsImport -/// use ${crateName.toSnakeCase()}::operation_registry::$operationRegistryBuilderName; -/// use #{Router}; -/// -/// ##[#{Tokio}::main] -/// pub async fn main() { -/// let app: Router = $operationRegistryBuilderName::default() -${operationNames.map { ".$it($it)" }.joinToString("\n") { it.prependIndent("/// ") }} -/// .build() -/// .expect("unable to build operation registry") -/// .into(); -/// -/// let bind: SocketAddr = format!("{}:{}", "127.0.0.1", "6969") -/// .parse() -/// .expect("unable to parse the server bind address and port"); -/// -/// let server = #{Hyper}::Server::bind(&bind).serve(app.into_make_service()); -/// -/// // Run your service! -/// // if let Err(err) = server.await { -/// // eprintln!("server error: {}", err); -/// // } -/// } -/// -${operationImplementationStubs(operations)} -/// ``` -/// -/// [`serve`]: https://docs.rs/hyper/0.14.16/hyper/server/struct.Builder.html##method.serve -/// [`tower::make::MakeService`]: https://docs.rs/tower/latest/tower/make/trait.MakeService.html -/// [HTTP binding traits]: https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html -/// [operations]: https://awslabs.github.io/smithy/1.0/spec/core/model.html##operation -/// [Hyper server]: https://docs.rs/hyper/latest/hyper/server/index.html -""", - "Router" to ServerRuntimeType.router(runtimeConfig), - // These should be dev-dependencies. Not all sSDKs depend on `Hyper` (only those that convert the body - // `to_bytes`), and none depend on `tokio`. - "Tokio" to ServerCargoDependency.TokioDev.toType(), - "Hyper" to CargoDependency.Hyper.copy(scope = DependencyScope.Dev).toType(), - ) - } - - private fun renderOperationRegistryStruct(writer: RustWriter) { - writer.rust("""##[deprecated(since = "0.52.0", note = "`OperationRegistry` is part of the deprecated service builder API. Use `$serviceName::builder` instead.")]""") - writer.rustBlock("pub struct $operationRegistryNameWithArguments") { - val members = operationNames - .mapIndexed { i, operationName -> "$operationName: Op$i" } - .joinToString(separator = ",\n") - rustTemplate( - """ - $members, - _phantom: #{Phantom}<(B, ${phantomMembers()})>, - """, - *codegenScope, - ) - } - } - - /** - * Renders the `OperationRegistryBuilder` structure, used to build the `OperationRegistry`. - */ - private fun renderOperationRegistryBuilderStruct(writer: RustWriter) { - writer.rust("""##[deprecated(since = "0.52.0", note = "`OperationRegistryBuilder` is part of the deprecated service builder API. Use `$serviceName::builder` instead.")]""") - writer.rustBlock("pub struct $operationRegistryBuilderNameWithArguments") { - val members = operationNames - .mapIndexed { i, operationName -> "$operationName: Option" } - .joinToString(separator = ",\n") - rustTemplate( - """ - $members, - _phantom: #{Phantom}<(B, ${phantomMembers()})>, - """, - *codegenScope, - ) - } - } - - /** - * Renders the `OperationRegistryBuilderError` type, used to error out in case there are uninitialized fields. - * This is an enum deriving `Debug` and implementing `Display` and `std::error::Error`. - */ - private fun renderOperationRegistryBuilderError(writer: RustWriter) { - Attribute(derive(RuntimeType.Debug)).render(writer) - writer.rustTemplate( - """ - pub enum $operationRegistryErrorName { - UninitializedField(&'static str) - } - impl #{Display} for $operationRegistryErrorName { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::UninitializedField(v) => write!(f, "{}", v), - } - } - } - impl #{StdError} for $operationRegistryErrorName {} - """, - *codegenScope, - ) - } - - /** - * Renders the `OperationRegistryBuilder` `Default` implementation, used to create a new builder that can be - * populated with the service's operation implementations. - */ - private fun renderOperationRegistryBuilderDefault(writer: RustWriter) { - writer.rustBlockTemplate("impl<$genericArguments> std::default::Default for $operationRegistryBuilderNameWithArguments") { - val defaultOperations = operationNames.joinToString(separator = "\n,") { operationName -> - "$operationName: Default::default()" - } - rustTemplate( - """ - fn default() -> Self { - Self { - $defaultOperations, - _phantom: #{Phantom} - } - } - """, - *codegenScope, - ) - } - } - - /** - * Renders the `OperationRegistryBuilder`'s impl block, where operations are stored. - * The `build()` method converts the builder into an `OperationRegistry` instance. - */ - private fun renderOperationRegistryBuilderImplementation(writer: RustWriter) { - writer.rustBlock("impl<$genericArguments> $operationRegistryBuilderNameWithArguments") { - operationNames.forEachIndexed { i, operationName -> - rust( - """ - pub fn $operationName(self, value: Op$i) -> Self { - let mut new = self; - new.$operationName = Some(value); - new - } - """, - ) - } - - rustBlock("pub fn build(self) -> Result<$operationRegistryNameWithArguments, $operationRegistryErrorName>") { - withBlock("Ok( $operationRegistryName {", "})") { - for (operationName in operationNames) { - rust( - """ - $operationName: match self.$operationName { - Some(v) => v, - None => return Err($operationRegistryErrorName::UninitializedField("$operationName")), - }, - """, - ) - } - rustTemplate("_phantom: #{Phantom}", *codegenScope) - } - } - } - } - - /** - * Renders the converter between the `OperationRegistry` and the `Router` via the `std::convert::From` trait. - */ - private fun renderRouterImplementationFromOperationRegistryBuilder(writer: RustWriter) { - val operationTraitBounds = writable { - operations.forEachIndexed { i, operation -> - rustTemplate( - """ - Op$i: #{ServerOperationHandler}::Handler, - In$i: 'static + Send, - """, - *codegenScope, - "OperationInput" to symbolProvider.toSymbol(operation.inputShape(model)), - ) - } - } - - writer.rustBlockTemplate( - // The bound `B: Send` is required because of [`tower::util::BoxCloneService`]. - // [`tower::util::BoxCloneService`]: https://docs.rs/tower/latest/tower/util/struct.BoxCloneService.html#method.new - """ - impl<$genericArguments> #{From}<$operationRegistryNameWithArguments> for #{Router} - where - B: Send + 'static, - #{operationTraitBounds:W} - """, - *codegenScope, - "operationTraitBounds" to operationTraitBounds, - ) { - rustBlock("fn from(registry: $operationRegistryNameWithArguments) -> Self") { - val requestSpecsVarNames = operationNames.map { "${it}_request_spec" } - - requestSpecsVarNames.zip(operations).forEach { (requestSpecVarName, operation) -> - rustTemplate( - "let $requestSpecVarName = #{RequestSpec:W};", - "RequestSpec" to operation.requestSpec(), - ) - } - - val sensitivityGens = operations.map { - ServerHttpSensitivityGenerator(model, it, codegenContext.runtimeConfig) - } - - withBlockTemplate( - "#{Router}::${protocol.serverRouterRuntimeConstructor()}(vec![", - "])", - *codegenScope, - ) { - requestSpecsVarNames.zip(operationNames).zip(sensitivityGens).forEach { - val (inner, sensitivityGen) = it - val (requestSpecVarName, operationName) = inner - - rustBlock("") { - rustTemplate( - """ - let svc = #{ServerOperationHandler}::operation(registry.$operationName); - let request_fmt = #{RequestFmt:W}; - let response_fmt = #{ResponseFmt:W}; - let svc = #{SmithyHttpServer}::instrumentation::InstrumentOperation::new(svc, "$operationName").request_fmt(request_fmt).response_fmt(response_fmt); - (#{Tower}::util::BoxCloneService::new(svc), $requestSpecVarName) - """, - "RequestFmt" to sensitivityGen.requestFmt().value, - "ResponseFmt" to sensitivityGen.responseFmt().value, - *codegenScope, - ) - } - rust(",") - } - } - } - } - } - - /** - * Returns the `PhantomData` generic members in a comma-separated list. - */ - private fun phantomMembers() = operationNames.mapIndexed { i, _ -> "In$i" }.joinToString(separator = ",\n") - - private fun operationImplementationStubs(operations: List): String = - operations.joinToString("\n///\n") { - val operationDocumentation = it.getTrait()?.value - val ret = if (!operationDocumentation.isNullOrBlank()) { - operationDocumentation.replace("#", "##").prependIndent("/// /// ") + "\n" - } else "" - ret + - """ - /// ${it.signature()} { - /// todo!() - /// } - """.trimIndent() - } - - /** - * Returns the function signature for an operation handler implementation. Used in the documentation. - */ - private fun OperationShape.signature(): String { - val inputSymbol = symbolProvider.toSymbol(inputShape(model)) - val outputSymbol = symbolProvider.toSymbol(outputShape(model)) - val errorSymbol = errorSymbol(symbolProvider) - - // using module names here to avoid generating `crate::...` since we've already added the import - val inputT = "${InputsModule.name}::${inputSymbol.name}" - val t = "${OutputsModule.name}::${outputSymbol.name}" - val outputT = if (errors.isEmpty()) { - t - } else { - val e = "${ErrorsModule.name}::${errorSymbol.name}" - "Result<$t, $e>" - } - - val operationName = RustReservedWords.escapeIfNeeded(symbolProvider.toSymbol(this).name.toSnakeCase()) - return "async fn $operationName(input: $inputT) -> $outputT" - } - - /** - * Returns a writable for the `RequestSpec` for an operation based on the service's protocol. - */ - private fun OperationShape.requestSpec(): Writable = protocol.serverRouterRequestSpec( - this, - symbolProvider.toSymbol(this).name, - serviceName, - ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("routing::request_spec"), - ) -} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt index a8d605a1a4..cf405dc5e9 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt @@ -7,9 +7,6 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.model.knowledge.TopDownIndex import software.amazon.smithy.model.shapes.OperationShape -import software.amazon.smithy.rust.codegen.core.rustlang.Attribute -import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.deprecated -import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter @@ -23,7 +20,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.InputsModule import software.amazon.smithy.rust.codegen.core.smithy.OutputsModule import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport -import software.amazon.smithy.rust.codegen.core.util.toPascalCase import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext @@ -86,7 +82,7 @@ open class ServerServiceGenerator( //! let server = app.into_make_service(); //! let bind: SocketAddr = "127.0.0.1:6969".parse() //! .expect("unable to parse the server bind address and port"); - //! hyper::Server::bind(&bind).serve(server).await.unwrap(); + //! #{Hyper}::Server::bind(&bind).serve(server).await.unwrap(); //! ## } //! ``` //! @@ -119,7 +115,7 @@ open class ServerServiceGenerator( //! ```rust //! ## use #{SmithyHttpServer}::plugin::IdentityPlugin as LoggingPlugin; //! ## use #{SmithyHttpServer}::plugin::IdentityPlugin as MetricsPlugin; - //! ## use hyper::Body; + //! ## use #{Hyper}::Body; //! use #{SmithyHttpServer}::plugin::PluginPipeline; //! use $crateName::{$serviceName, $builderName}; //! @@ -194,7 +190,7 @@ open class ServerServiceGenerator( //! ## use std::net::SocketAddr; //! use $crateName::$serviceName; //! - //! ##[tokio::main] + //! ##[#{Tokio}::main] //! pub async fn main() { //! let app = $serviceName::builder_without_plugins() ${builderFieldNames.values.joinToString("\n") { "//! .$it($it)" }} @@ -203,7 +199,7 @@ open class ServerServiceGenerator( //! //! let bind: SocketAddr = "127.0.0.1:6969".parse() //! .expect("unable to parse the server bind address and port"); - //! let server = hyper::Server::bind(&bind).serve(app.into_make_service()); + //! let server = #{Hyper}::Server::bind(&bind).serve(app.into_make_service()); //! ## let server = async { Ok::<_, ()>(()) }; //! //! // Run your service! @@ -229,6 +225,8 @@ open class ServerServiceGenerator( "Handlers" to handlers, "ExampleHandler" to operations.take(1).map { operation -> DocHandlerGenerator(codegenContext, operation, builderFieldNames[operation]!!, "//!").docSignature() }, "SmithyHttpServer" to ServerCargoDependency.smithyHttpServer(codegenContext.runtimeConfig).toType(), + "Hyper" to ServerCargoDependency.HyperDev.toType(), + "Tokio" to ServerCargoDependency.TokioDev.toType(), "Tower" to ServerCargoDependency.Tower.toType(), ) } @@ -255,30 +253,6 @@ open class ServerServiceGenerator( } } } - rustCrate.withModule(RustModule.private("operation_handler", "Operation handlers definition and implementation.")) { - renderOperationHandler(this, operations) - } - rustCrate.withModule( - RustModule.LeafModule( - "operation_registry", - RustMetadata( - visibility = Visibility.PUBLIC, - additionalAttributes = listOf( - Attribute(deprecated("0.52.0", "This module exports the deprecated `OperationRegistry`. Use the service builder exported from your root crate.")), - ), - ), - """ - Contains the [`operation_registry::OperationRegistry`], a place where - you can register your service's operation implementations. - - ## Deprecation - - This service builder is deprecated - use [`${codegenContext.serviceShape.id.name.toPascalCase()}::builder_with_plugins`] or [`${codegenContext.serviceShape.id.name.toPascalCase()}::builder_without_plugins`] instead. - """, - ), - ) { - renderOperationRegistry(this, operations) - } rustCrate.withModule( RustModule.public("operation_shape"), @@ -317,16 +291,6 @@ open class ServerServiceGenerator( /* Subclasses can override */ } - // Render operations handler. - open fun renderOperationHandler(writer: RustWriter, operations: List) { - ServerOperationHandlerGenerator(codegenContext, protocol, operations).render(writer) - } - - // Render operations registry. - private fun renderOperationRegistry(writer: RustWriter, operations: List) { - ServerOperationRegistryGenerator(codegenContext, protocol, operations).render(writer) - } - // Render `server` crate, re-exporting types. private fun renderServerReExports(writer: RustWriter) { ServerRuntimeTypesReExportsGenerator(codegenContext).render(writer) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGeneratorV2.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGeneratorV2.kt index 80b1d9b631..e8483f4a20 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGeneratorV2.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGeneratorV2.kt @@ -231,7 +231,7 @@ class ServerServiceGeneratorV2( #{Router}::from_iter([#{RoutesArrayElements:W}]) }; Ok($serviceName { - router: #{SmithyHttpServer}::routers::RoutingService::new(router), + router: #{SmithyHttpServer}::routing::RoutingService::new(router), }) } """, @@ -306,7 +306,7 @@ class ServerServiceGeneratorV2( { let router = #{Router}::from_iter([#{Pairs:W}]); $serviceName { - router: #{SmithyHttpServer}::routers::RoutingService::new(router), + router: #{SmithyHttpServer}::routing::RoutingService::new(router), } } """, @@ -387,7 +387,7 @@ class ServerServiceGeneratorV2( /// See the [root](crate) documentation for more information. ##[derive(Clone)] pub struct $serviceName { - router: #{SmithyHttpServer}::routers::RoutingService<#{Router}, #{Protocol}>, + router: #{SmithyHttpServer}::routing::RoutingService<#{Router}, #{Protocol}>, } impl $serviceName<()> { @@ -459,7 +459,7 @@ class ServerServiceGeneratorV2( { type Response = #{Http}::Response<#{SmithyHttpServer}::body::BoxBody>; type Error = S::Error; - type Future = #{SmithyHttpServer}::routers::RoutingFuture; + type Future = #{SmithyHttpServer}::routing::RoutingFuture; fn poll_ready(&mut self, cx: &mut std::task::Context) -> std::task::Poll> { self.router.poll_ready(cx) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 431eb99637..129199115a 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -37,7 +37,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.withBlock -import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType @@ -56,12 +55,9 @@ import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType import software.amazon.smithy.rust.codegen.server.smithy.generators.serverInstantiator -import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpBoundProtocolGenerator import java.util.logging.Logger import kotlin.reflect.KFunction1 -private const val PROTOCOL_TEST_HELPER_MODULE_NAME = "protocol_test_helper" - /** * Generate protocol tests for an operation */ @@ -141,97 +137,12 @@ class ServerProtocolTestGenerator( } fun render(writer: RustWriter) { - renderTestHelper(writer) - for (operation in operations) { protocolGenerator.renderOperation(writer, operation) renderOperationTestCases(operation, writer) } } - /** - * Render a test helper module to: - * - * - generate a dynamic builder for each handler, and - * - construct a Tower service to exercise each test case. - */ - private fun renderTestHelper(writer: RustWriter) { - val operationNames = operations.map { it.toName() } - val operationRegistryName = "OperationRegistry" - val operationRegistryBuilderName = "${operationRegistryName}Builder" - - fun renderRegistryBuilderTypeParams() = writable { - operations.forEach { - val (inputT, outputT) = operationInputOutputTypes[it]!! - writeInline("Fun<$inputT, $outputT>, (), ") - } - } - - fun renderRegistryBuilderMethods() = writable { - operations.withIndex().forEach { - val (inputT, outputT) = operationInputOutputTypes[it.value]!! - val operationName = operationNames[it.index] - rust(".$operationName((|_| Box::pin(async { todo!() })) as Fun<$inputT, $outputT> )") - } - } - - val module = RustModule.LeafModule( - PROTOCOL_TEST_HELPER_MODULE_NAME, - RustMetadata( - additionalAttributes = listOf( - Attribute.CfgTest, - Attribute.AllowDeadCode, - ), - visibility = Visibility.PUBCRATE, - ), - inline = true, - ) - - writer.withInlineModule(module) { - rustTemplate( - """ - use #{Tower}::Service as _; - - pub(crate) type Fun = fn(Input) -> std::pin::Pin + Send>>; - - type RegistryBuilder = crate::operation_registry::$operationRegistryBuilderName<#{Hyper}::Body, #{RegistryBuilderTypeParams:W}>; - - fn create_operation_registry_builder() -> RegistryBuilder { - crate::operation_registry::$operationRegistryBuilderName::default() - #{RegistryBuilderMethods:W} - } - - pub(crate) async fn build_router_and_make_request( - http_request: #{Http}::request::Request<#{SmithyHttpServer}::body::Body>, - f: &dyn Fn(RegistryBuilder) -> RegistryBuilder, - ) -> #{Http}::response::Response<#{SmithyHttpServer}::body::BoxBody> { - let mut router: #{Router} = f(create_operation_registry_builder()) - .build() - .expect("unable to build operation registry") - .into(); - let http_response = router - .call(http_request) - .await - .expect("unable to make an HTTP request"); - - http_response - } - - /// The operation full name is a concatenation of `.`. - pub(crate) fn check_operation_extension_was_set(http_response: #{Http}::response::Response<#{SmithyHttpServer}::body::BoxBody>, operation_full_name: &str) { - let operation_extension = http_response.extensions() - .get::<#{SmithyHttpServer}::extension::OperationExtension>() - .expect("extension `OperationExtension` not found"); - #{AssertEq}(operation_extension.absolute(), operation_full_name); - } - """, - "RegistryBuilderTypeParams" to renderRegistryBuilderTypeParams(), - "RegistryBuilderMethods" to renderRegistryBuilderMethods(), - *codegenScope, - ) - } - } - private fun renderOperationTestCases(operationShape: OperationShape, writer: RustWriter) { val outputShape = operationShape.outputShape(codegenContext.model) val operationSymbol = symbolProvider.toSymbol(operationShape) @@ -395,22 +306,12 @@ class ServerProtocolTestGenerator( return } - // Test against original `OperationRegistryBuilder`. with(httpRequestTestCase) { renderHttpRequest(uri, method, headers, body.orNull(), queryParams, host.orNull()) } if (protocolSupport.requestBodyDeserialization) { - makeRequest(operationShape, this, checkRequestHandler(operationShape, httpRequestTestCase)) - checkHandlerWasEntered(operationShape, operationSymbol, this) - } - - // Test against new service builder. - with(httpRequestTestCase) { - renderHttpRequest(uri, method, headers, body.orNull(), queryParams, host.orNull()) - } - if (protocolSupport.requestBodyDeserialization) { - makeRequest2(operationShape, operationSymbol, this, checkRequestHandler(operationShape, httpRequestTestCase)) - checkHandlerWasEntered2(this) + makeRequest(operationShape, operationSymbol, this, checkRequestHandler(operationShape, httpRequestTestCase)) + checkHandlerWasEntered(this) } // Explicitly warn if the test case defined parameters that we aren't doing anything with @@ -440,8 +341,6 @@ class ServerProtocolTestGenerator( operationShape: OperationShape, operationSymbol: Symbol, ) { - val operationImplementationName = - "${operationSymbol.name}${ServerHttpBoundProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}" val operationErrorName = "crate::error::${operationSymbol.name}Error" if (!protocolSupport.responseSerialization || ( @@ -454,19 +353,13 @@ class ServerProtocolTestGenerator( writeInline("let output =") instantiator.render(this, shape, testCase.params) rust(";") - val operationImpl = if (operationShape.allErrors(model).isNotEmpty()) { - if (shape.hasTrait()) { - val variant = symbolProvider.toSymbol(shape).name - "$operationImplementationName::Error($operationErrorName::$variant(output))" - } else { - "$operationImplementationName::Output(output)" - } - } else { - "$operationImplementationName(output)" + if (operationShape.allErrors(model).isNotEmpty() && shape.hasTrait()) { + val variant = symbolProvider.toSymbol(shape).name + rust("let output = $operationErrorName::$variant(output);") } rustTemplate( """ - let output = super::$operationImpl; + use #{SmithyHttpServer}::response::IntoResponse; let http_response = output.into_response(); """, *codegenScope, @@ -488,23 +381,13 @@ class ServerProtocolTestGenerator( val panicMessage = "request should have been rejected, but we accepted it; we parsed operation input `{:?}`" - rust("// Use the `OperationRegistryBuilder`") rustBlock("") { with(testCase.request) { // TODO(https://github.com/awslabs/smithy/issues/1102): `uri` should probably not be an `Optional`. renderHttpRequest(uri.get(), method, headers, body.orNull(), queryParams, host.orNull()) } - makeRequest(operationShape, this, writable("""panic!("$panicMessage", &input) as $outputT""")) - checkResponse(this, testCase.response) - } - rust("// Use new service builder") - rustBlock("") { - with(testCase.request) { - // TODO(https://github.com/awslabs/smithy/issues/1102): `uri` should probably not be an `Optional`. - renderHttpRequest(uri.get(), method, headers, body.orNull(), queryParams, host.orNull()) - } - makeRequest2(operationShape, operationSymbol, this, writable("""panic!("$panicMessage", &input) as $outputT""")) + makeRequest(operationShape, operationSymbol, this, writable("""panic!("$panicMessage", &input) as $outputT""")) checkResponse(this, testCase.response) } } @@ -586,44 +469,8 @@ class ServerProtocolTestGenerator( } } - /** Checks the request using the `OperationRegistryBuilder`. */ + /** Checks the request. */ private fun makeRequest( - operationShape: OperationShape, - rustWriter: RustWriter, - operationBody: Writable, - ) { - val (inputT, outputT) = operationInputOutputTypes[operationShape]!! - - rustWriter.withBlockTemplate( - """ - let http_response = super::$PROTOCOL_TEST_HELPER_MODULE_NAME::build_router_and_make_request( - http_request, - &|builder| { - builder.${operationShape.toName()}((|input| Box::pin(async move { - """, - - "})) as super::$PROTOCOL_TEST_HELPER_MODULE_NAME::Fun<$inputT, $outputT>)}).await;", - *codegenScope, - ) { - operationBody() - } - } - - private fun checkHandlerWasEntered( - operationShape: OperationShape, - operationSymbol: Symbol, - rustWriter: RustWriter, - ) { - val operationFullName = "${operationShape.id.namespace}.${operationSymbol.name}" - rustWriter.rust( - """ - super::$PROTOCOL_TEST_HELPER_MODULE_NAME::check_operation_extension_was_set(http_response, "$operationFullName"); - """, - ) - } - - /** Checks the request using the new service builder. */ - private fun makeRequest2( operationShape: OperationShape, operationSymbol: Symbol, rustWriter: RustWriter, @@ -654,7 +501,7 @@ class ServerProtocolTestGenerator( ) } - private fun checkHandlerWasEntered2(rustWriter: RustWriter) { + private fun checkHandlerWasEntered(rustWriter: RustWriter) { rustWriter.rust( """ assert!(receiver.recv().await.is_some()); diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index eed3dc8433..5e2ce595af 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -161,13 +161,11 @@ private class ServerHttpBoundProtocolTraitImplGenerator( operationShape: OperationShape, ) { val operationName = symbolProvider.toSymbol(operationShape).name - val inputName = "${operationName}${ServerHttpBoundProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}" - val verifyAcceptHeader = writable { httpBindingResolver.responseContentType(operationShape)?.also { contentType -> rustTemplate( """ - if ! #{SmithyHttpServer}::protocols::accept_header_classifier(req, ${contentType.dq()}) { + if ! #{SmithyHttpServer}::protocols::accept_header_classifier(request.headers(), ${contentType.dq()}) { return Err(#{RuntimeError}::NotAcceptable) } """, @@ -187,7 +185,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( ?.let { "Some(${it.dq()})" } ?: "None" rustTemplate( """ - if #{SmithyHttpServer}::protocols::content_type_header_classifier(req, $expectedRequestContentType).is_err() { + if #{SmithyHttpServer}::protocols::content_type_header_classifier(request.headers(), $expectedRequestContentType).is_err() { return Err(#{RuntimeError}::UnsupportedMediaType) } """, @@ -200,25 +198,6 @@ private class ServerHttpBoundProtocolTraitImplGenerator( // Implement `from_request` trait for input types. rustTemplate( """ - ##[derive(Debug)] - pub(crate) struct $inputName(#{I}); - impl $inputName - { - pub async fn from_request(req: &mut #{SmithyHttpServer}::request::RequestParts) -> Result - where - B: #{SmithyHttpServer}::body::HttpBody + Send, ${streamingBodyTraitBounds(operationShape)} - B::Data: Send, - #{RequestRejection} : From<::Error> - { - #{verifyAcceptHeader:W} - #{verifyRequestContentTypeHeader:W} - #{parse_request}(req) - .await - .map($inputName) - .map_err(Into::into) - } - } - impl #{SmithyHttpServer}::request::FromRequest<#{Marker}, B> for #{I} where B: #{SmithyHttpServer}::body::HttpBody + Send, @@ -232,8 +211,11 @@ private class ServerHttpBoundProtocolTraitImplGenerator( fn from_request(request: #{http}::Request) -> Self::Future { let fut = async move { - let mut request_parts = #{SmithyHttpServer}::request::RequestParts::new(request); - $inputName::from_request(&mut request_parts).await.map(|x| x.0) + #{verifyAcceptHeader:W} + #{verifyRequestContentTypeHeader:W} + #{parse_request}(request) + .await + .map_err(Into::into) }; Box::pin(fut) } @@ -249,143 +231,46 @@ private class ServerHttpBoundProtocolTraitImplGenerator( ) // Implement `into_response` for output types. - - val outputName = "${operationName}${ServerHttpBoundProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}" val errorSymbol = operationShape.errorSymbol(symbolProvider) - if (operationShape.operationErrors(model).isNotEmpty()) { - // The output of fallible operations is a `Result` which we convert into an - // isomorphic `enum` type we control that can in turn be converted into a response. - val intoResponseImpl = - """ - match self { - Self::Output(o) => { - match #{serialize_response}(o) { - Ok(response) => response, - Err(e) => #{SmithyHttpServer}::response::IntoResponse::<#{Marker}>::into_response(#{RuntimeError}::from(e)) - } - }, - Self::Error(err) => { - match #{serialize_error}(&err) { - Ok(mut response) => { - response.extensions_mut().insert(#{SmithyHttpServer}::extension::ModeledErrorExtension::new(err.name())); - response - }, - Err(e) => #{SmithyHttpServer}::response::IntoResponse::<#{Marker}>::into_response(#{RuntimeError}::from(e)) - } + rustTemplate( + """ + impl #{SmithyHttpServer}::response::IntoResponse<#{Marker}> for #{O} { + fn into_response(self) -> #{SmithyHttpServer}::response::Response { + match #{serialize_response}(self) { + Ok(response) => response, + Err(e) => #{SmithyHttpServer}::response::IntoResponse::<#{Marker}>::into_response(#{RuntimeError}::from(e)) } } - """ + } + """.trimIndent(), + *codegenScope, + "O" to outputSymbol, + "Marker" to protocol.markerStruct(), + "serialize_response" to serverSerializeResponse(operationShape), + ) + if (operationShape.operationErrors(model).isNotEmpty()) { rustTemplate( """ - pub(crate) enum $outputName { - Output(#{O}), - Error(#{E}) - } - - impl $outputName { - pub fn into_response(self) -> #{SmithyHttpServer}::response::Response { - $intoResponseImpl - } - } - - impl #{SmithyHttpServer}::response::IntoResponse<#{Marker}> for #{O} { - fn into_response(self) -> #{SmithyHttpServer}::response::Response { - $outputName::Output(self).into_response() - } - } - impl #{SmithyHttpServer}::response::IntoResponse<#{Marker}> for #{E} { fn into_response(self) -> #{SmithyHttpServer}::response::Response { - $outputName::Error(self).into_response() + match #{serialize_error}(&self) { + Ok(mut response) => { + response.extensions_mut().insert(#{SmithyHttpServer}::extension::ModeledErrorExtension::new(self.name())); + response + }, + Err(e) => #{SmithyHttpServer}::response::IntoResponse::<#{Marker}>::into_response(#{RuntimeError}::from(e)) + } } } """.trimIndent(), *codegenScope, - "O" to outputSymbol, "E" to errorSymbol, "Marker" to protocol.markerStruct(), - "serialize_response" to serverSerializeResponse(operationShape), "serialize_error" to serverSerializeError(operationShape), ) - } else { - // The output of non-fallible operations is a model type which we convert into - // a "wrapper" unit `struct` type we control that can in turn be converted into a response. - val intoResponseImpl = - """ - match #{serialize_response}(self.0) { - Ok(response) => response, - Err(e) => #{SmithyHttpServer}::response::IntoResponse::<#{Marker}>::into_response(#{RuntimeError}::from(e)) - } - """.trimIndent() - - rustTemplate( - """ - pub(crate) struct $outputName(#{O}); - - impl $outputName { - pub fn into_response(self) -> #{SmithyHttpServer}::response::Response { - $intoResponseImpl - } - } - - impl #{SmithyHttpServer}::response::IntoResponse<#{Marker}> for #{O} { - fn into_response(self) -> #{SmithyHttpServer}::response::Response { - $outputName(self).into_response() - } - } - """.trimIndent(), - *codegenScope, - "O" to outputSymbol, - "Marker" to protocol.markerStruct(), - "serialize_response" to serverSerializeResponse(operationShape), - ) - } - - // Implement conversion function to "wrap" from the model operation output types. - if (operationShape.operationErrors(model).isNotEmpty()) { - rustTemplate( - """ - impl #{From}> for $outputName { - fn from(res: Result<#{O}, #{E}>) -> Self { - match res { - Ok(v) => Self::Output(v), - Err(e) => Self::Error(e), - } - } - } - """.trimIndent(), - "O" to outputSymbol, - "E" to errorSymbol, - "From" to RuntimeType.From, - ) - } else { - rustTemplate( - """ - impl #{From}<#{O}> for $outputName { - fn from(o: #{O}) -> Self { - Self(o) - } - } - """.trimIndent(), - "O" to outputSymbol, - "From" to RuntimeType.From, - ) } - - // Implement conversion function to "unwrap" into the model operation input types. - rustTemplate( - """ - impl #{From}<$inputName> for #{I} { - fn from(i: $inputName) -> Self { - i.0 - } - } - """.trimIndent(), - "I" to inputSymbol, - "From" to RuntimeType.From, - ) } private fun serverParseRequest(operationShape: OperationShape): RuntimeType { @@ -399,7 +284,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( rustBlockTemplate( """ pub async fn $fnName( - ##[allow(unused_variables)] request: &mut #{SmithyHttpServer}::request::RequestParts + ##[allow(unused_variables)] request: #{http}::Request ) -> std::result::Result< #{I}, #{RequestRejection} @@ -712,12 +597,13 @@ private class ServerHttpBoundProtocolTraitImplGenerator( "let mut input = #T::default();", inputShape.serverBuilderSymbol(codegenContext), ) + Attribute.AllowUnusedVariables.render(this) + rust("let (parts, body) = request.into_parts();") val parser = structuredDataParser.serverInputParser(operationShape) val noInputs = model.expectShape(operationShape.inputShape).expectTrait().originalId == null if (parser != null) { rustTemplate( """ - let body = request.take_body().ok_or(#{RequestRejection}::BodyAlreadyExtracted)?; let bytes = #{Hyper}::body::to_bytes(body).await?; if !bytes.is_empty() { input = #{parser}(bytes.as_ref(), input)?; @@ -755,7 +641,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( conditionalBlock("if body.is_empty() {", "}", conditional = parser != null) { rustTemplate( """ - #{SmithyHttpServer}::protocols::content_type_header_empty_body_no_modeled_input(request)?; + #{SmithyHttpServer}::protocols::content_type_header_empty_body_no_modeled_input(&parts.headers)?; """, *codegenScope, ) @@ -797,7 +683,6 @@ private class ServerHttpBoundProtocolTraitImplGenerator( rustTemplate( """ { - let body = request.take_body().ok_or(#{RequestRejection}::BodyAlreadyExtracted)?; Some(#{Deserializer}(&mut body.into().into_inner())?) } """, @@ -808,7 +693,6 @@ private class ServerHttpBoundProtocolTraitImplGenerator( rustTemplate( """ { - let body = request.take_body().ok_or(#{RequestRejection}::BodyAlreadyExtracted)?; let bytes = #{Hyper}::body::to_bytes(body).await?; #{Deserializer}(&bytes)? } @@ -878,7 +762,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( }, ) with(writer) { - rustTemplate("let input_string = request.uri().path();") + rustTemplate("let input_string = parts.uri.path();") if (greedyLabelIndex >= 0 && greedyLabelIndex + 1 < httpTrait.uri.segments.size) { rustTemplate( """ @@ -963,7 +847,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( with(writer) { rustTemplate( """ - let query_string = request.uri().query().unwrap_or(""); + let query_string = parts.uri.query().unwrap_or(""); let pairs = #{FormUrlEncoded}::parse(query_string.as_bytes()); """.trimIndent(), *codegenScope, @@ -1129,7 +1013,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( val deserializer = httpBindingGenerator.generateDeserializeHeaderFn(binding) writer.rustTemplate( """ - #{deserializer}(request.headers().ok_or(#{RequestRejection}::HeadersAlreadyExtracted)?)? + #{deserializer}(&parts.headers)? """.trimIndent(), "deserializer" to deserializer, *codegenScope, @@ -1143,7 +1027,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( val deserializer = httpBindingGenerator.generateDeserializePrefixHeadersFn(binding) writer.rustTemplate( """ - #{deserializer}(request.headers().ok_or(#{RequestRejection}::HeadersAlreadyExtracted)?)? + #{deserializer}(&parts.headers)? """.trimIndent(), "deserializer" to deserializer, *codegenScope, diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationRegistryGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationRegistryGeneratorTest.kt deleted file mode 100644 index 6103de903b..0000000000 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationRegistryGeneratorTest.kt +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.rust.codegen.server.smithy.generators - -import io.kotest.matchers.string.shouldContain -import org.junit.jupiter.api.Test -import software.amazon.smithy.model.knowledge.TopDownIndex -import software.amazon.smithy.model.shapes.ServiceShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel -import software.amazon.smithy.rust.codegen.core.util.lookup -import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol -import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerProtocolLoader -import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext - -class ServerOperationRegistryGeneratorTest { - private val model = """ - namespace test - - use aws.protocols#restJson1 - - @restJson1 - service Service { - operations: [ - Frobnify, - SayHello, - ], - } - - /// Only the Frobnify operation is documented, - /// over multiple lines. - /// And here are #hash #tags! - @http(method: "GET", uri: "/frobnify") - operation Frobnify { - input: FrobnifyInputOutput, - output: FrobnifyInputOutput, - errors: [FrobnifyFailure] - } - - @http(method: "GET", uri: "/hello") - operation SayHello { - input: SayHelloInputOutput, - output: SayHelloInputOutput, - } - - structure FrobnifyInputOutput {} - structure SayHelloInputOutput {} - - @error("server") - structure FrobnifyFailure {} - """.asSmithyModel() - - @Test - fun `it generates quickstart example`() { - val serviceShape = model.lookup("test#Service") - val (protocolShapeId, protocolGeneratorFactory) = ServerProtocolLoader(ServerProtocolLoader.DefaultProtocols).protocolFor( - model, - serviceShape, - ) - val serverCodegenContext = serverTestCodegenContext( - model, - serviceShape, - protocolShapeId = protocolShapeId, - ) - - val index = TopDownIndex.of(serverCodegenContext.model) - val operations = index.getContainedOperations(serverCodegenContext.serviceShape).sortedBy { it.id } - val protocol = protocolGeneratorFactory.protocol(serverCodegenContext) as ServerProtocol - - val generator = ServerOperationRegistryGenerator(serverCodegenContext, protocol, operations) - val writer = RustWriter.forModule("operation_registry") - generator.render(writer) - - writer.toString() shouldContain - """ - /// ```rust - /// use std::net::SocketAddr; - /// use test_module::{input, output, error}; - /// use test_module::operation_registry::OperationRegistryBuilder; - /// use aws_smithy_http_server::routing::Router; - /// - /// #[tokio::main] - /// pub async fn main() { - /// let app: Router = OperationRegistryBuilder::default() - /// .frobnify(frobnify) - /// .say_hello(say_hello) - /// .build() - /// .expect("unable to build operation registry") - /// .into(); - /// - /// let bind: SocketAddr = format!("{}:{}", "127.0.0.1", "6969") - /// .parse() - /// .expect("unable to parse the server bind address and port"); - /// - /// let server = hyper::Server::bind(&bind).serve(app.into_make_service()); - /// - /// // Run your service! - /// // if let Err(err) = server.await { - /// // eprintln!("server error: {}", err); - /// // } - /// } - /// - /// /// Only the Frobnify operation is documented, - /// /// over multiple lines. - /// /// And here are #hash #tags! - /// async fn frobnify(input: input::FrobnifyInputOutput) -> Result { - /// todo!() - /// } - /// - /// async fn say_hello(input: input::SayHelloInputOutput) -> output::SayHelloInputOutput { - /// todo!() - /// } - /// ``` - /// - """.trimIndent() - } -} diff --git a/rust-runtime/aws-smithy-http-server/src/extension.rs b/rust-runtime/aws-smithy-http-server/src/extension.rs index 5c302a757a..ce265bd7f2 100644 --- a/rust-runtime/aws-smithy-http-server/src/extension.rs +++ b/rust-runtime/aws-smithy-http-server/src/extension.rs @@ -28,8 +28,6 @@ use tower::{layer::util::Stack, Layer, Service}; use crate::operation::{Operation, OperationShape}; use crate::plugin::{plugin_from_operation_name_fn, OperationNameFn, Plugin, PluginPipeline, PluginStack}; -#[allow(deprecated)] -use crate::request::RequestParts; pub use crate::request::extension::{Extension, MissingExtension}; @@ -236,37 +234,6 @@ impl Deref for RuntimeErrorExtension { } } -/// Extract an [`Extension`] from a request. -/// This is essentially the implementation of `FromRequest` for `Extension`, but with a -/// protocol-agnostic rejection type. The actual code-generated implementation simply delegates to -/// this function and converts the rejection type into a [`crate::runtime_error::RuntimeError`]. -#[deprecated( - since = "0.52.0", - note = "This was used for extraction under the older service builder. The `FromParts::from_parts` method is now used instead." -)] -#[allow(deprecated)] -pub async fn extract_extension( - req: &mut RequestParts, -) -> Result, crate::rejection::RequestExtensionNotFoundRejection> -where - T: Clone + Send + Sync + 'static, - B: Send, -{ - let value = req - .extensions() - .ok_or(crate::rejection::RequestExtensionNotFoundRejection::ExtensionsAlreadyExtracted)? - .get::() - .ok_or_else(|| { - crate::rejection::RequestExtensionNotFoundRejection::MissingExtension(format!( - "Extension of type `{}` was not found. Perhaps you forgot to add it?", - std::any::type_name::() - )) - }) - .map(|x| x.clone())?; - - Ok(Extension(value)) -} - #[cfg(test)] mod tests { use super::*; diff --git a/rust-runtime/aws-smithy-http-server/src/lib.rs b/rust-runtime/aws-smithy-http-server/src/lib.rs index 8fefaf242d..804b00d8a3 100644 --- a/rust-runtime/aws-smithy-http-server/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server/src/lib.rs @@ -28,15 +28,10 @@ pub mod routing; #[doc(hidden)] pub mod runtime_error; -#[doc(hidden)] -pub mod routers; - #[doc(inline)] pub(crate) use self::error::Error; -pub use self::request::extension::Extension; #[doc(inline)] -#[allow(deprecated)] -pub use self::routing::Router; +pub use self::request::extension::Extension; #[doc(inline)] pub use tower_http::add_extension::{AddExtension, AddExtensionLayer}; diff --git a/rust-runtime/aws-smithy-http-server/src/proto/aws_json/router.rs b/rust-runtime/aws-smithy-http-server/src/proto/aws_json/router.rs index 67041a62fc..4474dfb600 100644 --- a/rust-runtime/aws-smithy-http-server/src/proto/aws_json/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/proto/aws_json/router.rs @@ -9,9 +9,9 @@ use tower::Layer; use tower::Service; use crate::body::BoxBody; -use crate::routers::Router; use crate::routing::tiny_map::TinyMap; use crate::routing::Route; +use crate::routing::Router; use http::header::ToStrError; use thiserror::Error; @@ -117,3 +117,42 @@ impl FromIterator<(String, S)> for AwsJsonRouter { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{proto::test_helpers::req, routing::Router}; + + use http::{HeaderMap, HeaderValue, Method}; + use pretty_assertions::assert_eq; + + #[tokio::test] + async fn simple_routing() { + let routes = vec![("Service.Operation")]; + let router: AwsJsonRouter<_> = routes + .clone() + .into_iter() + .map(|operation| (operation.to_string(), ())) + .collect(); + + let mut headers = HeaderMap::new(); + headers.insert("x-amz-target", HeaderValue::from_static("Service.Operation")); + + // Valid request, should match. + router + .match_route(&req(&Method::POST, "/", Some(headers.clone()))) + .unwrap(); + + // No headers, should return `MissingHeader`. + let res = router.match_route(&req(&Method::POST, "/", None)); + assert_eq!(res.unwrap_err().to_string(), Error::MissingHeader.to_string()); + + // Wrong HTTP method, should return `MethodNotAllowed`. + let res = router.match_route(&req(&Method::GET, "/", Some(headers.clone()))); + assert_eq!(res.unwrap_err().to_string(), Error::MethodNotAllowed.to_string()); + + // Wrong URI, should return `NotRootUrl`. + let res = router.match_route(&req(&Method::POST, "/something", Some(headers))); + assert_eq!(res.unwrap_err().to_string(), Error::NotRootUrl.to_string()); + } +} diff --git a/rust-runtime/aws-smithy-http-server/src/proto/aws_json_10/router.rs b/rust-runtime/aws-smithy-http-server/src/proto/aws_json_10/router.rs index 5c582f569c..31d5ce8a9e 100644 --- a/rust-runtime/aws-smithy-http-server/src/proto/aws_json_10/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/proto/aws_json_10/router.rs @@ -7,7 +7,7 @@ use crate::body::{empty, BoxBody}; use crate::extension::RuntimeErrorExtension; use crate::proto::aws_json::router::Error; use crate::response::IntoResponse; -use crate::routers::{method_disallowed, UNKNOWN_OPERATION_EXCEPTION}; +use crate::routing::{method_disallowed, UNKNOWN_OPERATION_EXCEPTION}; use super::AwsJson1_0; diff --git a/rust-runtime/aws-smithy-http-server/src/proto/aws_json_11/router.rs b/rust-runtime/aws-smithy-http-server/src/proto/aws_json_11/router.rs index 8d0f8c0a06..18f3b4b329 100644 --- a/rust-runtime/aws-smithy-http-server/src/proto/aws_json_11/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/proto/aws_json_11/router.rs @@ -7,7 +7,7 @@ use crate::body::{empty, BoxBody}; use crate::extension::RuntimeErrorExtension; use crate::proto::aws_json::router::Error; use crate::response::IntoResponse; -use crate::routers::{method_disallowed, UNKNOWN_OPERATION_EXCEPTION}; +use crate::routing::{method_disallowed, UNKNOWN_OPERATION_EXCEPTION}; use super::AwsJson1_1; diff --git a/rust-runtime/aws-smithy-http-server/src/proto/mod.rs b/rust-runtime/aws-smithy-http-server/src/proto/mod.rs index 39344ab4f1..26fb17d893 100644 --- a/rust-runtime/aws-smithy-http-server/src/proto/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/proto/mod.rs @@ -9,3 +9,27 @@ pub mod aws_json_11; pub mod rest; pub mod rest_json_1; pub mod rest_xml; + +#[cfg(test)] +pub mod test_helpers { + use http::{HeaderMap, Method, Request}; + + /// Helper function to build a `Request`. Used in other test modules. + pub fn req(method: &Method, uri: &str, headers: Option) -> Request<()> { + let mut r = Request::builder().method(method).uri(uri).body(()).unwrap(); + if let Some(headers) = headers { + *r.headers_mut() = headers + } + r + } + + // Returns a `Response`'s body as a `String`, without consuming the response. + pub async fn get_body_as_string(body: B) -> String + where + B: http_body::Body + std::marker::Unpin, + B::Error: std::fmt::Debug, + { + let body_bytes = hyper::body::to_bytes(body).await.unwrap(); + String::from(std::str::from_utf8(&body_bytes).unwrap()) + } +} diff --git a/rust-runtime/aws-smithy-http-server/src/proto/rest/router.rs b/rust-runtime/aws-smithy-http-server/src/proto/rest/router.rs index 4c73eda72b..1d55f676d6 100644 --- a/rust-runtime/aws-smithy-http-server/src/proto/rest/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/proto/rest/router.rs @@ -6,17 +6,17 @@ use std::convert::Infallible; use crate::body::BoxBody; -use crate::routers::Router; use crate::routing::request_spec::Match; use crate::routing::request_spec::RequestSpec; use crate::routing::Route; +use crate::routing::Router; use tower::Layer; use tower::Service; use thiserror::Error; /// An AWS REST routing error. -#[derive(Debug, Error)] +#[derive(Debug, Error, PartialEq)] pub enum Error { /// Operation not found. #[error("operation not found")] @@ -108,3 +108,166 @@ impl FromIterator<(RequestSpec, S)> for RestRouter { Self { routes } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{proto::test_helpers::req, routing::request_spec::*}; + + use http::Method; + + // This test is a rewrite of `mux.spec.ts`. + // https://github.com/awslabs/smithy-typescript/blob/fbf97a9bf4c1d8cf7f285ea7c24e1f0ef280142a/smithy-typescript-ssdk-libs/server-common/src/httpbinding/mux.spec.ts + #[test] + fn simple_routing() { + let request_specs: Vec<(RequestSpec, &'static str)> = vec![ + ( + RequestSpec::from_parts( + Method::GET, + vec![ + PathSegment::Literal(String::from("a")), + PathSegment::Label, + PathSegment::Label, + ], + Vec::new(), + ), + "A", + ), + ( + RequestSpec::from_parts( + Method::GET, + vec![ + PathSegment::Literal(String::from("mg")), + PathSegment::Greedy, + PathSegment::Literal(String::from("z")), + ], + Vec::new(), + ), + "MiddleGreedy", + ), + ( + RequestSpec::from_parts( + Method::DELETE, + Vec::new(), + vec![ + QuerySegment::KeyValue(String::from("foo"), String::from("bar")), + QuerySegment::Key(String::from("baz")), + ], + ), + "Delete", + ), + ( + RequestSpec::from_parts( + Method::POST, + vec![PathSegment::Literal(String::from("query_key_only"))], + vec![QuerySegment::Key(String::from("foo"))], + ), + "QueryKeyOnly", + ), + ]; + + // Test both RestJson1 and RestXml routers. + let router: RestRouter<_> = request_specs + .into_iter() + .map(|(spec, svc_name)| (spec, svc_name)) + .collect(); + + let hits = vec![ + ("A", Method::GET, "/a/b/c"), + ("MiddleGreedy", Method::GET, "/mg/a/z"), + ("MiddleGreedy", Method::GET, "/mg/a/b/c/d/z?abc=def"), + ("Delete", Method::DELETE, "/?foo=bar&baz=quux"), + ("Delete", Method::DELETE, "/?foo=bar&baz"), + ("Delete", Method::DELETE, "/?foo=bar&baz=&"), + ("Delete", Method::DELETE, "/?foo=bar&baz=quux&baz=grault"), + ("QueryKeyOnly", Method::POST, "/query_key_only?foo=bar"), + ("QueryKeyOnly", Method::POST, "/query_key_only?foo"), + ("QueryKeyOnly", Method::POST, "/query_key_only?foo="), + ("QueryKeyOnly", Method::POST, "/query_key_only?foo=&"), + ]; + for (svc_name, method, uri) in &hits { + assert_eq!(router.match_route(&req(method, uri, None)).unwrap(), *svc_name); + } + + for (_, _, uri) in hits { + let res = router.match_route(&req(&Method::PATCH, uri, None)); + assert_eq!(res.unwrap_err(), Error::MethodNotAllowed); + } + + let misses = vec![ + (Method::GET, "/a"), + (Method::GET, "/a/b"), + (Method::GET, "/mg"), + (Method::GET, "/mg/q"), + (Method::GET, "/mg/z"), + (Method::GET, "/mg/a/b/z/c"), + (Method::DELETE, "/?foo=bar"), + (Method::DELETE, "/?foo=bar"), + (Method::DELETE, "/?baz=quux"), + (Method::POST, "/query_key_only?baz=quux"), + (Method::GET, "/"), + (Method::POST, "/"), + ]; + for (method, miss) in misses { + let res = router.match_route(&req(&method, miss, None)); + assert_eq!(res.unwrap_err(), Error::NotFound); + } + } + + #[tokio::test] + async fn basic_pattern_conflict_avoidance() { + let request_specs: Vec<(RequestSpec, &'static str)> = vec![ + ( + RequestSpec::from_parts( + Method::GET, + vec![PathSegment::Literal(String::from("a")), PathSegment::Label], + Vec::new(), + ), + "A1", + ), + ( + RequestSpec::from_parts( + Method::GET, + vec![ + PathSegment::Literal(String::from("a")), + PathSegment::Label, + PathSegment::Literal(String::from("a")), + ], + Vec::new(), + ), + "A2", + ), + ( + RequestSpec::from_parts( + Method::GET, + vec![PathSegment::Literal(String::from("b")), PathSegment::Greedy], + Vec::new(), + ), + "B1", + ), + ( + RequestSpec::from_parts( + Method::GET, + vec![PathSegment::Literal(String::from("b")), PathSegment::Greedy], + vec![QuerySegment::Key(String::from("q"))], + ), + "B2", + ), + ]; + + let router: RestRouter<_> = request_specs + .into_iter() + .map(|(spec, svc_name)| (spec, svc_name)) + .collect(); + + let hits = vec![ + ("A1", Method::GET, "/a/foo"), + ("A2", Method::GET, "/a/foo/a"), + ("B1", Method::GET, "/b/foo/bar/baz"), + ("B2", Method::GET, "/b/foo?q=baz"), + ]; + for (svc_name, method, uri) in hits { + assert_eq!(router.match_route(&req(&method, uri, None)).unwrap(), svc_name); + } + } +} diff --git a/rust-runtime/aws-smithy-http-server/src/proto/rest_json_1/router.rs b/rust-runtime/aws-smithy-http-server/src/proto/rest_json_1/router.rs index c737b665c4..189658d317 100644 --- a/rust-runtime/aws-smithy-http-server/src/proto/rest_json_1/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/proto/rest_json_1/router.rs @@ -7,7 +7,7 @@ use crate::body::BoxBody; use crate::extension::RuntimeErrorExtension; use crate::proto::rest::router::Error; use crate::response::IntoResponse; -use crate::routers::{method_disallowed, UNKNOWN_OPERATION_EXCEPTION}; +use crate::routing::{method_disallowed, UNKNOWN_OPERATION_EXCEPTION}; use super::RestJson1; diff --git a/rust-runtime/aws-smithy-http-server/src/proto/rest_xml/router.rs b/rust-runtime/aws-smithy-http-server/src/proto/rest_xml/router.rs index 1b1b21742f..b771884b04 100644 --- a/rust-runtime/aws-smithy-http-server/src/proto/rest_xml/router.rs +++ b/rust-runtime/aws-smithy-http-server/src/proto/rest_xml/router.rs @@ -8,8 +8,7 @@ use crate::body::BoxBody; use crate::extension::RuntimeErrorExtension; use crate::proto::rest::router::Error; use crate::response::IntoResponse; -use crate::routers::method_disallowed; -use crate::routers::UNKNOWN_OPERATION_EXCEPTION; +use crate::routing::{method_disallowed, UNKNOWN_OPERATION_EXCEPTION}; use super::RestXml; diff --git a/rust-runtime/aws-smithy-http-server/src/protocols.rs b/rust-runtime/aws-smithy-http-server/src/protocols.rs index bf56eea89f..2267c0384c 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocols.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocols.rs @@ -5,20 +5,11 @@ //! Protocol helpers. use crate::rejection::MissingContentTypeReason; -#[allow(deprecated)] -use crate::request::RequestParts; use http::HeaderMap; /// When there are no modeled inputs, /// a request body is empty and the content-type request header must not be set -#[allow(deprecated)] -pub fn content_type_header_empty_body_no_modeled_input( - req: &RequestParts, -) -> Result<(), MissingContentTypeReason> { - if req.headers().is_none() { - return Ok(()); - } - let headers = req.headers().unwrap(); +pub fn content_type_header_empty_body_no_modeled_input(headers: &HeaderMap) -> Result<(), MissingContentTypeReason> { if headers.contains_key(http::header::CONTENT_TYPE) { let found_mime = parse_content_type(headers)?; Err(MissingContentTypeReason::UnexpectedMimeType { @@ -42,15 +33,10 @@ fn parse_content_type(headers: &HeaderMap) -> Result( - req: &RequestParts, +pub fn content_type_header_classifier( + headers: &HeaderMap, expected_content_type: Option<&'static str>, ) -> Result<(), MissingContentTypeReason> { - // Allow no CONTENT-TYPE header - if req.headers().is_none() { - return Ok(()); - } - let headers = req.headers().unwrap(); // Headers are present, `unwrap` will not panic. if !headers.contains_key(http::header::CONTENT_TYPE) { return Ok(()); } @@ -79,12 +65,7 @@ pub fn content_type_header_classifier( } #[allow(deprecated)] -pub fn accept_header_classifier(req: &RequestParts, content_type: &'static str) -> bool { - // Allow no ACCEPT header - if req.headers().is_none() { - return true; - } - let headers = req.headers().unwrap(); +pub fn accept_header_classifier(headers: &HeaderMap, content_type: &'static str) -> bool { if !headers.contains_key(http::header::ACCEPT) { return true; } @@ -126,28 +107,25 @@ pub fn accept_header_classifier(req: &RequestParts, content_type: &'static #[cfg(test)] mod tests { use super::*; - use http::Request; + use http::header::{HeaderValue, ACCEPT, CONTENT_TYPE}; - fn req_content_type(content_type: &str) -> RequestParts<&str> { - let request = Request::builder() - .header("content-type", content_type) - .body("") - .unwrap(); - RequestParts::new(request) + fn req_content_type(content_type: &'static str) -> HeaderMap { + let mut headers = HeaderMap::new(); + headers.insert(CONTENT_TYPE, HeaderValue::from_str(content_type).unwrap()); + headers } - fn req_accept(content_type: &str) -> RequestParts<&str> { - let request = Request::builder().header("accept", content_type).body("").unwrap(); - RequestParts::new(request) + fn req_accept(accept: &'static str) -> HeaderMap { + let mut headers = HeaderMap::new(); + headers.insert(ACCEPT, HeaderValue::from_static(accept)); + headers } const EXPECTED_MIME_APPLICATION_JSON: Option<&'static str> = Some("application/json"); #[test] fn check_content_type_header_empty_body_no_modeled_input() { - let request = Request::builder().body("").unwrap(); - let request = RequestParts::new(request); - assert!(content_type_header_empty_body_no_modeled_input(&request).is_ok()); + assert!(content_type_header_empty_body_no_modeled_input(&HeaderMap::new()).is_ok()); } #[test] @@ -193,8 +171,7 @@ mod tests { #[test] fn check_missing_content_type_is_allowed() { - let request = RequestParts::new(Request::builder().body("").unwrap()); - let result = content_type_header_classifier(&request, EXPECTED_MIME_APPLICATION_JSON); + let result = content_type_header_classifier(&HeaderMap::new(), EXPECTED_MIME_APPLICATION_JSON); assert!(result.is_ok()); } @@ -241,9 +218,7 @@ mod tests { #[test] fn valid_empty_accept_header_classifier() { - let valid_request = Request::builder().body("").unwrap(); - let valid_request = RequestParts::new(valid_request); - assert!(accept_header_classifier(&valid_request, "application/json")); + assert!(accept_header_classifier(&HeaderMap::new(), "application/json")); } #[test] diff --git a/rust-runtime/aws-smithy-http-server/src/rejection.rs b/rust-runtime/aws-smithy-http-server/src/rejection.rs index 65fb8e2e62..f01344f765 100644 --- a/rust-runtime/aws-smithy-http-server/src/rejection.rs +++ b/rust-runtime/aws-smithy-http-server/src/rejection.rs @@ -9,13 +9,11 @@ //! handle requests and responses that return `Result` throughout the framework. These //! include functions to deserialize incoming requests and serialize outgoing responses. //! -//! All types end with `Rejection`. There are three types: +//! All types end with `Rejection`. There are two types: //! //! 1. [`RequestRejection`]s are used when the framework fails to deserialize the request into the //! corresponding operation input. -//! 1. [`RequestExtensionNotFoundRejection`]s are used when the framework fails to deserialize from -//! the request's extensions a particular [`crate::Extension`] that was expected to be found. -//! 1. [`ResponseRejection`]s are used when the framework fails to serialize the operation +//! 2. [`ResponseRejection`]s are used when the framework fails to serialize the operation //! output into a response. //! //! They are called _rejection_ types and not _error_ types to signal that the input was _rejected_ @@ -41,35 +39,10 @@ //! [`crate::runtime_error::RuntimeError`], thus allowing us to represent the full //! error chain. -// For some reason `deprecated(deprecated)` warns of its own deprecation. Putting `allow(deprecated)` at the module -// level remedies it. -#![allow(deprecated)] - use strum_macros::Display; use crate::response::IntoResponse; -/// Rejection used for when failing to extract an [`crate::Extension`] from an incoming [request's -/// extensions]. Contains one variant for each way the extractor can fail. -/// -/// [request's extensions]: https://docs.rs/http/latest/http/struct.Extensions.html -#[deprecated( - since = "0.52.0", - note = "This was used for extraction under the older service builder. The `MissingExtension` struct returned by `FromParts::from_parts` is now used." -)] -#[derive(Debug, Display)] -pub enum RequestExtensionNotFoundRejection { - /// Used when a particular [`crate::Extension`] was expected to be found in the request but we - /// did not find it. - /// This most likely means the service implementer simply forgot to add a [`tower::Layer`] that - /// registers the particular extension in their service to incoming requests. - MissingExtension(String), - // Used when the request extensions have already been taken by another extractor. - ExtensionsAlreadyExtracted, -} - -impl std::error::Error for RequestExtensionNotFoundRejection {} - /// Errors that can occur when serializing the operation output provided by the service implementer /// into an HTTP response. #[derive(Debug, Display)] @@ -104,8 +77,7 @@ convert_to_response_rejection!(aws_smithy_http::operation::error::SerializationE convert_to_response_rejection!(http::Error, Http); /// Errors that can occur when deserializing an HTTP request into an _operation input_, the input -/// that is passed as the first argument to operation handlers. To deserialize into the service's -/// registered state, a different rejection type is used, [`RequestExtensionNotFoundRejection`]. +/// that is passed as the first argument to operation handlers. /// /// This type allows us to easily keep track of all the possible errors that can occur in the /// lifecycle of an incoming HTTP request. @@ -131,10 +103,6 @@ convert_to_response_rejection!(http::Error, Http); // The variants are _roughly_ sorted in the order in which the HTTP request is processed. #[derive(Debug, Display)] pub enum RequestRejection { - /// Used when attempting to take the request's body, and it has already been taken (presumably - /// by an outer `Service` that handled the request before us). - BodyAlreadyExtracted, - /// Used when failing to convert non-streaming requests into a byte slab with /// `hyper::body::to_bytes`. HttpBody(crate::Error), @@ -149,10 +117,6 @@ pub enum RequestRejection { /// input it should represent. XmlDeserialize(crate::Error), - /// Used when attempting to take the request's headers, and they have already been taken (presumably - /// by an outer `Service` that handled the request before us). - HeadersAlreadyExtracted, - /// Used when failing to parse HTTP headers that are bound to input members with the `httpHeader` /// or the `httpPrefixHeaders` traits. HeaderParse(crate::Error), diff --git a/rust-runtime/aws-smithy-http-server/src/request/mod.rs b/rust-runtime/aws-smithy-http-server/src/request/mod.rs index 1b08c63289..be507b804f 100644 --- a/rust-runtime/aws-smithy-http-server/src/request/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/request/mod.rs @@ -54,7 +54,7 @@ use futures_util::{ future::{try_join, MapErr, MapOk, TryJoin}, TryFutureExt, }; -use http::{request::Parts, Extensions, HeaderMap, Request, StatusCode, Uri}; +use http::{request::Parts, Request, StatusCode}; use crate::{ body::{empty, BoxBody}, @@ -77,77 +77,6 @@ fn internal_server_error() -> http::Response { response } -#[doc(hidden)] -#[deprecated( - since = "0.52.0", - note = "This is not used by the new service builder. We use the `http::Parts` struct directly." -)] -#[derive(Debug)] -pub struct RequestParts { - uri: Uri, - headers: Option, - extensions: Option, - body: Option, -} - -#[allow(deprecated)] -impl RequestParts { - /// Create a new `RequestParts`. - /// - /// You generally shouldn't need to construct this type yourself, unless - /// using extractors outside of axum for example to implement a - /// [`tower::Service`]. - /// - /// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html - #[doc(hidden)] - pub fn new(req: Request) -> Self { - let ( - Parts { - uri, - headers, - extensions, - .. - }, - body, - ) = req.into_parts(); - - RequestParts { - uri, - headers: Some(headers), - extensions: Some(extensions), - body: Some(body), - } - } - - /// Gets a reference to the request headers. - /// - /// Returns `None` if the headers has been taken by another extractor. - #[doc(hidden)] - pub fn headers(&self) -> Option<&HeaderMap> { - self.headers.as_ref() - } - - /// Takes the body out of the request, leaving a `None` in its place. - #[doc(hidden)] - pub fn take_body(&mut self) -> Option { - self.body.take() - } - - /// Gets a reference the request URI. - #[doc(hidden)] - pub fn uri(&self) -> &Uri { - &self.uri - } - - /// Gets a reference to the request extensions. - /// - /// Returns `None` if the extensions has been taken by another extractor. - #[doc(hidden)] - pub fn extensions(&self) -> Option<&Extensions> { - self.extensions.as_ref() - } -} - /// Provides a protocol aware extraction from a [`Request`]. This borrows the [`Parts`], in contrast to /// [`FromRequest`] which consumes the entire [`http::Request`] including the body. pub trait FromParts: Sized { diff --git a/rust-runtime/aws-smithy-http-server/src/routers.rs b/rust-runtime/aws-smithy-http-server/src/routers.rs deleted file mode 100644 index ecffe36e0c..0000000000 --- a/rust-runtime/aws-smithy-http-server/src/routers.rs +++ /dev/null @@ -1,176 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -use std::{ - error::Error, - fmt, - future::{ready, Future, Ready}, - marker::PhantomData, - pin::Pin, - task::{Context, Poll}, -}; - -use bytes::Bytes; -use futures_util::{ - future::{Either, MapOk}, - TryFutureExt, -}; -use http::Response; -use http_body::Body as HttpBody; -use tower::{util::Oneshot, Service, ServiceExt}; -use tracing::debug; - -use crate::{ - body::{boxed, BoxBody}, - error::BoxError, - response::IntoResponse, -}; - -pub(crate) const UNKNOWN_OPERATION_EXCEPTION: &str = "UnknownOperationException"; - -/// Constructs common response to method disallowed. -pub(crate) fn method_disallowed() -> http::Response { - let mut responses = http::Response::default(); - *responses.status_mut() = http::StatusCode::METHOD_NOT_ALLOWED; - responses -} - -/// An interface for retrieving an inner [`Service`] given a [`http::Request`]. -pub trait Router { - type Service; - type Error; - - /// Matches a [`http::Request`] to a target [`Service`]. - fn match_route(&self, request: &http::Request) -> Result; -} - -/// A [`Service`] using the [`Router`] `R` to redirect messages to specific routes. -/// -/// The `Protocol` parameter is used to determine the serialization of errors. -pub struct RoutingService { - router: R, - _protocol: PhantomData, -} - -impl fmt::Debug for RoutingService -where - R: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("RoutingService") - .field("router", &self.router) - .field("_protocol", &self._protocol) - .finish() - } -} - -impl Clone for RoutingService -where - R: Clone, -{ - fn clone(&self) -> Self { - Self { - router: self.router.clone(), - _protocol: PhantomData, - } - } -} - -impl RoutingService { - /// Creates a [`RoutingService`] from a [`Router`]. - pub fn new(router: R) -> Self { - Self { - router, - _protocol: PhantomData, - } - } - - /// Maps a [`Router`] using a closure. - pub fn map(self, f: F) -> RoutingService - where - F: FnOnce(R) -> RNew, - { - RoutingService { - router: f(self.router), - _protocol: PhantomData, - } - } -} - -type EitherOneshotReady = Either< - MapOk>, fn(>>::Response) -> http::Response>, - Ready, >>::Error>>, ->; - -pin_project_lite::pin_project! { - pub struct RoutingFuture where S: Service> { - #[pin] - inner: EitherOneshotReady - } -} - -impl RoutingFuture -where - S: Service>, -{ - /// Creates a [`RoutingFuture`] from [`ServiceExt::oneshot`]. - pub(super) fn from_oneshot(future: Oneshot>) -> Self - where - S: Service, Response = http::Response>, - RespB: HttpBody + Send + 'static, - RespB::Error: Into, - { - Self { - inner: Either::Left(future.map_ok(|x| x.map(boxed))), - } - } - - /// Creates a [`RoutingFuture`] from [`Service::Response`]. - pub(super) fn from_response(response: http::Response) -> Self { - Self { - inner: Either::Right(ready(Ok(response))), - } - } -} - -impl Future for RoutingFuture -where - S: Service>, -{ - type Output = Result, S::Error>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.project().inner.poll(cx) - } -} - -impl Service> for RoutingService -where - R: Router, - R::Service: Service, Response = http::Response> + Clone, - R::Error: IntoResponse

+ Error, - RespB: HttpBody + Send + 'static, - RespB::Error: Into, -{ - type Response = Response; - type Error = >>::Error; - type Future = RoutingFuture; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: http::Request) -> Self::Future { - match self.router.match_route(&req) { - // Successfully routed, use the routes `Service::call`. - Ok(ok) => RoutingFuture::from_oneshot(ok.oneshot(req)), - // Failed to route, use the `R::Error`s `IntoResponse

`. - Err(error) => { - debug!(%error, "failed to route"); - RoutingFuture::from_response(error.into_response()) - } - } - } -} diff --git a/rust-runtime/aws-smithy-http-server/src/routing/future.rs b/rust-runtime/aws-smithy-http-server/src/routing/future.rs deleted file mode 100644 index dcb1d83c53..0000000000 --- a/rust-runtime/aws-smithy-http-server/src/routing/future.rs +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -// This code was copied and then modified from Tokio's Axum. - -/* Copyright (c) 2021 Tower Contributors - * - * Permission is hereby granted, free of charge, to any - * person obtaining a copy of this software and associated - * documentation files (the "Software"), to deal in the - * Software without restriction, including without - * limitation the rights to use, copy, modify, merge, - * publish, distribute, sublicense, and/or sell copies of - * the Software, and to permit persons to whom the Software - * is furnished to do so, subject to the following - * conditions: - * - * The above copyright notice and this permission notice - * shall be included in all copies or substantial portions - * of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF - * ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED - * TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A - * PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT - * SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION - * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR - * IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER - * DEALINGS IN THE SOFTWARE. - */ - -#![allow(deprecated)] - -//! Future types. - -use crate::routers::RoutingFuture; - -use super::Route; -pub use super::{into_make_service::IntoMakeService, route::RouteFuture}; - -opaque_future! { - #[deprecated( - since = "0.52.0", - note = "`OperationRegistry` is part of the deprecated service builder API. This type no longer appears in the public API." - )] - /// Response future for [`Router`](super::Router). - pub type RouterFuture = RoutingFuture, B>; -} diff --git a/rust-runtime/aws-smithy-http-server/src/routing/mod.rs b/rust-runtime/aws-smithy-http-server/src/routing/mod.rs index 2ff9c82ef5..4f0cb8c0fa 100644 --- a/rust-runtime/aws-smithy-http-server/src/routing/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/routing/mod.rs @@ -7,27 +7,6 @@ //! //! [Smithy specification]: https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html -use std::{ - convert::Infallible, - task::{Context, Poll}, -}; - -use self::request_spec::RequestSpec; -use crate::{ - body::{boxed, Body, BoxBody, HttpBody}, - proto::{ - aws_json::router::AwsJsonRouter, aws_json_10::AwsJson1_0, aws_json_11::AwsJson1_1, rest::router::RestRouter, - rest_json_1::RestJson1, rest_xml::RestXml, - }, -}; -use crate::{error::BoxError, routers::RoutingService}; - -use http::{Request, Response}; -use tower::layer::Layer; -use tower::{Service, ServiceBuilder}; -use tower_http::map_response_body::MapResponseBodyLayer; - -mod future; mod into_make_service; mod into_make_service_with_connect_info; #[cfg(feature = "aws-lambda")] @@ -41,571 +20,185 @@ mod route; pub(crate) mod tiny_map; +use std::{ + error::Error, + fmt, + future::{ready, Future, Ready}, + marker::PhantomData, + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use futures_util::{ + future::{Either, MapOk}, + TryFutureExt, +}; +use http::Response; +use http_body::Body as HttpBody; +use tower::{util::Oneshot, Service, ServiceExt}; +use tracing::debug; + +use crate::{ + body::{boxed, BoxBody}, + error::BoxError, + response::IntoResponse, +}; + #[cfg(feature = "aws-lambda")] #[cfg_attr(docsrs, doc(cfg(feature = "aws-lambda")))] pub use self::lambda_handler::LambdaHandler; #[allow(deprecated)] pub use self::{ - future::RouterFuture, into_make_service::IntoMakeService, into_make_service_with_connect_info::{Connected, IntoMakeServiceWithConnectInfo}, route::Route, }; -/// The router is a [`tower::Service`] that routes incoming requests to other `Service`s -/// based on the request's URI and HTTP method or on some specific header setting the target operation. -/// The former is adhering to the [Smithy specification], while the latter is adhering to -/// the [AwsJson specification]. -/// -/// The router is also [Protocol] aware and currently supports REST based protocols like [restJson1] or [restXml] -/// and RPC based protocols like [awsJson1.0] or [awsJson1.1]. -/// It currently does not support Smithy's [endpoint trait]. -/// -/// You should not **instantiate** this router directly; it will be created for you from the -/// code generated from your Smithy model by `smithy-rs`. -/// -/// [Smithy specification]: https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html -/// [AwsJson specification]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#protocol-behaviors -/// [Protocol]: https://awslabs.github.io/smithy/1.0/spec/aws/index.html#aws-protocols -/// [restJson1]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-restjson1-protocol.html -/// [restXml]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-restxml-protocol.html -/// [awsJson1.0]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html -/// [awsJson1.1]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html -/// [endpoint trait]: https://awslabs.github.io/smithy/1.0/spec/core/endpoint-traits.html#endpoint-trait -#[derive(Debug)] -#[deprecated( - since = "0.52.0", - note = "`OperationRegistry` is part of the deprecated service builder API. This type no longer appears in the public API." -)] -pub struct Router { - routes: Routes, +pub(crate) const UNKNOWN_OPERATION_EXCEPTION: &str = "UnknownOperationException"; + +/// Constructs common response to method disallowed. +pub(crate) fn method_disallowed() -> http::Response { + let mut responses = http::Response::default(); + *responses.status_mut() = http::StatusCode::METHOD_NOT_ALLOWED; + responses } -/// Protocol-aware routes types. -/// -/// RestJson1 and RestXml routes are stored in a `Vec` because there can be multiple matches on the -/// request URI and we thus need to iterate the whole list and use a ranking mechanism to choose. -/// -/// AwsJson 1.0 and 1.1 routes can be stored in a `HashMap` since the requested operation can be -/// directly found in the `X-Amz-Target` HTTP header. -#[derive(Debug)] -enum Routes { - RestXml(RoutingService>, RestXml>), - RestJson1(RoutingService>, RestJson1>), - AwsJson1_0(RoutingService>, AwsJson1_0>), - AwsJson1_1(RoutingService>, AwsJson1_1>), +/// An interface for retrieving an inner [`Service`] given a [`http::Request`]. +pub trait Router { + type Service; + type Error; + + /// Matches a [`http::Request`] to a target [`Service`]. + fn match_route(&self, request: &http::Request) -> Result; } -#[allow(deprecated)] -impl Clone for Router { - fn clone(&self) -> Self { - match &self.routes { - Routes::RestJson1(routes) => Router { - routes: Routes::RestJson1(routes.clone()), - }, - Routes::RestXml(routes) => Router { - routes: Routes::RestXml(routes.clone()), - }, - Routes::AwsJson1_0(routes) => Router { - routes: Routes::AwsJson1_0(routes.clone()), - }, - Routes::AwsJson1_1(routes) => Router { - routes: Routes::AwsJson1_1(routes.clone()), - }, - } - } +/// A [`Service`] using the [`Router`] `R` to redirect messages to specific routes. +/// +/// The `Protocol` parameter is used to determine the serialization of errors. +pub struct RoutingService { + router: R, + _protocol: PhantomData, } -#[allow(deprecated)] -impl Router +impl fmt::Debug for RoutingService where - B: Send + 'static, + R: fmt::Debug, { - /// Convert this router into a [`MakeService`], that is a [`Service`] whose - /// response is another service. - /// - /// This is useful when running your application with hyper's - /// [`Server`]. - /// - /// [`Server`]: hyper::server::Server - /// [`MakeService`]: tower::make::MakeService - pub fn into_make_service(self) -> IntoMakeService { - IntoMakeService::new(self) - } - - /// Apply a [`tower::Layer`] to the router. - /// - /// All requests to the router will be processed by the layer's - /// corresponding middleware. - /// - /// This can be used to add additional processing to all routes. - pub fn layer(self, layer: L) -> Router - where - L: Layer>, - L::Service: - Service, Response = Response, Error = Infallible> + Clone + Send + 'static, - >>::Future: Send + 'static, - NewResBody: HttpBody + Send + 'static, - NewResBody::Error: Into, - { - let layer = ServiceBuilder::new() - .layer(MapResponseBodyLayer::new(boxed)) - .layer(layer); - match self.routes { - Routes::RestJson1(routes) => Router { - routes: Routes::RestJson1(routes.map(|router| router.layer(layer).boxed())), - }, - Routes::RestXml(routes) => Router { - routes: Routes::RestXml(routes.map(|router| router.layer(layer).boxed())), - }, - Routes::AwsJson1_0(routes) => Router { - routes: Routes::AwsJson1_0(routes.map(|router| router.layer(layer).boxed())), - }, - Routes::AwsJson1_1(routes) => Router { - routes: Routes::AwsJson1_1(routes.map(|router| router.layer(layer).boxed())), - }, - } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RoutingService") + .field("router", &self.router) + .field("_protocol", &self._protocol) + .finish() } +} - /// Create a new RestJson1 `Router` from an iterator over pairs of [`RequestSpec`]s and services. - /// - /// If the iterator is empty the router will respond `404 Not Found` to all requests. - #[doc(hidden)] - pub fn new_rest_json_router(routes: T) -> Self - where - T: IntoIterator< - Item = ( - tower::util::BoxCloneService, Response, Infallible>, - RequestSpec, - ), - >, - { - let svc = RoutingService::new( - routes - .into_iter() - .map(|(svc, request_spec)| (request_spec, Route::from_box_clone_service(svc))) - .collect(), - ); +impl Clone for RoutingService +where + R: Clone, +{ + fn clone(&self) -> Self { Self { - routes: Routes::RestJson1(svc), + router: self.router.clone(), + _protocol: PhantomData, } } +} - /// Create a new RestXml `Router` from an iterator over pairs of [`RequestSpec`]s and services. - /// - /// If the iterator is empty the router will respond `404 Not Found` to all requests. - #[doc(hidden)] - pub fn new_rest_xml_router(routes: T) -> Self - where - T: IntoIterator< - Item = ( - tower::util::BoxCloneService, Response, Infallible>, - RequestSpec, - ), - >, - { - let svc = RoutingService::new( - routes - .into_iter() - .map(|(svc, request_spec)| (request_spec, Route::from_box_clone_service(svc))) - .collect(), - ); +impl RoutingService { + /// Creates a [`RoutingService`] from a [`Router`]. + pub fn new(router: R) -> Self { Self { - routes: Routes::RestXml(svc), + router, + _protocol: PhantomData, } } - /// Create a new AwsJson 1.0 `Router` from an iterator over pairs of operation names and services. - /// - /// If the iterator is empty the router will respond `404 Not Found` to all requests. - #[doc(hidden)] - pub fn new_aws_json_10_router(routes: T) -> Self + /// Maps a [`Router`] using a closure. + pub fn map(self, f: F) -> RoutingService where - T: IntoIterator< - Item = ( - tower::util::BoxCloneService, Response, Infallible>, - String, - ), - >, + F: FnOnce(R) -> RNew, { - let svc = RoutingService::new( - routes - .into_iter() - .map(|(svc, operation)| (operation, Route::from_box_clone_service(svc))) - .collect(), - ); - - Self { - routes: Routes::AwsJson1_0(svc), + RoutingService { + router: f(self.router), + _protocol: PhantomData, } } +} - /// Create a new AwsJson 1.1 `Router` from a vector of pairs of operations and services. - /// - /// If the vector is empty the router will respond `404 Not Found` to all requests. - #[doc(hidden)] - pub fn new_aws_json_11_router(routes: T) -> Self - where - T: IntoIterator< - Item = ( - tower::util::BoxCloneService, Response, Infallible>, - String, - ), - >, - { - let svc = RoutingService::new( - routes - .into_iter() - .map(|(svc, operation)| (operation, Route::from_box_clone_service(svc))) - .collect(), - ); +type EitherOneshotReady = Either< + MapOk>, fn(>>::Response) -> http::Response>, + Ready, >>::Error>>, +>; - Self { - routes: Routes::AwsJson1_1(svc), - } +pin_project_lite::pin_project! { + pub struct RoutingFuture where S: Service> { + #[pin] + inner: EitherOneshotReady } } -#[allow(deprecated)] -impl Service> for Router +impl RoutingFuture where - B: Send + 'static, + S: Service>, { - type Response = Response; - type Error = Infallible; - type Future = RouterFuture; - - #[inline] - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - #[inline] - fn call(&mut self, req: Request) -> Self::Future { - let fut = match &mut self.routes { - // REST routes. - Routes::RestJson1(routes) => routes.call(req), - Routes::RestXml(routes) => routes.call(req), - // AwsJson routes. - Routes::AwsJson1_0(routes) => routes.call(req), - Routes::AwsJson1_1(routes) => routes.call(req), - }; - RouterFuture::new(fut) - } -} - -#[cfg(test)] -#[allow(deprecated)] -mod rest_tests { - use super::*; - use crate::{ - body::{boxed, BoxBody}, - routing::request_spec::*, - }; - use futures_util::Future; - use http::{HeaderMap, Method, StatusCode}; - use std::pin::Pin; - - /// Helper function to build a `Request`. Used in other test modules. - pub fn req(method: &Method, uri: &str, headers: Option) -> Request<()> { - let mut r = Request::builder().method(method).uri(uri).body(()).unwrap(); - if let Some(headers) = headers { - *r.headers_mut() = headers - } - r - } - - // Returns a `Response`'s body as a `String`, without consuming the response. - pub async fn get_body_as_string(res: &mut Response) -> String + /// Creates a [`RoutingFuture`] from [`ServiceExt::oneshot`]. + pub(super) fn from_oneshot(future: Oneshot>) -> Self where - B: http_body::Body + std::marker::Unpin, - B::Error: std::fmt::Debug, + S: Service, Response = http::Response>, + RespB: HttpBody + Send + 'static, + RespB::Error: Into, { - let body_mut = res.body_mut(); - let body_bytes = hyper::body::to_bytes(body_mut).await.unwrap(); - String::from(std::str::from_utf8(&body_bytes).unwrap()) - } - - /// A service that returns its name and the request's URI in the response body. - #[derive(Clone)] - struct NamedEchoUriService(String); - - impl Service> for NamedEchoUriService { - type Response = Response; - type Error = Infallible; - type Future = Pin> + Send>>; - - #[inline] - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - #[inline] - fn call(&mut self, req: Request) -> Self::Future { - let body = boxed(Body::from(format!("{} :: {}", self.0, req.uri()))); - let fut = async { Ok(Response::builder().status(&http::StatusCode::OK).body(body).unwrap()) }; - Box::pin(fut) + Self { + inner: Either::Left(future.map_ok(|x| x.map(boxed))), } } - // This test is a rewrite of `mux.spec.ts`. - // https://github.com/awslabs/smithy-typescript/blob/fbf97a9bf4c1d8cf7f285ea7c24e1f0ef280142a/smithy-typescript-ssdk-libs/server-common/src/httpbinding/mux.spec.ts - #[tokio::test] - async fn simple_routing() { - let request_specs: Vec<(RequestSpec, &str)> = vec![ - ( - RequestSpec::from_parts( - Method::GET, - vec![ - PathSegment::Literal(String::from("a")), - PathSegment::Label, - PathSegment::Label, - ], - Vec::new(), - ), - "A", - ), - ( - RequestSpec::from_parts( - Method::GET, - vec![ - PathSegment::Literal(String::from("mg")), - PathSegment::Greedy, - PathSegment::Literal(String::from("z")), - ], - Vec::new(), - ), - "MiddleGreedy", - ), - ( - RequestSpec::from_parts( - Method::DELETE, - Vec::new(), - vec![ - QuerySegment::KeyValue(String::from("foo"), String::from("bar")), - QuerySegment::Key(String::from("baz")), - ], - ), - "Delete", - ), - ( - RequestSpec::from_parts( - Method::POST, - vec![PathSegment::Literal(String::from("query_key_only"))], - vec![QuerySegment::Key(String::from("foo"))], - ), - "QueryKeyOnly", - ), - ]; - - // Test both RestJson1 and RestXml routers. - let router_json = Router::new_rest_json_router(request_specs.clone().into_iter().map(|(spec, svc_name)| { - ( - tower::util::BoxCloneService::new(NamedEchoUriService(String::from(svc_name))), - spec, - ) - })); - let router_xml = Router::new_rest_xml_router(request_specs.into_iter().map(|(spec, svc_name)| { - ( - tower::util::BoxCloneService::new(NamedEchoUriService(String::from(svc_name))), - spec, - ) - })); - - for mut router in [router_json, router_xml] { - let hits = vec![ - ("A", Method::GET, "/a/b/c"), - ("MiddleGreedy", Method::GET, "/mg/a/z"), - ("MiddleGreedy", Method::GET, "/mg/a/b/c/d/z?abc=def"), - ("Delete", Method::DELETE, "/?foo=bar&baz=quux"), - ("Delete", Method::DELETE, "/?foo=bar&baz"), - ("Delete", Method::DELETE, "/?foo=bar&baz=&"), - ("Delete", Method::DELETE, "/?foo=bar&baz=quux&baz=grault"), - ("QueryKeyOnly", Method::POST, "/query_key_only?foo=bar"), - ("QueryKeyOnly", Method::POST, "/query_key_only?foo"), - ("QueryKeyOnly", Method::POST, "/query_key_only?foo="), - ("QueryKeyOnly", Method::POST, "/query_key_only?foo=&"), - ]; - for (svc_name, method, uri) in &hits { - let mut res = router.call(req(method, uri, None)).await.unwrap(); - let actual_body = get_body_as_string(&mut res).await; - - assert_eq!(format!("{} :: {}", svc_name, uri), actual_body); - } - - for (_, _, uri) in hits { - let res = router.call(req(&Method::PATCH, uri, None)).await.unwrap(); - assert_eq!(StatusCode::METHOD_NOT_ALLOWED, res.status()); - } - - let misses = vec![ - (Method::GET, "/a"), - (Method::GET, "/a/b"), - (Method::GET, "/mg"), - (Method::GET, "/mg/q"), - (Method::GET, "/mg/z"), - (Method::GET, "/mg/a/b/z/c"), - (Method::DELETE, "/?foo=bar"), - (Method::DELETE, "/?foo=bar"), - (Method::DELETE, "/?baz=quux"), - (Method::POST, "/query_key_only?baz=quux"), - (Method::GET, "/"), - (Method::POST, "/"), - ]; - for (method, miss) in misses { - let res = router.call(req(&method, miss, None)).await.unwrap(); - assert_eq!(StatusCode::NOT_FOUND, res.status()); - } + /// Creates a [`RoutingFuture`] from [`Service::Response`]. + pub(super) fn from_response(response: http::Response) -> Self { + Self { + inner: Either::Right(ready(Ok(response))), } } +} - #[tokio::test] - async fn basic_pattern_conflict_avoidance() { - let request_specs: Vec<(RequestSpec, &str)> = vec![ - ( - RequestSpec::from_parts( - Method::GET, - vec![PathSegment::Literal(String::from("a")), PathSegment::Label], - Vec::new(), - ), - "A1", - ), - ( - RequestSpec::from_parts( - Method::GET, - vec![ - PathSegment::Literal(String::from("a")), - PathSegment::Label, - PathSegment::Literal(String::from("a")), - ], - Vec::new(), - ), - "A2", - ), - ( - RequestSpec::from_parts( - Method::GET, - vec![PathSegment::Literal(String::from("b")), PathSegment::Greedy], - Vec::new(), - ), - "B1", - ), - ( - RequestSpec::from_parts( - Method::GET, - vec![PathSegment::Literal(String::from("b")), PathSegment::Greedy], - vec![QuerySegment::Key(String::from("q"))], - ), - "B2", - ), - ]; - - let mut router = Router::new_rest_json_router(request_specs.into_iter().map(|(spec, svc_name)| { - ( - tower::util::BoxCloneService::new(NamedEchoUriService(String::from(svc_name))), - spec, - ) - })); - - let hits = vec![ - ("A1", Method::GET, "/a/foo"), - ("A2", Method::GET, "/a/foo/a"), - ("B1", Method::GET, "/b/foo/bar/baz"), - ("B2", Method::GET, "/b/foo?q=baz"), - ]; - for (svc_name, method, uri) in &hits { - let mut res = router.call(req(method, uri, None)).await.unwrap(); - let actual_body = get_body_as_string(&mut res).await; +impl Future for RoutingFuture +where + S: Service>, +{ + type Output = Result, S::Error>; - assert_eq!(format!("{} :: {}", svc_name, uri), actual_body); - } + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().inner.poll(cx) } } -#[allow(deprecated)] -#[cfg(test)] -mod awsjson_tests { - use super::rest_tests::{get_body_as_string, req}; - use super::*; - use crate::body::boxed; - use futures_util::Future; - use http::{HeaderMap, HeaderValue, Method, StatusCode}; - use pretty_assertions::assert_eq; - use std::pin::Pin; - - /// A service that returns its name and the request's URI in the response body. - #[derive(Clone)] - struct NamedEchoOperationService(String); - - impl Service> for NamedEchoOperationService { - type Response = Response; - type Error = Infallible; - type Future = Pin> + Send>>; - - #[inline] - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } +impl Service> for RoutingService +where + R: Router, + R::Service: Service, Response = http::Response> + Clone, + R::Error: IntoResponse

+ Error, + RespB: HttpBody + Send + 'static, + RespB::Error: Into, +{ + type Response = Response; + type Error = >>::Error; + type Future = RoutingFuture; - #[inline] - fn call(&mut self, req: Request) -> Self::Future { - let target = req - .headers() - .get("x-amz-target") - .map(|x| x.to_str().unwrap()) - .unwrap_or("unknown"); - let body = boxed(Body::from(format!("{} :: {}", self.0, target))); - let fut = async { Ok(Response::builder().status(&http::StatusCode::OK).body(body).unwrap()) }; - Box::pin(fut) - } + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } - #[tokio::test] - async fn simple_routing() { - let routes = vec![("Service.Operation", "A")]; - let router_json10 = Router::new_aws_json_10_router(routes.clone().into_iter().map(|(operation, svc_name)| { - ( - tower::util::BoxCloneService::new(NamedEchoOperationService(String::from(svc_name))), - operation.to_string(), - ) - })); - let router_json11 = Router::new_aws_json_11_router(routes.into_iter().map(|(operation, svc_name)| { - ( - tower::util::BoxCloneService::new(NamedEchoOperationService(String::from(svc_name))), - operation.to_string(), - ) - })); - - for mut router in [router_json10, router_json11] { - let mut headers = HeaderMap::new(); - headers.insert("x-amz-target", HeaderValue::from_static("Service.Operation")); - - // Valid request, should return a valid body. - let mut res = router - .call(req(&Method::POST, "/", Some(headers.clone()))) - .await - .unwrap(); - let actual_body = get_body_as_string(&mut res).await; - assert_eq!(format!("{} :: {}", "A", "Service.Operation"), actual_body); - - // No headers, should return NOT_FOUND. - let res = router.call(req(&Method::POST, "/", None)).await.unwrap(); - assert_eq!(res.status(), StatusCode::NOT_FOUND); - - // Wrong HTTP method, should return METHOD_NOT_ALLOWED. - let res = router - .call(req(&Method::GET, "/", Some(headers.clone()))) - .await - .unwrap(); - assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED); - - // Wrong URI, should return NOT_FOUND. - let res = router - .call(req(&Method::POST, "/something", Some(headers))) - .await - .unwrap(); - assert_eq!(res.status(), StatusCode::NOT_FOUND); + fn call(&mut self, req: http::Request) -> Self::Future { + match self.router.match_route(&req) { + // Successfully routed, use the routes `Service::call`. + Ok(ok) => RoutingFuture::from_oneshot(ok.oneshot(req)), + // Failed to route, use the `R::Error`s `IntoResponse

`. + Err(error) => { + debug!(%error, "failed to route"); + RoutingFuture::from_response(error.into_response()) + } } } } diff --git a/rust-runtime/aws-smithy-http-server/src/routing/request_spec.rs b/rust-runtime/aws-smithy-http-server/src/routing/request_spec.rs index bbabefb019..84f431c24d 100644 --- a/rust-runtime/aws-smithy-http-server/src/routing/request_spec.rs +++ b/rust-runtime/aws-smithy-http-server/src/routing/request_spec.rs @@ -250,8 +250,9 @@ impl RequestSpec { #[cfg(test)] mod tests { - use super::super::rest_tests::req; use super::*; + use crate::proto::test_helpers::req; + use http::Method; #[test] diff --git a/rust-runtime/aws-smithy-http-server/src/routing/route.rs b/rust-runtime/aws-smithy-http-server/src/routing/route.rs index 67d7ec5353..7eda401f61 100644 --- a/rust-runtime/aws-smithy-http-server/src/routing/route.rs +++ b/rust-runtime/aws-smithy-http-server/src/routing/route.rs @@ -64,10 +64,6 @@ impl Route { service: BoxCloneService::new(svc), } } - - pub(super) fn from_box_clone_service(svc: BoxCloneService, Response, Infallible>) -> Self { - Self { service: svc } - } } impl Clone for Route { diff --git a/rust-runtime/aws-smithy-http-server/src/runtime_error.rs b/rust-runtime/aws-smithy-http-server/src/runtime_error.rs index 1201ca61aa..9fc1b8e1a7 100644 --- a/rust-runtime/aws-smithy-http-server/src/runtime_error.rs +++ b/rust-runtime/aws-smithy-http-server/src/runtime_error.rs @@ -174,13 +174,6 @@ impl IntoResponse for RuntimeError { } } -#[allow(deprecated)] -impl From for RuntimeError { - fn from(err: crate::rejection::RequestExtensionNotFoundRejection) -> Self { - Self::InternalFailure(crate::Error::new(err)) - } -} - impl From for RuntimeError { fn from(err: crate::rejection::ResponseRejection) -> Self { Self::Serialization(crate::Error::new(err)) diff --git a/rust-runtime/inlineable/src/lib.rs b/rust-runtime/inlineable/src/lib.rs index 2c2634110c..41af358919 100644 --- a/rust-runtime/inlineable/src/lib.rs +++ b/rust-runtime/inlineable/src/lib.rs @@ -15,8 +15,6 @@ mod json_errors; mod rest_xml_unwrapped_errors; #[allow(unused)] mod rest_xml_wrapped_errors; -#[allow(unused)] -mod server_operation_handler_trait; #[allow(unused)] mod endpoint_lib; diff --git a/rust-runtime/inlineable/src/server_operation_handler_trait.rs b/rust-runtime/inlineable/src/server_operation_handler_trait.rs deleted file mode 100644 index 633120912e..0000000000 --- a/rust-runtime/inlineable/src/server_operation_handler_trait.rs +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -use async_trait::async_trait; -use aws_smithy_http_server::{body::BoxBody, opaque_future}; -use futures_util::{ - future::{BoxFuture, Map}, - FutureExt, -}; -use http::{Request, Response}; -use std::marker::PhantomData; -use tower::Service; - -/// Struct that holds a handler, that is, a function provided by the user that implements the -/// Smithy operation. -#[deprecated( - since = "0.52.0", - note = "`OperationHandler` is part of the older service builder API. This type no longer appears in the public API." -)] -pub struct OperationHandler { - handler: H, - #[allow(clippy::type_complexity)] - _marker: PhantomData<(B, R, I)>, -} - -#[allow(deprecated)] -impl Clone for OperationHandler -where - H: Clone, -{ - fn clone(&self) -> Self { - Self { - handler: self.handler.clone(), - _marker: PhantomData, - } - } -} - -/// Construct an [`OperationHandler`] out of a function implementing the operation. -#[allow(deprecated)] -#[deprecated( - since = "0.52.0", - note = "`OperationHandler` is part of the older service builder API. This type no longer appears in the public API." -)] -pub fn operation(handler: H) -> OperationHandler { - OperationHandler { - handler, - _marker: PhantomData, - } -} - -#[allow(deprecated)] -impl Service> for OperationHandler -where - H: Handler, - B: Send + 'static, -{ - type Response = Response; - type Error = std::convert::Infallible; - type Future = OperationHandlerFuture; - - #[inline] - fn poll_ready( - &mut self, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - std::task::Poll::Ready(Ok(())) - } - - fn call(&mut self, req: Request) -> Self::Future { - let future = - Handler::call(self.handler.clone(), req).map(Ok::<_, std::convert::Infallible> as _); - OperationHandlerFuture::new(future) - } -} - -type WrapResultInResponseFn = - fn(Response) -> Result, std::convert::Infallible>; - -opaque_future! { - /// Response future for [`OperationHandler`]. - pub type OperationHandlerFuture = - Map>, WrapResultInResponseFn>; -} - -pub(crate) mod sealed { - #![allow(unreachable_pub, missing_docs, missing_debug_implementations)] - pub trait HiddenTrait {} - pub struct Hidden; - impl HiddenTrait for Hidden {} -} - -#[deprecated( - since = "0.52.0", - note = "The inlineable `Handler` is part of the deprecated service builder API. This type no longer appears in the public API." -)] -#[async_trait] -pub trait Handler: Clone + Send + Sized + 'static { - #[doc(hidden)] - type Sealed: sealed::HiddenTrait; - - async fn call(self, req: Request) -> Response; -}