Skip to content

Commit

Permalink
feat(middleware-flexible-checksums): support trailing checksums with …
Browse files Browse the repository at this point in the history
…aws-chunked encoding (#3347)
  • Loading branch information
trivikr committed Feb 23, 2022
1 parent d9e4c4b commit 50b48fe
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 27 deletions.
12 changes: 11 additions & 1 deletion packages/middleware-flexible-checksums/src/configuration.ts
@@ -1,4 +1,4 @@
import { Encoder, HashConstructor, StreamHasher } from "@aws-sdk/types";
import { Encoder, GetAwsChunkedEncodingStream, HashConstructor, StreamHasher } from "@aws-sdk/types";

export interface PreviouslyResolved {
/**
Expand All @@ -7,6 +7,16 @@ export interface PreviouslyResolved {
*/
base64Encoder: Encoder;

/**
* A function that can calculate the length of a body.
*/
bodyLengthChecker: (body: any) => number | undefined;

/**
* A function that returns Readable Stream which follows aws-chunked encoding stream.
*/
getAwsChunkedEncodingStream: GetAwsChunkedEncodingStream;

/**
* A constructor for a class implementing the {@link Hash} interface that computes MD5 hashes.
* @internal
Expand Down
Expand Up @@ -9,16 +9,18 @@ import { getChecksumAlgorithmForRequest } from "./getChecksumAlgorithmForRequest
import { getChecksumLocationName } from "./getChecksumLocationName";
import { FlexibleChecksumsMiddlewareConfig } from "./getFlexibleChecksumsPlugin";
import { hasHeader } from "./hasHeader";
import { isStreaming } from "./isStreaming";
import { selectChecksumAlgorithmFunction } from "./selectChecksumAlgorithmFunction";
import { validateChecksumFromResponse } from "./validateChecksumFromResponse";

jest.mock("@aws-sdk/protocol-http");
jest.mock("./getChecksum");
jest.mock("./getChecksumAlgorithmForRequest");
jest.mock("./getChecksumLocationName");
jest.mock("./selectChecksumAlgorithmFunction");
jest.mock("./getChecksum");
jest.mock("./hasHeader");
jest.mock("./isStreaming");
jest.mock("./validateChecksumFromResponse");
jest.mock("./selectChecksumAlgorithmFunction");

describe(flexibleChecksumsMiddleware.name, () => {
const mockNext = jest.fn();
Expand All @@ -31,8 +33,8 @@ describe(flexibleChecksumsMiddleware.name, () => {
const mockConfig = {} as PreviouslyResolved;
const mockMiddlewareConfig = { input: mockInput } as FlexibleChecksumsMiddlewareConfig;

const mockBody = {};
const mockHeaders = {};
const mockBody = { body: "mockBody" };
const mockHeaders = { "content-length": 100 };
const mockRequest = { body: mockBody, headers: mockHeaders };
const mockArgs = { request: mockRequest } as BuildHandlerArguments<any>;
const mockResult = { response: {} };
Expand All @@ -41,19 +43,20 @@ describe(flexibleChecksumsMiddleware.name, () => {
mockNext.mockResolvedValueOnce(mockResult);
const { isInstance } = HttpRequest;
(isInstance as unknown as jest.Mock).mockReturnValue(true);
(getChecksum as jest.Mock).mockReturnValue(mockChecksum);
(getChecksumAlgorithmForRequest as jest.Mock).mockReturnValue(ChecksumAlgorithm.MD5);
(getChecksumLocationName as jest.Mock).mockReturnValue(mockChecksumLocationName);
(selectChecksumAlgorithmFunction as jest.Mock).mockReturnValue(mockChecksumAlgorithmFunction);
(getChecksum as jest.Mock).mockReturnValue(mockChecksum);
(hasHeader as jest.Mock).mockReturnValue(false);
(isStreaming as jest.Mock).mockReturnValue(false);
(selectChecksumAlgorithmFunction as jest.Mock).mockReturnValue(mockChecksumAlgorithmFunction);
});

afterEach(() => {
expect(mockNext).toHaveBeenCalledTimes(1);
jest.clearAllMocks();
});

describe("skips checksum computation", () => {
describe("skips", () => {
it("if not an instance of HttpRequest", async () => {
const { isInstance } = HttpRequest;
(isInstance as unknown as jest.Mock).mockReturnValue(false);
Expand All @@ -65,7 +68,6 @@ describe(flexibleChecksumsMiddleware.name, () => {
describe("request checksum", () => {
afterEach(() => {
expect(getChecksumAlgorithmForRequest).toHaveBeenCalledTimes(1);
expect(selectChecksumAlgorithmFunction).not.toHaveBeenCalled();
expect(getChecksum).not.toHaveBeenCalled();
});

Expand All @@ -75,6 +77,7 @@ describe(flexibleChecksumsMiddleware.name, () => {
await handler(mockArgs);
expect(getChecksumLocationName).not.toHaveBeenCalled();
expect(mockNext).toHaveBeenCalledWith(mockArgs);
expect(selectChecksumAlgorithmFunction).not.toHaveBeenCalled();
});

it("if header is already present", async () => {
Expand All @@ -87,6 +90,7 @@ describe(flexibleChecksumsMiddleware.name, () => {
(hasHeader as jest.Mock).mockReturnValue(true);
await handler(mockArgsWithChecksumHeader);
expect(getChecksumLocationName).toHaveBeenCalledTimes(1);
expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1);
expect(hasHeader).toHaveBeenCalledTimes(1);
expect(mockNext).toHaveBeenCalledWith(mockArgsWithChecksumHeader);
expect(hasHeader).toHaveBeenCalledWith(mockChecksumLocationName, mockHeadersWithChecksumHeader);
Expand All @@ -112,21 +116,57 @@ describe(flexibleChecksumsMiddleware.name, () => {
});
});

it("adds checksum in the request header", async () => {
const handler = flexibleChecksumsMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, {});
await handler(mockArgs);
expect(getChecksumLocationName).toHaveBeenCalledTimes(1);
expect(hasHeader).toHaveBeenCalledTimes(1);
expect(mockNext).toHaveBeenCalledWith({
...mockArgs,
request: {
...mockRequest,
headers: { ...mockHeaders, [mockChecksumLocationName]: mockChecksum },
},
describe("adds checksum in the request header", () => {
afterEach(() => {
expect(getChecksumAlgorithmForRequest).toHaveBeenCalledTimes(1);
expect(getChecksumLocationName).toHaveBeenCalledTimes(1);
expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1);
});

it("for streaming body", async () => {
(isStreaming as jest.Mock).mockReturnValue(true);
const mockUpdatedBody = { body: "mockUpdatedBody" };
const mockGetAwsChunkedEncodingStream = jest.fn().mockReturnValue(mockUpdatedBody);

const handler = flexibleChecksumsMiddleware(
{ ...mockConfig, getAwsChunkedEncodingStream: mockGetAwsChunkedEncodingStream },
mockMiddlewareConfig
)(mockNext, {});
await handler(mockArgs);

expect(mockNext).toHaveBeenCalledWith({
...mockArgs,
request: {
...mockRequest,
headers: {
...mockHeaders,
"content-length": undefined,
"content-encoding": "aws-chunked",
"transfer-encoding": "chunked",
"x-amz-decoded-content-length": mockHeaders["content-length"],
"x-amz-content-sha256": "STREAMING-UNSIGNED-PAYLOAD-TRAILER",
"x-amz-trailer": mockChecksumLocationName,
},
body: mockUpdatedBody,
},
});
expect(mockGetAwsChunkedEncodingStream).toHaveBeenCalledTimes(1);
});

it("for non-streaming body", async () => {
const handler = flexibleChecksumsMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, {});
await handler(mockArgs);
expect(hasHeader).toHaveBeenCalledTimes(1);
expect(mockNext).toHaveBeenCalledWith({
...mockArgs,
request: {
...mockRequest,
headers: { ...mockHeaders, [mockChecksumLocationName]: mockChecksum },
},
});
expect(hasHeader).toHaveBeenCalledWith(mockChecksumLocationName, mockHeaders);
expect(getChecksum).toHaveBeenCalledTimes(1);
});
expect(hasHeader).toHaveBeenCalledWith(mockChecksumLocationName, mockHeaders);
expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1);
expect(getChecksum).toHaveBeenCalledTimes(1);
});

it("validates checksum from the response header", async () => {
Expand Down
Expand Up @@ -13,6 +13,7 @@ import { getChecksumAlgorithmForRequest } from "./getChecksumAlgorithmForRequest
import { getChecksumLocationName } from "./getChecksumLocationName";
import { FlexibleChecksumsMiddlewareConfig } from "./getFlexibleChecksumsPlugin";
import { hasHeader } from "./hasHeader";
import { isStreaming } from "./isStreaming";
import { selectChecksumAlgorithmFunction } from "./selectChecksumAlgorithmFunction";
import { validateChecksumFromResponse } from "./validateChecksumFromResponse";

Expand All @@ -26,20 +27,38 @@ export const flexibleChecksumsMiddleware =

const { request } = args;
const { body: requestBody, headers } = request;
const { streamHasher, base64Encoder } = config;
const { base64Encoder, streamHasher } = config;
const { input, requestChecksumRequired, requestAlgorithmMember } = middlewareConfig;

const checksumAlgorithm = getChecksumAlgorithmForRequest(input, {
requestChecksumRequired,
requestAlgorithmMember,
});
let updatedBody = requestBody;
let updatedHeaders = headers;

if (checksumAlgorithm) {
const checksumLocationName = getChecksumLocationName(checksumAlgorithm);
// ToDo: Update trailer instead if it is Unsigned-payload.
if (!hasHeader(checksumLocationName, headers)) {
const checksumAlgorithmFn = selectChecksumAlgorithmFunction(checksumAlgorithm, config);
const checksumAlgorithmFn = selectChecksumAlgorithmFunction(checksumAlgorithm, config);
if (isStreaming(requestBody)) {
const { getAwsChunkedEncodingStream, bodyLengthChecker } = config;
updatedBody = getAwsChunkedEncodingStream(requestBody, {
base64Encoder,
bodyLengthChecker,
checksumLocationName,
checksumAlgorithmFn,
streamHasher,
});
updatedHeaders = {
...headers,
"content-encoding": "aws-chunked",
"transfer-encoding": "chunked",
"x-amz-decoded-content-length": headers["content-length"],
"x-amz-content-sha256": "STREAMING-UNSIGNED-PAYLOAD-TRAILER",
"x-amz-trailer": checksumLocationName,
};
delete updatedHeaders["content-length"];
} else if (!hasHeader(checksumLocationName, headers)) {
const checksum = await getChecksum(requestBody, { streamHasher, checksumAlgorithmFn, base64Encoder });
updatedHeaders = {
...headers,
Expand All @@ -53,6 +72,7 @@ export const flexibleChecksumsMiddleware =
request: {
...request,
headers: updatedHeaders,
body: updatedBody,
},
});

Expand Down

0 comments on commit 50b48fe

Please sign in to comment.