Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

service/sqs: Add support for validating message checksums by default #1748

Merged
merged 2 commits into from Jul 1, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions .changelog/131fe156ee0640ff85a5ba09e3cda44c.json
@@ -0,0 +1,8 @@
{
"id": "131fe156-ee06-40ff-85a5-ba09e3cda44c",
"type": "feature",
"description": "Adds support for the SQS client to automatically validate message checksums for SendMessage, SendMessageBatch, and ReceiveMessage. A DisableMessageChecksumValidation parameter has been added to the Options struct for SQS package. Setting this to true will disable the checksum validation. This can be set when creating a client, or per operation call.",
"modules": [
"service/sqs"
]
}
@@ -0,0 +1,102 @@
package software.amazon.smithy.aws.go.codegen.customization;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;
import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.go.codegen.GoCodegenPlugin;
import software.amazon.smithy.go.codegen.GoSettings;
import software.amazon.smithy.go.codegen.SymbolUtils;
import software.amazon.smithy.go.codegen.integration.ConfigField;
import software.amazon.smithy.go.codegen.integration.GoIntegration;
import software.amazon.smithy.go.codegen.integration.MiddlewareRegistrar;
import software.amazon.smithy.go.codegen.integration.RuntimeClientPlugin;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.ShapeId;
import software.amazon.smithy.utils.MapUtils;
import software.amazon.smithy.utils.SetUtils;

public class SQSValidateMessageChecksum implements GoIntegration {
private static final Logger LOGGER = Logger.getLogger(SQSValidateMessageChecksum.class.getName());

/**
* Map of service shape to Set of operation shapes that need to have this
* customization.
*/
public static final Map<ShapeId, Set<ShapeId>> SERVICE_TO_OPERATION_MAP = MapUtils.of(
ShapeId.from("com.amazonaws.sqs#AmazonSQS"), SetUtils.of(
ShapeId.from("com.amazonaws.sqs#SendMessage"),
ShapeId.from("com.amazonaws.sqs#SendMessageBatch"),
ShapeId.from("com.amazonaws.sqs#ReceiveMessage")
)
);
static final String DISABLE_MESSAGE_CHECKSUM_VALIDATION_OPTION_NAME = "DisableMessageChecksumValidation";

private final List<RuntimeClientPlugin> runtimeClientPlugins = new ArrayList<>();

/**
* Builds the set of runtime plugs used by the customization.
*
* @param settings codegen settings
* @param model api model
*/
@Override
public void processFinalizedModel(GoSettings settings, Model model) {
ShapeId serviceId = settings.getService();
if (!SERVICE_TO_OPERATION_MAP.containsKey(serviceId)) {
return;
}

ServiceShape service = settings.getService(model);

// Add option to disable message checksum validation
runtimeClientPlugins.add(RuntimeClientPlugin.builder()
.servicePredicate((m, s) -> s.equals(service))
.addConfigField(ConfigField.builder()
.name(DISABLE_MESSAGE_CHECKSUM_VALIDATION_OPTION_NAME)
.type(SymbolUtils.createValueSymbolBuilder("bool")
.putProperty(SymbolUtils.GO_UNIVERSE_TYPE, true).build())
.documentation("Allows you to disable the client's validation of "
+ "response message checksums. Enabled by default. "
+ "Used by SendMessage, SendMessageBatch, and ReceiveMessage.")
.build())
.build());

for (ShapeId operationId : SERVICE_TO_OPERATION_MAP.get(serviceId)) {
final OperationShape operation = model.expectShape(operationId, OperationShape.class);

// Create a symbol provider because one is not available in this call.
SymbolProvider symbolProvider = GoCodegenPlugin.createSymbolProvider(model, settings);

String helperFuncName = addMiddlewareFuncName(symbolProvider.toSymbol(operation).getName());

runtimeClientPlugins.add(RuntimeClientPlugin.builder()
.servicePredicate((m, s) -> s.equals(service))
.operationPredicate((m, s, o) -> o.equals(operation))
.registerMiddleware(MiddlewareRegistrar.builder()
.resolvedFunction(SymbolUtils.createValueSymbolBuilder(helperFuncName)
.build())
.useClientOptions()
.build())
.build());
}
}

String addMiddlewareFuncName(String operationName) {
return "addValidate" + operationName + "Checksum";
}

/**
* Returns the list of runtime client plugins added by this customization
*
* @return runtime client plugins
*/
@Override
public List<RuntimeClientPlugin> getClientPlugins() {
return runtimeClientPlugins;
}
}
Expand Up @@ -45,4 +45,5 @@ software.amazon.smithy.aws.go.codegen.RequestResponseLogging
software.amazon.smithy.aws.go.codegen.customization.S3AddPutObjectUnseekableBodyDoc
software.amazon.smithy.aws.go.codegen.customization.BackfillEc2UnboxedToBoxedShapes
software.amazon.smithy.aws.go.codegen.customization.AdjustAwsRestJsonContentType
software.amazon.smithy.aws.go.codegen.customization.SQSValidateMessageChecksum
software.amazon.smithy.aws.go.codegen.EndpointDiscoveryGenerator
4 changes: 4 additions & 0 deletions service/sqs/api_client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions service/sqs/api_op_ReceiveMessage.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions service/sqs/api_op_SendMessage.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions service/sqs/api_op_SendMessageBatch.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

214 changes: 214 additions & 0 deletions service/sqs/cust_checksum_validation.go
@@ -0,0 +1,214 @@
package sqs

import (
"context"
"crypto/md5"
"encoding/hex"
"fmt"
"strings"

"github.com/aws/aws-sdk-go-v2/aws"
sqstypes "github.com/aws/aws-sdk-go-v2/service/sqs/types"
"github.com/aws/smithy-go/middleware"
)

//********************************
// SendMessage checksum validation
//********************************
jasdel marked this conversation as resolved.
Show resolved Hide resolved
func addValidateSendMessageChecksum(stack *middleware.Stack, o Options) error {
return addValidateMessageChecksum(stack, o, validateSendMessageChecksum)
}

func validateSendMessageChecksum(input, output interface{}) error {
in, ok := input.(*SendMessageInput)
if !ok {
return fmt.Errorf("wrong input type, expect %T, got %T", in, input)
}
out, ok := output.(*SendMessageOutput)
if !ok {
return fmt.Errorf("wrong output type, expect %T, got %T", out, output)
}

// Nothing to validate if the members aren't populated.
if in.MessageBody == nil || out.MD5OfMessageBody == nil {
return nil
}

if err := validateMessageChecksum(*in.MessageBody, *out.MD5OfMessageBody); err != nil {
return messageChecksumError{
MessageID: aws.ToString(out.MessageId),
Err: err,
}
}
return nil
}

//*************************************
// SendMessageBatch checksum validation
//*************************************
func addValidateSendMessageBatchChecksum(stack *middleware.Stack, o Options) error {
return addValidateMessageChecksum(stack, o, validateSendMessageBatchChecksum)
}

func validateSendMessageBatchChecksum(input, output interface{}) error {
in, ok := input.(*SendMessageBatchInput)
if !ok {
return fmt.Errorf("wrong input type, expect %T, got %T", in, input)
}
out, ok := output.(*SendMessageBatchOutput)
if !ok {
return fmt.Errorf("wrong output type, expect %T, got %T", out, output)
}

outEntries := map[string]sqstypes.SendMessageBatchResultEntry{}
for _, e := range out.Successful {
outEntries[*e.Id] = e
}

var failedMessageErrs []error
for _, inEntry := range in.Entries {
outEntry, ok := outEntries[*inEntry.Id]
// Nothing to validate if the members aren't populated.
if !ok || inEntry.MessageBody == nil || outEntry.MD5OfMessageBody == nil {
continue
}

if err := validateMessageChecksum(*inEntry.MessageBody, *outEntry.MD5OfMessageBody); err != nil {
failedMessageErrs = append(failedMessageErrs, messageChecksumError{
MessageID: aws.ToString(outEntry.MessageId),
Err: err,
})
}
}

if len(failedMessageErrs) != 0 {
return batchMessageChecksumError{
Errs: failedMessageErrs,
}
}

return nil
}

//***********************************
// ReceiveMessage checksum validation
//***********************************
func addValidateReceiveMessageChecksum(stack *middleware.Stack, o Options) error {
return addValidateMessageChecksum(stack, o, validateReceiveMessageChecksum)
}

func validateReceiveMessageChecksum(_, output interface{}) error {
out, ok := output.(*ReceiveMessageOutput)
if !ok {
return fmt.Errorf("wrong output type, expect %T, got %T", out, output)
}

var failedMessageErrs []error
for _, msg := range out.Messages {
// Nothing to validate if the members aren't populated.
if msg.Body == nil || msg.MD5OfBody == nil {
continue
}

if err := validateMessageChecksum(*msg.Body, *msg.MD5OfBody); err != nil {
failedMessageErrs = append(failedMessageErrs, messageChecksumError{
MessageID: aws.ToString(msg.MessageId),
Err: err,
})
}
}

if len(failedMessageErrs) != 0 {
return batchMessageChecksumError{
Errs: failedMessageErrs,
}
}

return nil
}

//***************************************
// Message checksum validation middleware
//***************************************
type messageChecksumValidator func(input, output interface{}) error

func addValidateMessageChecksum(stack *middleware.Stack, o Options, validate messageChecksumValidator) error {
if o.DisableMessageChecksumValidation {
return nil
}

m := validateMessageChecksumMiddleware{
validate: validate,
}
err := stack.Initialize.Add(m, middleware.Before)
if err != nil {
return fmt.Errorf("failed to add %s middleware, %w", m.ID(), err)
}

return nil
}

type validateMessageChecksumMiddleware struct {
validate messageChecksumValidator
}

func (validateMessageChecksumMiddleware) ID() string { return "SQSValidateMessageChecksum" }

func (m validateMessageChecksumMiddleware) HandleInitialize(
ctx context.Context, input middleware.InitializeInput, next middleware.InitializeHandler,
) (
out middleware.InitializeOutput, meta middleware.Metadata, err error,
) {
out, meta, err = next.HandleInitialize(ctx, input)
if err != nil {
return out, meta, err
}

err = m.validate(input.Parameters, out.Result)
if err != nil {
return out, meta, fmt.Errorf("message checksum validation failed, %w", err)
}

return out, meta, nil
}

func validateMessageChecksum(value, expect string) error {
msum := md5.Sum([]byte(value))
sum := hex.EncodeToString(msum[:])
if sum != expect {
return fmt.Errorf("expected MD5 checksum %s, got %s", expect, sum)
}

return nil
}

//************************
// Message checksum errors
//************************
type messageChecksumError struct {
MessageID string
Err error
}

func (e messageChecksumError) Error() string {
prefix := "message"
if e.MessageID != "" {
prefix += " " + e.MessageID
}
return fmt.Sprintf("%s has invalid checksum, %v", prefix, e.Err.Error())
}

type batchMessageChecksumError struct {
Errs []error
}

func (e batchMessageChecksumError) Error() string {
var w strings.Builder
fmt.Fprintf(&w, "message checksum errors")

for _, err := range e.Errs {
fmt.Fprintf(&w, "\n\t%s", err.Error())
}

return w.String()
}