Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

service/sqs: Add support for validating message checksums by default #1748

Merged
merged 2 commits into from Jul 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions .changelog/131fe156ee0640ff85a5ba09e3cda44c.json
@@ -0,0 +1,8 @@
{
"id": "131fe156-ee06-40ff-85a5-ba09e3cda44c",
"type": "feature",
"description": "Adds support for the SQS client to automatically validate message checksums for SendMessage, SendMessageBatch, and ReceiveMessage. A DisableMessageChecksumValidation parameter has been added to the Options struct for SQS package. Setting this to true will disable the checksum validation. This can be set when creating a client, or per operation call.",
"modules": [
"service/sqs"
]
}
@@ -0,0 +1,102 @@
package software.amazon.smithy.aws.go.codegen.customization;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;
import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.go.codegen.GoCodegenPlugin;
import software.amazon.smithy.go.codegen.GoSettings;
import software.amazon.smithy.go.codegen.SymbolUtils;
import software.amazon.smithy.go.codegen.integration.ConfigField;
import software.amazon.smithy.go.codegen.integration.GoIntegration;
import software.amazon.smithy.go.codegen.integration.MiddlewareRegistrar;
import software.amazon.smithy.go.codegen.integration.RuntimeClientPlugin;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.ShapeId;
import software.amazon.smithy.utils.MapUtils;
import software.amazon.smithy.utils.SetUtils;

public class SQSValidateMessageChecksum implements GoIntegration {
private static final Logger LOGGER = Logger.getLogger(SQSValidateMessageChecksum.class.getName());

/**
* Map of service shape to Set of operation shapes that need to have this
* customization.
*/
public static final Map<ShapeId, Set<ShapeId>> SERVICE_TO_OPERATION_MAP = MapUtils.of(
ShapeId.from("com.amazonaws.sqs#AmazonSQS"), SetUtils.of(
ShapeId.from("com.amazonaws.sqs#SendMessage"),
ShapeId.from("com.amazonaws.sqs#SendMessageBatch"),
ShapeId.from("com.amazonaws.sqs#ReceiveMessage")
)
);
static final String DISABLE_MESSAGE_CHECKSUM_VALIDATION_OPTION_NAME = "DisableMessageChecksumValidation";

private final List<RuntimeClientPlugin> runtimeClientPlugins = new ArrayList<>();

/**
* Builds the set of runtime plugs used by the customization.
*
* @param settings codegen settings
* @param model api model
*/
@Override
public void processFinalizedModel(GoSettings settings, Model model) {
ShapeId serviceId = settings.getService();
if (!SERVICE_TO_OPERATION_MAP.containsKey(serviceId)) {
return;
}

ServiceShape service = settings.getService(model);

// Add option to disable message checksum validation
runtimeClientPlugins.add(RuntimeClientPlugin.builder()
.servicePredicate((m, s) -> s.equals(service))
.addConfigField(ConfigField.builder()
.name(DISABLE_MESSAGE_CHECKSUM_VALIDATION_OPTION_NAME)
.type(SymbolUtils.createValueSymbolBuilder("bool")
.putProperty(SymbolUtils.GO_UNIVERSE_TYPE, true).build())
.documentation("Allows you to disable the client's validation of "
+ "response message checksums. Enabled by default. "
+ "Used by SendMessage, SendMessageBatch, and ReceiveMessage.")
.build())
.build());

for (ShapeId operationId : SERVICE_TO_OPERATION_MAP.get(serviceId)) {
final OperationShape operation = model.expectShape(operationId, OperationShape.class);

// Create a symbol provider because one is not available in this call.
SymbolProvider symbolProvider = GoCodegenPlugin.createSymbolProvider(model, settings);

String helperFuncName = addMiddlewareFuncName(symbolProvider.toSymbol(operation).getName());

runtimeClientPlugins.add(RuntimeClientPlugin.builder()
.servicePredicate((m, s) -> s.equals(service))
.operationPredicate((m, s, o) -> o.equals(operation))
.registerMiddleware(MiddlewareRegistrar.builder()
.resolvedFunction(SymbolUtils.createValueSymbolBuilder(helperFuncName)
.build())
.useClientOptions()
.build())
.build());
}
}

String addMiddlewareFuncName(String operationName) {
return "addValidate" + operationName + "Checksum";
}

/**
* Returns the list of runtime client plugins added by this customization
*
* @return runtime client plugins
*/
@Override
public List<RuntimeClientPlugin> getClientPlugins() {
return runtimeClientPlugins;
}
}
Expand Up @@ -45,4 +45,5 @@ software.amazon.smithy.aws.go.codegen.RequestResponseLogging
software.amazon.smithy.aws.go.codegen.customization.S3AddPutObjectUnseekableBodyDoc
software.amazon.smithy.aws.go.codegen.customization.BackfillEc2UnboxedToBoxedShapes
software.amazon.smithy.aws.go.codegen.customization.AdjustAwsRestJsonContentType
software.amazon.smithy.aws.go.codegen.customization.SQSValidateMessageChecksum
software.amazon.smithy.aws.go.codegen.EndpointDiscoveryGenerator
4 changes: 4 additions & 0 deletions service/sqs/api_client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions service/sqs/api_op_ReceiveMessage.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions service/sqs/api_op_SendMessage.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions service/sqs/api_op_SendMessageBatch.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

234 changes: 234 additions & 0 deletions service/sqs/cust_checksum_validation.go
@@ -0,0 +1,234 @@
package sqs

import (
"context"
"crypto/md5"
"encoding/hex"
"fmt"
"strings"

"github.com/aws/aws-sdk-go-v2/aws"
sqstypes "github.com/aws/aws-sdk-go-v2/service/sqs/types"
"github.com/aws/smithy-go/middleware"
)

// addValidateSendMessageChecksum adds the ValidateMessageChecksum middleware
// to the stack configured for the SendMessage Operation.
func addValidateSendMessageChecksum(stack *middleware.Stack, o Options) error {
return addValidateMessageChecksum(stack, o, validateSendMessageChecksum)
}

// validateSendMessageChecksum validates the SendMessage operation's input
// message payload MD5 checksum matches that returned by the API.
//
// The input and output types must match the SendMessage operation.
func validateSendMessageChecksum(input, output interface{}) error {
in, ok := input.(*SendMessageInput)
if !ok {
return fmt.Errorf("wrong input type, expect %T, got %T", in, input)
}
out, ok := output.(*SendMessageOutput)
if !ok {
return fmt.Errorf("wrong output type, expect %T, got %T", out, output)
}

// Nothing to validate if the members aren't populated.
if in.MessageBody == nil || out.MD5OfMessageBody == nil {
return nil
}

if err := validateMessageChecksum(*in.MessageBody, *out.MD5OfMessageBody); err != nil {
return messageChecksumError{
MessageID: aws.ToString(out.MessageId),
Err: err,
}
}
return nil
}

// addValidateSendMessageBatchChecksum adds the ValidateMessagechecksum
// middleware to the stack configured for the SendMessageBatch operation.
func addValidateSendMessageBatchChecksum(stack *middleware.Stack, o Options) error {
return addValidateMessageChecksum(stack, o, validateSendMessageBatchChecksum)
}

// validateSendMessageBatchChecksum validates the SendMessageBatch operation's
// input messages body MD5 checksum matches those returned by the API.
//
// The input and output types must match the SendMessageBatch operation.
func validateSendMessageBatchChecksum(input, output interface{}) error {
in, ok := input.(*SendMessageBatchInput)
if !ok {
return fmt.Errorf("wrong input type, expect %T, got %T", in, input)
}
out, ok := output.(*SendMessageBatchOutput)
if !ok {
return fmt.Errorf("wrong output type, expect %T, got %T", out, output)
}

outEntries := map[string]sqstypes.SendMessageBatchResultEntry{}
for _, e := range out.Successful {
outEntries[*e.Id] = e
}

var failedMessageErrs []messageChecksumError
for _, inEntry := range in.Entries {
outEntry, ok := outEntries[*inEntry.Id]
// Nothing to validate if the members aren't populated.
if !ok || inEntry.MessageBody == nil || outEntry.MD5OfMessageBody == nil {
continue
}

if err := validateMessageChecksum(*inEntry.MessageBody, *outEntry.MD5OfMessageBody); err != nil {
failedMessageErrs = append(failedMessageErrs, messageChecksumError{
MessageID: aws.ToString(outEntry.MessageId),
Err: err,
})
}
}

if len(failedMessageErrs) != 0 {
return batchMessageChecksumError{
Errs: failedMessageErrs,
}
}

return nil
}

// addValidateReceiveMessageChecksum adds the ValidateMessagechecksum
// middleware to the stack configured for the ReceiveMessage operation.
func addValidateReceiveMessageChecksum(stack *middleware.Stack, o Options) error {
return addValidateMessageChecksum(stack, o, validateReceiveMessageChecksum)
}

// validateReceiveMessageChecksum validates the ReceiveMessage operation's
// input messages body MD5 checksum matches those returned by the API.
//
// The input and output types must match the ReceiveMessage operation.
func validateReceiveMessageChecksum(_, output interface{}) error {
out, ok := output.(*ReceiveMessageOutput)
if !ok {
return fmt.Errorf("wrong output type, expect %T, got %T", out, output)
}

var failedMessageErrs []messageChecksumError
for _, msg := range out.Messages {
// Nothing to validate if the members aren't populated.
if msg.Body == nil || msg.MD5OfBody == nil {
continue
}

if err := validateMessageChecksum(*msg.Body, *msg.MD5OfBody); err != nil {
failedMessageErrs = append(failedMessageErrs, messageChecksumError{
MessageID: aws.ToString(msg.MessageId),
Err: err,
})
}
}

if len(failedMessageErrs) != 0 {
return batchMessageChecksumError{
Errs: failedMessageErrs,
}
}

return nil
}

// messageChecksumValidator provides the function signature for the operation's
// validator.
type messageChecksumValidator func(input, output interface{}) error

// addValidateMessageChecksum adds the ValidateMessageChecksum middleware to
// the stack with the passed in validator specified.
func addValidateMessageChecksum(stack *middleware.Stack, o Options, validate messageChecksumValidator) error {
if o.DisableMessageChecksumValidation {
return nil
}

m := validateMessageChecksumMiddleware{
validate: validate,
}
err := stack.Initialize.Add(m, middleware.Before)
if err != nil {
return fmt.Errorf("failed to add %s middleware, %w", m.ID(), err)
}

return nil
}

// validateMessageChecksumMiddleware provides the Initialize middleware for
// validating an operation's message checksum is validate. Needs to b
// configured with the operation's validator.
type validateMessageChecksumMiddleware struct {
validate messageChecksumValidator
}

// ID returns the Middleware ID.
func (validateMessageChecksumMiddleware) ID() string { return "SQSValidateMessageChecksum" }

// HandleInitialize implements the InitializeMiddleware interface providing a
// middleware that will validate an operation's message checksum based on
// calling the validate member.
func (m validateMessageChecksumMiddleware) HandleInitialize(
ctx context.Context, input middleware.InitializeInput, next middleware.InitializeHandler,
) (
out middleware.InitializeOutput, meta middleware.Metadata, err error,
) {
out, meta, err = next.HandleInitialize(ctx, input)
if err != nil {
return out, meta, err
}

err = m.validate(input.Parameters, out.Result)
if err != nil {
return out, meta, fmt.Errorf("message checksum validation failed, %w", err)
}

return out, meta, nil
}

// validateMessageChecksum compares the MD5 checksums of value parameter with
// the expected MD5 value. Returns an error if the computed checksum does not
// match the expected value.
func validateMessageChecksum(value, expect string) error {
msum := md5.Sum([]byte(value))
sum := hex.EncodeToString(msum[:])
if sum != expect {
return fmt.Errorf("expected MD5 checksum %s, got %s", expect, sum)
}

return nil
}

// messageChecksumError provides an error type for invalid message checksums.
type messageChecksumError struct {
MessageID string
Err error
}

func (e messageChecksumError) Error() string {
prefix := "message"
if e.MessageID != "" {
prefix += " " + e.MessageID
}
return fmt.Sprintf("%s has invalid checksum, %v", prefix, e.Err.Error())
}

// batchMessageChecksumError provides an error type for a collection of invalid
// message checksum errors.
type batchMessageChecksumError struct {
Errs []messageChecksumError
}

func (e batchMessageChecksumError) Error() string {
var w strings.Builder
fmt.Fprintf(&w, "message checksum errors")

for _, err := range e.Errs {
fmt.Fprintf(&w, "\n\t%s", err.Error())
}

return w.String()
}