diff --git a/.changelog/c1435149a35b4c4e9fd8b919f90bca32.json b/.changelog/c1435149a35b4c4e9fd8b919f90bca32.json new file mode 100644 index 000000000..e306ac321 --- /dev/null +++ b/.changelog/c1435149a35b4c4e9fd8b919f90bca32.json @@ -0,0 +1,8 @@ +{ + "id": "c1435149-a35b-4c4e-9fd8-b919f90bca32", + "type": "feature", + "description": "`transport/http`: Add utility for setting context metadata when operation serializer automatically assigns content-type default value.", + "modules": [ + "." + ] +} \ No newline at end of file diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpBindingProtocolGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpBindingProtocolGenerator.java index dfeb669a8..b904a9b91 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpBindingProtocolGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpBindingProtocolGenerator.java @@ -476,11 +476,20 @@ protected abstract void writeMiddlewareDocumentSerializerDelegator( * @param payloadShape the payload shape. */ protected void writeSetPayloadShapeHeader(GoWriter writer, Shape payloadShape) { + writer.pushState(); + + writer.putContext("withIsDefaultContentType", SymbolUtils.createValueSymbolBuilder( + "SetIsContentTypeDefaultValue", SmithyGoDependency.SMITHY_HTTP_TRANSPORT).build()); + writer.putContext("payloadMediaType", getPayloadShapeMediaType(payloadShape)); + writer.write(""" if !restEncoder.HasHeader("Content-Type") { - restEncoder.SetHeader("Content-Type").String($S) + ctx = $withIsDefaultContentType:T(ctx, true) + restEncoder.SetHeader("Content-Type").String($payloadMediaType:S) } - """, getPayloadShapeMediaType(payloadShape)); + """); + + writer.popState(); } /** @@ -511,24 +520,24 @@ protected void writeMiddlewarePayloadSerializerDelegator( Shape payloadShape = model.expectShape(memberShape.getTarget()); if (payloadShape.hasTrait(StreamingTrait.class)) { + writeSetPayloadShapeHeader(writer, payloadShape); GoValueAccessUtils.writeIfNonZeroValueMember(context.getModel(), context.getSymbolProvider(), writer, memberShape, "input", (s) -> { - writeSetPayloadShapeHeader(writer, payloadShape); writer.write("payload := $L", s); writeSetStream(writer, "payload"); }); } else if (payloadShape.isBlobShape()) { + writeSetPayloadShapeHeader(writer, payloadShape); GoValueAccessUtils.writeIfNonZeroValueMember(context.getModel(), context.getSymbolProvider(), writer, memberShape, "input", (s) -> { - writeSetPayloadShapeHeader(writer, payloadShape); writer.addUseImports(SmithyGoDependency.BYTES); writer.write("payload := bytes.NewReader($L)", s); writeSetStream(writer, "payload"); }); } else if (payloadShape.isStringShape()) { + writeSetPayloadShapeHeader(writer, payloadShape); GoValueAccessUtils.writeIfNonZeroValueMember(context.getModel(), context.getSymbolProvider(), writer, memberShape, "input", (s) -> { - writeSetPayloadShapeHeader(writer, payloadShape); writer.addUseImports(SmithyGoDependency.STRINGS); if (payloadShape.hasTrait(EnumTrait.class)) { writer.write("payload := strings.NewReader(string($L))", s); diff --git a/transport/http/middleware_headers.go b/transport/http/middleware_headers.go index 49884e6af..eac32b4ba 100644 --- a/transport/http/middleware_headers.go +++ b/transport/http/middleware_headers.go @@ -7,6 +7,85 @@ import ( "github.com/aws/smithy-go/middleware" ) +type isContentTypeAutoSet struct{} + +// SetIsContentTypeDefaultValue returns a Context specifying if the request's +// content-type header was set to a default value. +func SetIsContentTypeDefaultValue(ctx context.Context, isDefault bool) context.Context { + return context.WithValue(ctx, isContentTypeAutoSet{}, isDefault) +} + +// GetIsContentTypeDefaultValue returns if the content-type HTTP header on the +// request is a default value that was auto assigned by an operation +// serializer. Allows middleware post serialization to know if the content-type +// was auto set to a default value or not. +// +// Also returns false if the Context value was never updated to include if +// content-type was set to a default value. +func GetIsContentTypeDefaultValue(ctx context.Context) bool { + v, _ := ctx.Value(isContentTypeAutoSet{}).(bool) + return v +} + +// AddNoPayloadDefaultContentTypeRemover Adds the DefaultContentTypeRemover +// middleware to the stack after the operation serializer. This middleware will +// remove the content-type header from the request if it was set as a default +// value, and no request payload is present. +// +// Returns error if unable to add the middleware. +func AddNoPayloadDefaultContentTypeRemover(stack *middleware.Stack) (err error) { + err = stack.Serialize.Insert(removeDefaultContentType{}, + "OperationSerializer", middleware.After) + if err != nil { + return fmt.Errorf("failed to add %s serialize middleware, %w", + removeDefaultContentType{}.ID(), err) + } + + return nil +} + +// RemoveNoPayloadDefaultContentTypeRemover removes the +// DefaultContentTypeRemover middleware from the stack. Returns an error if +// unable to remove the middleware. +func RemoveNoPayloadDefaultContentTypeRemover(stack *middleware.Stack) (err error) { + _, err = stack.Serialize.Remove(removeDefaultContentType{}.ID()) + if err != nil { + return fmt.Errorf("failed to remove %s serialize middleware, %w", + removeDefaultContentType{}.ID(), err) + + } + return nil +} + +// removeDefaultContentType provides after serialization middleware that will +// remove the content-type header from an HTTP request if the header was set as +// a default value by the operation serializer, and there is no request payload. +type removeDefaultContentType struct{} + +// ID returns the middleware ID +func (removeDefaultContentType) ID() string { return "RemoveDefaultContentType" } + +// HandleSerialize implements the serialization middleware. +func (removeDefaultContentType) HandleSerialize( + ctx context.Context, input middleware.SerializeInput, next middleware.SerializeHandler, +) ( + out middleware.SerializeOutput, meta middleware.Metadata, err error, +) { + req, ok := input.Request.(*Request) + if !ok { + return out, meta, fmt.Errorf( + "unexpected request type %T for removeDefaultContentType middleware", + input.Request) + } + + if GetIsContentTypeDefaultValue(ctx) && req.GetStream() == nil { + req.Header.Del("Content-Type") + input.Request = req + } + + return next.HandleSerialize(ctx, input) +} + type headerValue struct { header string value string