From 681ca6569aa95acbbdbd3bc47adf225de2a684e0 Mon Sep 17 00:00:00 2001 From: Jason Del Ponte <961963+jasdel@users.noreply.github.com> Date: Fri, 1 Jul 2022 12:58:22 -0700 Subject: [PATCH] service/sqs: Add support for validating message checksums by default (#1748) Adds support for the SQS client to automatically validate message checksums for SendMessage, SendMessageBatch, and ReceiveMessage. This brings the v2 SDK up to speed with the v1 SDK's behavior. A DisableMessageChecksumValidation parameter has been added to the Options struct for SQS package. Setting this to true will disable the checksum validation. This can be set when creating a client, or per operation call. --- .../131fe156ee0640ff85a5ba09e3cda44c.json | 8 + .../SQSValidateMessageChecksum.java | 102 ++++ ...mithy.go.codegen.integration.GoIntegration | 1 + service/sqs/api_client.go | 4 + service/sqs/api_op_ReceiveMessage.go | 3 + service/sqs/api_op_SendMessage.go | 3 + service/sqs/api_op_SendMessageBatch.go | 3 + service/sqs/cust_checksum_validation.go | 234 +++++++++ service/sqs/cust_checksum_validation_test.go | 473 ++++++++++++++++++ 9 files changed, 831 insertions(+) create mode 100644 .changelog/131fe156ee0640ff85a5ba09e3cda44c.json create mode 100644 codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/customization/SQSValidateMessageChecksum.java create mode 100644 service/sqs/cust_checksum_validation.go create mode 100644 service/sqs/cust_checksum_validation_test.go diff --git a/.changelog/131fe156ee0640ff85a5ba09e3cda44c.json b/.changelog/131fe156ee0640ff85a5ba09e3cda44c.json new file mode 100644 index 00000000000..fdd744ce62e --- /dev/null +++ b/.changelog/131fe156ee0640ff85a5ba09e3cda44c.json @@ -0,0 +1,8 @@ +{ + "id": "131fe156-ee06-40ff-85a5-ba09e3cda44c", + "type": "feature", + "description": "Adds support for the SQS client to automatically validate message checksums for SendMessage, SendMessageBatch, and ReceiveMessage. A DisableMessageChecksumValidation parameter has been added to the Options struct for SQS package. Setting this to true will disable the checksum validation. This can be set when creating a client, or per operation call.", + "modules": [ + "service/sqs" + ] +} \ No newline at end of file diff --git a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/customization/SQSValidateMessageChecksum.java b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/customization/SQSValidateMessageChecksum.java new file mode 100644 index 00000000000..bc7dcfcb109 --- /dev/null +++ b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/customization/SQSValidateMessageChecksum.java @@ -0,0 +1,102 @@ +package software.amazon.smithy.aws.go.codegen.customization; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.logging.Logger; +import software.amazon.smithy.codegen.core.SymbolProvider; +import software.amazon.smithy.go.codegen.GoCodegenPlugin; +import software.amazon.smithy.go.codegen.GoSettings; +import software.amazon.smithy.go.codegen.SymbolUtils; +import software.amazon.smithy.go.codegen.integration.ConfigField; +import software.amazon.smithy.go.codegen.integration.GoIntegration; +import software.amazon.smithy.go.codegen.integration.MiddlewareRegistrar; +import software.amazon.smithy.go.codegen.integration.RuntimeClientPlugin; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.shapes.OperationShape; +import software.amazon.smithy.model.shapes.ServiceShape; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.utils.MapUtils; +import software.amazon.smithy.utils.SetUtils; + +public class SQSValidateMessageChecksum implements GoIntegration { + private static final Logger LOGGER = Logger.getLogger(SQSValidateMessageChecksum.class.getName()); + + /** + * Map of service shape to Set of operation shapes that need to have this + * customization. + */ + public static final Map> SERVICE_TO_OPERATION_MAP = MapUtils.of( + ShapeId.from("com.amazonaws.sqs#AmazonSQS"), SetUtils.of( + ShapeId.from("com.amazonaws.sqs#SendMessage"), + ShapeId.from("com.amazonaws.sqs#SendMessageBatch"), + ShapeId.from("com.amazonaws.sqs#ReceiveMessage") + ) + ); + static final String DISABLE_MESSAGE_CHECKSUM_VALIDATION_OPTION_NAME = "DisableMessageChecksumValidation"; + + private final List runtimeClientPlugins = new ArrayList<>(); + + /** + * Builds the set of runtime plugs used by the customization. + * + * @param settings codegen settings + * @param model api model + */ + @Override + public void processFinalizedModel(GoSettings settings, Model model) { + ShapeId serviceId = settings.getService(); + if (!SERVICE_TO_OPERATION_MAP.containsKey(serviceId)) { + return; + } + + ServiceShape service = settings.getService(model); + + // Add option to disable message checksum validation + runtimeClientPlugins.add(RuntimeClientPlugin.builder() + .servicePredicate((m, s) -> s.equals(service)) + .addConfigField(ConfigField.builder() + .name(DISABLE_MESSAGE_CHECKSUM_VALIDATION_OPTION_NAME) + .type(SymbolUtils.createValueSymbolBuilder("bool") + .putProperty(SymbolUtils.GO_UNIVERSE_TYPE, true).build()) + .documentation("Allows you to disable the client's validation of " + + "response message checksums. Enabled by default. " + + "Used by SendMessage, SendMessageBatch, and ReceiveMessage.") + .build()) + .build()); + + for (ShapeId operationId : SERVICE_TO_OPERATION_MAP.get(serviceId)) { + final OperationShape operation = model.expectShape(operationId, OperationShape.class); + + // Create a symbol provider because one is not available in this call. + SymbolProvider symbolProvider = GoCodegenPlugin.createSymbolProvider(model, settings); + + String helperFuncName = addMiddlewareFuncName(symbolProvider.toSymbol(operation).getName()); + + runtimeClientPlugins.add(RuntimeClientPlugin.builder() + .servicePredicate((m, s) -> s.equals(service)) + .operationPredicate((m, s, o) -> o.equals(operation)) + .registerMiddleware(MiddlewareRegistrar.builder() + .resolvedFunction(SymbolUtils.createValueSymbolBuilder(helperFuncName) + .build()) + .useClientOptions() + .build()) + .build()); + } + } + + String addMiddlewareFuncName(String operationName) { + return "addValidate" + operationName + "Checksum"; + } + + /** + * Returns the list of runtime client plugins added by this customization + * + * @return runtime client plugins + */ + @Override + public List getClientPlugins() { + return runtimeClientPlugins; + } +} diff --git a/codegen/smithy-aws-go-codegen/src/main/resources/META-INF/services/software.amazon.smithy.go.codegen.integration.GoIntegration b/codegen/smithy-aws-go-codegen/src/main/resources/META-INF/services/software.amazon.smithy.go.codegen.integration.GoIntegration index 056f964899c..1b0378be1ec 100644 --- a/codegen/smithy-aws-go-codegen/src/main/resources/META-INF/services/software.amazon.smithy.go.codegen.integration.GoIntegration +++ b/codegen/smithy-aws-go-codegen/src/main/resources/META-INF/services/software.amazon.smithy.go.codegen.integration.GoIntegration @@ -45,4 +45,5 @@ software.amazon.smithy.aws.go.codegen.RequestResponseLogging software.amazon.smithy.aws.go.codegen.customization.S3AddPutObjectUnseekableBodyDoc software.amazon.smithy.aws.go.codegen.customization.BackfillEc2UnboxedToBoxedShapes software.amazon.smithy.aws.go.codegen.customization.AdjustAwsRestJsonContentType +software.amazon.smithy.aws.go.codegen.customization.SQSValidateMessageChecksum software.amazon.smithy.aws.go.codegen.EndpointDiscoveryGenerator diff --git a/service/sqs/api_client.go b/service/sqs/api_client.go index ce463acad66..57602e07446 100644 --- a/service/sqs/api_client.go +++ b/service/sqs/api_client.go @@ -75,6 +75,10 @@ type Options struct { // clients initial default settings. DefaultsMode aws.DefaultsMode + // Allows you to disable the client's validation of response message checksums. + // Enabled by default. Used by SendMessage, SendMessageBatch, and ReceiveMessage. + DisableMessageChecksumValidation bool + // The endpoint options to be used when attempting to resolve an endpoint. EndpointOptions EndpointResolverOptions diff --git a/service/sqs/api_op_ReceiveMessage.go b/service/sqs/api_op_ReceiveMessage.go index 37726c7a47b..49d0c6d770b 100644 --- a/service/sqs/api_op_ReceiveMessage.go +++ b/service/sqs/api_op_ReceiveMessage.go @@ -290,6 +290,9 @@ func (c *Client) addOperationReceiveMessageMiddlewares(stack *middleware.Stack, if err = smithyhttp.AddCloseResponseBodyMiddleware(stack); err != nil { return err } + if err = addValidateReceiveMessageChecksum(stack, options); err != nil { + return err + } if err = addOpReceiveMessageValidationMiddleware(stack); err != nil { return err } diff --git a/service/sqs/api_op_SendMessage.go b/service/sqs/api_op_SendMessage.go index c96fe9b2c80..23d9ca07fe3 100644 --- a/service/sqs/api_op_SendMessage.go +++ b/service/sqs/api_op_SendMessage.go @@ -239,6 +239,9 @@ func (c *Client) addOperationSendMessageMiddlewares(stack *middleware.Stack, opt if err = smithyhttp.AddCloseResponseBodyMiddleware(stack); err != nil { return err } + if err = addValidateSendMessageChecksum(stack, options); err != nil { + return err + } if err = addOpSendMessageValidationMiddleware(stack); err != nil { return err } diff --git a/service/sqs/api_op_SendMessageBatch.go b/service/sqs/api_op_SendMessageBatch.go index eec3a92508d..2730be7f950 100644 --- a/service/sqs/api_op_SendMessageBatch.go +++ b/service/sqs/api_op_SendMessageBatch.go @@ -129,6 +129,9 @@ func (c *Client) addOperationSendMessageBatchMiddlewares(stack *middleware.Stack if err = smithyhttp.AddCloseResponseBodyMiddleware(stack); err != nil { return err } + if err = addValidateSendMessageBatchChecksum(stack, options); err != nil { + return err + } if err = addOpSendMessageBatchValidationMiddleware(stack); err != nil { return err } diff --git a/service/sqs/cust_checksum_validation.go b/service/sqs/cust_checksum_validation.go new file mode 100644 index 00000000000..a16d005dbf5 --- /dev/null +++ b/service/sqs/cust_checksum_validation.go @@ -0,0 +1,234 @@ +package sqs + +import ( + "context" + "crypto/md5" + "encoding/hex" + "fmt" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws" + sqstypes "github.com/aws/aws-sdk-go-v2/service/sqs/types" + "github.com/aws/smithy-go/middleware" +) + +// addValidateSendMessageChecksum adds the ValidateMessageChecksum middleware +// to the stack configured for the SendMessage Operation. +func addValidateSendMessageChecksum(stack *middleware.Stack, o Options) error { + return addValidateMessageChecksum(stack, o, validateSendMessageChecksum) +} + +// validateSendMessageChecksum validates the SendMessage operation's input +// message payload MD5 checksum matches that returned by the API. +// +// The input and output types must match the SendMessage operation. +func validateSendMessageChecksum(input, output interface{}) error { + in, ok := input.(*SendMessageInput) + if !ok { + return fmt.Errorf("wrong input type, expect %T, got %T", in, input) + } + out, ok := output.(*SendMessageOutput) + if !ok { + return fmt.Errorf("wrong output type, expect %T, got %T", out, output) + } + + // Nothing to validate if the members aren't populated. + if in.MessageBody == nil || out.MD5OfMessageBody == nil { + return nil + } + + if err := validateMessageChecksum(*in.MessageBody, *out.MD5OfMessageBody); err != nil { + return messageChecksumError{ + MessageID: aws.ToString(out.MessageId), + Err: err, + } + } + return nil +} + +// addValidateSendMessageBatchChecksum adds the ValidateMessagechecksum +// middleware to the stack configured for the SendMessageBatch operation. +func addValidateSendMessageBatchChecksum(stack *middleware.Stack, o Options) error { + return addValidateMessageChecksum(stack, o, validateSendMessageBatchChecksum) +} + +// validateSendMessageBatchChecksum validates the SendMessageBatch operation's +// input messages body MD5 checksum matches those returned by the API. +// +// The input and output types must match the SendMessageBatch operation. +func validateSendMessageBatchChecksum(input, output interface{}) error { + in, ok := input.(*SendMessageBatchInput) + if !ok { + return fmt.Errorf("wrong input type, expect %T, got %T", in, input) + } + out, ok := output.(*SendMessageBatchOutput) + if !ok { + return fmt.Errorf("wrong output type, expect %T, got %T", out, output) + } + + outEntries := map[string]sqstypes.SendMessageBatchResultEntry{} + for _, e := range out.Successful { + outEntries[*e.Id] = e + } + + var failedMessageErrs []messageChecksumError + for _, inEntry := range in.Entries { + outEntry, ok := outEntries[*inEntry.Id] + // Nothing to validate if the members aren't populated. + if !ok || inEntry.MessageBody == nil || outEntry.MD5OfMessageBody == nil { + continue + } + + if err := validateMessageChecksum(*inEntry.MessageBody, *outEntry.MD5OfMessageBody); err != nil { + failedMessageErrs = append(failedMessageErrs, messageChecksumError{ + MessageID: aws.ToString(outEntry.MessageId), + Err: err, + }) + } + } + + if len(failedMessageErrs) != 0 { + return batchMessageChecksumError{ + Errs: failedMessageErrs, + } + } + + return nil +} + +// addValidateReceiveMessageChecksum adds the ValidateMessagechecksum +// middleware to the stack configured for the ReceiveMessage operation. +func addValidateReceiveMessageChecksum(stack *middleware.Stack, o Options) error { + return addValidateMessageChecksum(stack, o, validateReceiveMessageChecksum) +} + +// validateReceiveMessageChecksum validates the ReceiveMessage operation's +// input messages body MD5 checksum matches those returned by the API. +// +// The input and output types must match the ReceiveMessage operation. +func validateReceiveMessageChecksum(_, output interface{}) error { + out, ok := output.(*ReceiveMessageOutput) + if !ok { + return fmt.Errorf("wrong output type, expect %T, got %T", out, output) + } + + var failedMessageErrs []messageChecksumError + for _, msg := range out.Messages { + // Nothing to validate if the members aren't populated. + if msg.Body == nil || msg.MD5OfBody == nil { + continue + } + + if err := validateMessageChecksum(*msg.Body, *msg.MD5OfBody); err != nil { + failedMessageErrs = append(failedMessageErrs, messageChecksumError{ + MessageID: aws.ToString(msg.MessageId), + Err: err, + }) + } + } + + if len(failedMessageErrs) != 0 { + return batchMessageChecksumError{ + Errs: failedMessageErrs, + } + } + + return nil +} + +// messageChecksumValidator provides the function signature for the operation's +// validator. +type messageChecksumValidator func(input, output interface{}) error + +// addValidateMessageChecksum adds the ValidateMessageChecksum middleware to +// the stack with the passed in validator specified. +func addValidateMessageChecksum(stack *middleware.Stack, o Options, validate messageChecksumValidator) error { + if o.DisableMessageChecksumValidation { + return nil + } + + m := validateMessageChecksumMiddleware{ + validate: validate, + } + err := stack.Initialize.Add(m, middleware.Before) + if err != nil { + return fmt.Errorf("failed to add %s middleware, %w", m.ID(), err) + } + + return nil +} + +// validateMessageChecksumMiddleware provides the Initialize middleware for +// validating an operation's message checksum is validate. Needs to b +// configured with the operation's validator. +type validateMessageChecksumMiddleware struct { + validate messageChecksumValidator +} + +// ID returns the Middleware ID. +func (validateMessageChecksumMiddleware) ID() string { return "SQSValidateMessageChecksum" } + +// HandleInitialize implements the InitializeMiddleware interface providing a +// middleware that will validate an operation's message checksum based on +// calling the validate member. +func (m validateMessageChecksumMiddleware) HandleInitialize( + ctx context.Context, input middleware.InitializeInput, next middleware.InitializeHandler, +) ( + out middleware.InitializeOutput, meta middleware.Metadata, err error, +) { + out, meta, err = next.HandleInitialize(ctx, input) + if err != nil { + return out, meta, err + } + + err = m.validate(input.Parameters, out.Result) + if err != nil { + return out, meta, fmt.Errorf("message checksum validation failed, %w", err) + } + + return out, meta, nil +} + +// validateMessageChecksum compares the MD5 checksums of value parameter with +// the expected MD5 value. Returns an error if the computed checksum does not +// match the expected value. +func validateMessageChecksum(value, expect string) error { + msum := md5.Sum([]byte(value)) + sum := hex.EncodeToString(msum[:]) + if sum != expect { + return fmt.Errorf("expected MD5 checksum %s, got %s", expect, sum) + } + + return nil +} + +// messageChecksumError provides an error type for invalid message checksums. +type messageChecksumError struct { + MessageID string + Err error +} + +func (e messageChecksumError) Error() string { + prefix := "message" + if e.MessageID != "" { + prefix += " " + e.MessageID + } + return fmt.Sprintf("%s has invalid checksum, %v", prefix, e.Err.Error()) +} + +// batchMessageChecksumError provides an error type for a collection of invalid +// message checksum errors. +type batchMessageChecksumError struct { + Errs []messageChecksumError +} + +func (e batchMessageChecksumError) Error() string { + var w strings.Builder + fmt.Fprintf(&w, "message checksum errors") + + for _, err := range e.Errs { + fmt.Fprintf(&w, "\n\t%s", err.Error()) + } + + return w.String() +} diff --git a/service/sqs/cust_checksum_validation_test.go b/service/sqs/cust_checksum_validation_test.go new file mode 100644 index 00000000000..7e6e5b66caf --- /dev/null +++ b/service/sqs/cust_checksum_validation_test.go @@ -0,0 +1,473 @@ +package sqs + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + sqstypes "github.com/aws/aws-sdk-go-v2/service/sqs/types" + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" +) + +func TestValidateSendMessageChecksum(t *testing.T) { + cases := map[string]struct { + input *SendMessageInput + output *SendMessageOutput + handlerErr error + + expectErr string + }{ + "success": { + input: &SendMessageInput{ + MessageBody: aws.String("test"), + }, + output: &SendMessageOutput{ + MD5OfMessageBody: aws.String("098f6bcd4621d373cade4e832627b4f6"), + MessageId: aws.String("id"), + }, + }, + "no input message": { + input: &SendMessageInput{}, + output: &SendMessageOutput{ + MD5OfMessageBody: aws.String("098f6bcd4621d373cade4e832627b4f6"), + MessageId: aws.String("id"), + }, + }, + "no md5": { + input: &SendMessageInput{ + MessageBody: aws.String("test"), + }, + output: &SendMessageOutput{ + MessageId: aws.String("id"), + }, + }, + "no message id": { + input: &SendMessageInput{ + MessageBody: aws.String("test"), + }, + output: &SendMessageOutput{ + MD5OfMessageBody: aws.String("098f6bcd4621d373cade4e832627b4f6"), + }, + }, + "invalid checksum": { + input: &SendMessageInput{ + MessageBody: aws.String("test"), + }, + output: &SendMessageOutput{ + MD5OfMessageBody: aws.String("01234556"), + MessageId: aws.String("id"), + }, + expectErr: "message id has invalid checksum", + }, + "response error": { + input: &SendMessageInput{ + MessageBody: aws.String("test"), + }, + handlerErr: fmt.Errorf("some error"), + expectErr: "some error", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + m := validateMessageChecksumMiddleware{ + validate: validateSendMessageChecksum, + } + + next := mockInitializeHandler{ + Output: middleware.InitializeOutput{Result: c.output}, + Err: c.handlerErr, + } + input := middleware.InitializeInput{ + Parameters: c.input, + } + _, _, err := m.HandleInitialize(context.Background(), input, next) + if c.expectErr != "" { + if err == nil { + t.Fatalf("expect %v error, got none", c.expectErr) + } + if e, a := c.expectErr, err.Error(); !strings.Contains(a, e) { + t.Fatalf("expect %v error, got %v", e, a) + } + return + } + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + }) + } +} + +func TestValidateSendMessageBatchChecksum(t *testing.T) { + successMD5 := "098f6bcd4621d373cade4e832627b4f6" + invalidMD5 := "11111111111111111111111111111111" + + cases := map[string]struct { + input *SendMessageBatchInput + output *SendMessageBatchOutput + handlerErr error + + expectErrs []string + }{ + "success": { + input: &SendMessageBatchInput{ + Entries: []sqstypes.SendMessageBatchRequestEntry{ + {Id: aws.String("1"), MessageBody: aws.String("test")}, + {Id: aws.String("2"), MessageBody: aws.String("test")}, + {Id: aws.String("3"), MessageBody: aws.String("test")}, + {Id: aws.String("4"), MessageBody: aws.String("test")}, + }, + }, + output: &SendMessageBatchOutput{ + Successful: []sqstypes.SendMessageBatchResultEntry{ + {MD5OfMessageBody: &successMD5, MessageId: aws.String("123"), Id: aws.String("1")}, + {MD5OfMessageBody: &successMD5, MessageId: aws.String("456"), Id: aws.String("2")}, + {MD5OfMessageBody: &successMD5, MessageId: aws.String("789"), Id: aws.String("3")}, + {MD5OfMessageBody: &successMD5, MessageId: aws.String("012"), Id: aws.String("4")}, + }, + }, + }, + "no input body": { + input: &SendMessageBatchInput{ + Entries: []sqstypes.SendMessageBatchRequestEntry{ + {Id: aws.String("1")}, + }, + }, + output: &SendMessageBatchOutput{ + Successful: []sqstypes.SendMessageBatchResultEntry{ + {MD5OfMessageBody: &invalidMD5, MessageId: aws.String("123"), Id: aws.String("1")}, + }, + }, + }, + "no md5": { + input: &SendMessageBatchInput{ + Entries: []sqstypes.SendMessageBatchRequestEntry{ + {Id: aws.String("1"), MessageBody: aws.String("test")}, + }, + }, + output: &SendMessageBatchOutput{ + Successful: []sqstypes.SendMessageBatchResultEntry{ + {MessageId: aws.String("123"), Id: aws.String("1")}, + }, + }, + }, + "server side failure": { + input: &SendMessageBatchInput{ + Entries: []sqstypes.SendMessageBatchRequestEntry{ + {Id: aws.String("1"), MessageBody: aws.String("test")}, + {Id: aws.String("2"), MessageBody: aws.String("test")}, + {Id: aws.String("3"), MessageBody: aws.String("test")}, + {Id: aws.String("4"), MessageBody: aws.String("test")}, + }, + }, + output: &SendMessageBatchOutput{ + Successful: []sqstypes.SendMessageBatchResultEntry{ + {MD5OfMessageBody: &successMD5, MessageId: aws.String("123"), Id: aws.String("1")}, + {MD5OfMessageBody: &successMD5, MessageId: aws.String("456"), Id: aws.String("2")}, + {MD5OfMessageBody: &successMD5, MessageId: aws.String("012"), Id: aws.String("4")}, + }, + Failed: []sqstypes.BatchResultErrorEntry{ + {Id: aws.String("3"), Code: aws.String("test"), Message: aws.String("test")}, + }, + }, + }, + "partial invalid checksum": { + input: &SendMessageBatchInput{ + Entries: []sqstypes.SendMessageBatchRequestEntry{ + {Id: aws.String("1"), MessageBody: aws.String("test")}, + {Id: aws.String("2"), MessageBody: aws.String("test")}, + {Id: aws.String("3"), MessageBody: aws.String("test")}, + {Id: aws.String("4"), MessageBody: aws.String("test")}, + }, + }, + output: &SendMessageBatchOutput{ + Successful: []sqstypes.SendMessageBatchResultEntry{ + {MD5OfMessageBody: &successMD5, MessageId: aws.String("123"), Id: aws.String("1")}, + {MD5OfMessageBody: &successMD5, MessageId: aws.String("456"), Id: aws.String("2")}, + {MD5OfMessageBody: &invalidMD5, MessageId: aws.String("789"), Id: aws.String("3")}, + {MD5OfMessageBody: &successMD5, MessageId: aws.String("012"), Id: aws.String("4")}, + }, + }, + expectErrs: []string{"message 789 has invalid checksum"}, + }, + "complete invalid checksum": { + input: &SendMessageBatchInput{ + Entries: []sqstypes.SendMessageBatchRequestEntry{ + {Id: aws.String("1"), MessageBody: aws.String("test")}, + {Id: aws.String("2"), MessageBody: aws.String("test")}, + {Id: aws.String("3"), MessageBody: aws.String("test")}, + {Id: aws.String("4"), MessageBody: aws.String("test")}, + }, + }, + output: &SendMessageBatchOutput{ + Successful: []sqstypes.SendMessageBatchResultEntry{ + {MD5OfMessageBody: &invalidMD5, MessageId: aws.String("123"), Id: aws.String("1")}, + {MD5OfMessageBody: &invalidMD5, MessageId: aws.String("456"), Id: aws.String("2")}, + {MD5OfMessageBody: &invalidMD5, MessageId: aws.String("789"), Id: aws.String("3")}, + {MD5OfMessageBody: &invalidMD5, MessageId: aws.String("012"), Id: aws.String("4")}, + }, + }, + expectErrs: []string{ + "message 123 has invalid checksum", + "message 456 has invalid checksum", + "message 789 has invalid checksum", + "message 012 has invalid checksum", + }, + }, + "invalid checksum no message id": { + input: &SendMessageBatchInput{ + Entries: []sqstypes.SendMessageBatchRequestEntry{ + {Id: aws.String("1"), MessageBody: aws.String("test")}, + }, + }, + output: &SendMessageBatchOutput{ + Successful: []sqstypes.SendMessageBatchResultEntry{ + {MD5OfMessageBody: &invalidMD5, Id: aws.String("1")}, + }, + }, + expectErrs: []string{"message has invalid checksum"}, + }, + "response error": { + input: &SendMessageBatchInput{ + Entries: []sqstypes.SendMessageBatchRequestEntry{ + {Id: aws.String("1"), MessageBody: aws.String("test")}, + }, + }, + handlerErr: fmt.Errorf("some error"), + expectErrs: []string{"some error"}, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + m := validateMessageChecksumMiddleware{ + validate: validateSendMessageBatchChecksum, + } + + next := mockInitializeHandler{ + Output: middleware.InitializeOutput{Result: c.output}, + Err: c.handlerErr, + } + input := middleware.InitializeInput{ + Parameters: c.input, + } + _, _, err := m.HandleInitialize(context.Background(), input, next) + if len(c.expectErrs) != 0 { + if err == nil { + t.Fatalf("expect error(s), got none") + } + for i, expectErr := range c.expectErrs { + if e, a := expectErr, err.Error(); !strings.Contains(a, e) { + t.Errorf("%d expect %v error, got %v", i, e, a) + } + } + return + } + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + }) + } +} + +func TestValidateReceiveMessageChecksum(t *testing.T) { + successMD5 := "098f6bcd4621d373cade4e832627b4f6" + invalidMD5 := "11111111111111111111111111111111" + + cases := map[string]struct { + output *ReceiveMessageOutput + handlerErr error + + expectErrs []string + }{ + "success": { + output: &ReceiveMessageOutput{ + Messages: []sqstypes.Message{ + {Body: aws.String("test"), MD5OfBody: &successMD5}, + {Body: aws.String("test"), MD5OfBody: &successMD5}, + {Body: aws.String("test"), MD5OfBody: &successMD5}, + {Body: aws.String("test"), MD5OfBody: &successMD5}, + }, + }, + }, + "no body": { + output: &ReceiveMessageOutput{ + Messages: []sqstypes.Message{ + {MD5OfBody: &successMD5}, + }, + }, + }, + "no md5": { + output: &ReceiveMessageOutput{ + Messages: []sqstypes.Message{ + {Body: aws.String("test")}, + }, + }, + }, + "message with no ID partial invalid checksum": { + output: &ReceiveMessageOutput{ + Messages: []sqstypes.Message{ + {Body: aws.String("test"), MD5OfBody: &successMD5}, + {Body: aws.String("test"), MD5OfBody: &successMD5}, + {Body: aws.String("test"), MD5OfBody: &invalidMD5}, + {Body: aws.String("test"), MD5OfBody: &successMD5}, + }, + }, + expectErrs: []string{"message has invalid checksum"}, + }, + "message with ID partial invalid checksum": { + output: &ReceiveMessageOutput{ + Messages: []sqstypes.Message{ + {Body: aws.String("test"), MD5OfBody: &successMD5}, + {Body: aws.String("test"), MD5OfBody: &successMD5}, + {Body: aws.String("test"), MD5OfBody: &invalidMD5, MessageId: aws.String("123")}, + {Body: aws.String("test"), MD5OfBody: &successMD5}, + }, + }, + expectErrs: []string{"message 123 has invalid checksum"}, + }, + "complete invalid checksum": { + output: &ReceiveMessageOutput{ + Messages: []sqstypes.Message{ + {Body: aws.String("test"), MD5OfBody: &invalidMD5}, + {Body: aws.String("test"), MD5OfBody: &invalidMD5, MessageId: aws.String("123")}, + {Body: aws.String("test"), MD5OfBody: &invalidMD5, MessageId: aws.String("456")}, + {Body: aws.String("test"), MD5OfBody: &invalidMD5}, + }, + }, + expectErrs: []string{ + "message has invalid checksum", + "message 123 has invalid checksum", + "message 456 has invalid checksum", + "message has invalid checksum", + }, + }, + "response error": { + handlerErr: fmt.Errorf("some error"), + expectErrs: []string{"some error"}, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + m := validateMessageChecksumMiddleware{ + validate: validateReceiveMessageChecksum, + } + + next := mockInitializeHandler{ + Output: middleware.InitializeOutput{Result: c.output}, + Err: c.handlerErr, + } + + input := middleware.InitializeInput{ + Parameters: &ReceiveMessageInput{}, + } + _, _, err := m.HandleInitialize(context.Background(), input, next) + if len(c.expectErrs) != 0 { + if err == nil { + t.Fatalf("expect error(s), got none") + } + for i, expectErr := range c.expectErrs { + if e, a := expectErr, err.Error(); !strings.Contains(a, e) { + t.Errorf("%d expect %v error, got %v", i, e, a) + } + } + return + } + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + }) + } +} + +func TestAddValidateSendMessageChecksum(t *testing.T) { + cases := map[string]map[string]struct { + options Options + fn func(*middleware.Stack, Options) error + expectInStack bool + }{ + "SendMessage": { + "enabled": { + options: Options{}, + fn: addValidateSendMessageChecksum, + expectInStack: true, + }, + "disabled": { + options: Options{ + DisableMessageChecksumValidation: true, + }, + fn: addValidateSendMessageChecksum, + expectInStack: false, + }, + }, + "SendMessageBatch": { + "enabled": { + options: Options{}, + fn: addValidateSendMessageBatchChecksum, + expectInStack: true, + }, + "disabled": { + options: Options{ + DisableMessageChecksumValidation: true, + }, + fn: addValidateSendMessageBatchChecksum, + expectInStack: false, + }, + }, + "ReceiveMessage": { + "enabled": { + options: Options{}, + fn: addValidateReceiveMessageChecksum, + expectInStack: true, + }, + "disabled": { + options: Options{ + DisableMessageChecksumValidation: true, + }, + fn: addValidateReceiveMessageChecksum, + expectInStack: false, + }, + }, + } + + for opName, opCases := range cases { + t.Run(opName, func(t *testing.T) { + for name, c := range opCases { + t.Run(name, func(t *testing.T) { + options := c.options.Copy() + stack := middleware.NewStack("test", smithyhttp.NewStackRequest) + + err := c.fn(stack, options) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + expectID := validateMessageChecksumMiddleware{}.ID() + if e, a := expectID, stack.String(); c.expectInStack != strings.Contains(a, e) { + t.Errorf("expect %v in stack %v:\n%s", e, c.expectInStack, a) + } + }) + } + }) + } +} + +//****************** +// Testing Utilities +//****************** +type mockInitializeHandler struct { + Output middleware.InitializeOutput + Err error +} + +func (m mockInitializeHandler) HandleInitialize( + ctx context.Context, in middleware.InitializeInput, +) ( + out middleware.InitializeOutput, meta middleware.Metadata, err error, +) { + return m.Output, meta, m.Err +}