Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

codegen: Update REST streaming request payload content-type usage #367

Merged
merged 2 commits into from Jun 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions .changelog/c1435149a35b4c4e9fd8b919f90bca32.json
@@ -0,0 +1,8 @@
{
"id": "c1435149-a35b-4c4e-9fd8-b919f90bca32",
"type": "feature",
"description": "`transport/http`: Add utility for setting context metadata when operation serializer automatically assigns content-type default value.",
"modules": [
"."
]
}
Expand Up @@ -476,11 +476,20 @@ protected abstract void writeMiddlewareDocumentSerializerDelegator(
* @param payloadShape the payload shape.
*/
protected void writeSetPayloadShapeHeader(GoWriter writer, Shape payloadShape) {
writer.pushState();

writer.putContext("withIsDefaultContentType", SymbolUtils.createValueSymbolBuilder(
"SetIsContentTypeDefaultValue", SmithyGoDependency.SMITHY_HTTP_TRANSPORT).build());
writer.putContext("payloadMediaType", getPayloadShapeMediaType(payloadShape));

writer.write("""
if !restEncoder.HasHeader("Content-Type") {
restEncoder.SetHeader("Content-Type").String($S)
ctx = $withIsDefaultContentType:T(ctx, true)
restEncoder.SetHeader("Content-Type").String($payloadMediaType:S)
}
""", getPayloadShapeMediaType(payloadShape));
""");

writer.popState();
}

/**
Expand Down Expand Up @@ -511,24 +520,24 @@ protected void writeMiddlewarePayloadSerializerDelegator(
Shape payloadShape = model.expectShape(memberShape.getTarget());

if (payloadShape.hasTrait(StreamingTrait.class)) {
writeSetPayloadShapeHeader(writer, payloadShape);
GoValueAccessUtils.writeIfNonZeroValueMember(context.getModel(), context.getSymbolProvider(), writer,
memberShape, "input", (s) -> {
writeSetPayloadShapeHeader(writer, payloadShape);
writer.write("payload := $L", s);
writeSetStream(writer, "payload");
});
} else if (payloadShape.isBlobShape()) {
writeSetPayloadShapeHeader(writer, payloadShape);
GoValueAccessUtils.writeIfNonZeroValueMember(context.getModel(), context.getSymbolProvider(), writer,
memberShape, "input", (s) -> {
writeSetPayloadShapeHeader(writer, payloadShape);
writer.addUseImports(SmithyGoDependency.BYTES);
writer.write("payload := bytes.NewReader($L)", s);
writeSetStream(writer, "payload");
});
} else if (payloadShape.isStringShape()) {
writeSetPayloadShapeHeader(writer, payloadShape);
GoValueAccessUtils.writeIfNonZeroValueMember(context.getModel(), context.getSymbolProvider(), writer,
memberShape, "input", (s) -> {
writeSetPayloadShapeHeader(writer, payloadShape);
writer.addUseImports(SmithyGoDependency.STRINGS);
if (payloadShape.hasTrait(EnumTrait.class)) {
writer.write("payload := strings.NewReader(string($L))", s);
Expand Down
79 changes: 79 additions & 0 deletions transport/http/middleware_headers.go
Expand Up @@ -7,6 +7,85 @@ import (
"github.com/aws/smithy-go/middleware"
)

type isContentTypeAutoSet struct{}

// SetIsContentTypeDefaultValue returns a Context specifying if the request's
// content-type header was set to a default value.
func SetIsContentTypeDefaultValue(ctx context.Context, isDefault bool) context.Context {
return context.WithValue(ctx, isContentTypeAutoSet{}, isDefault)
}

// GetIsContentTypeDefaultValue returns if the content-type HTTP header on the
// request is a default value that was auto assigned by an operation
// serializer. Allows middleware post serialization to know if the content-type
// was auto set to a default value or not.
//
// Also returns false if the Context value was never updated to include if
// content-type was set to a default value.
func GetIsContentTypeDefaultValue(ctx context.Context) bool {
v, _ := ctx.Value(isContentTypeAutoSet{}).(bool)
return v
}

// AddNoPayloadDefaultContentTypeRemover Adds the DefaultContentTypeRemover
// middleware to the stack after the operation serializer. This middleware will
// remove the content-type header from the request if it was set as a default
// value, and no request payload is present.
//
// Returns error if unable to add the middleware.
func AddNoPayloadDefaultContentTypeRemover(stack *middleware.Stack) (err error) {
err = stack.Serialize.Insert(removeDefaultContentType{},
"OperationSerializer", middleware.After)
if err != nil {
return fmt.Errorf("failed to add %s serialize middleware, %w",
removeDefaultContentType{}.ID(), err)
}

return nil
}

// RemoveNoPayloadDefaultContentTypeRemover removes the
// DefaultContentTypeRemover middleware from the stack. Returns an error if
// unable to remove the middleware.
func RemoveNoPayloadDefaultContentTypeRemover(stack *middleware.Stack) (err error) {
_, err = stack.Serialize.Remove(removeDefaultContentType{}.ID())
if err != nil {
return fmt.Errorf("failed to remove %s serialize middleware, %w",
removeDefaultContentType{}.ID(), err)

}
return nil
}

// removeDefaultContentType provides after serialization middleware that will
// remove the content-type header from an HTTP request if the header was set as
// a default value by the operation serializer, and there is no request payload.
type removeDefaultContentType struct{}

// ID returns the middleware ID
func (removeDefaultContentType) ID() string { return "RemoveDefaultContentType" }

// HandleSerialize implements the serialization middleware.
func (removeDefaultContentType) HandleSerialize(
ctx context.Context, input middleware.SerializeInput, next middleware.SerializeHandler,
) (
out middleware.SerializeOutput, meta middleware.Metadata, err error,
) {
req, ok := input.Request.(*Request)
if !ok {
return out, meta, fmt.Errorf(
"unexpected request type %T for removeDefaultContentType middleware",
input.Request)
}

if GetIsContentTypeDefaultValue(ctx) && req.GetStream() == nil {
req.Header.Del("Content-Type")
input.Request = req
}

return next.HandleSerialize(ctx, input)
}

type headerValue struct {
header string
value string
Expand Down