From 50b48fec155cb87f85799f001beac6ca6434372b Mon Sep 17 00:00:00 2001 From: Trivikram Kamat <16024985+trivikr@users.noreply.github.com> Date: Wed, 23 Feb 2022 09:40:10 -0800 Subject: [PATCH] feat(middleware-flexible-checksums): support trailing checksums with aws-chunked encoding (#3347) --- .../src/configuration.ts | 12 ++- .../src/flexibleChecksumsMiddleware.spec.ts | 84 ++++++++++++++----- .../src/flexibleChecksumsMiddleware.ts | 28 ++++++- 3 files changed, 97 insertions(+), 27 deletions(-) diff --git a/packages/middleware-flexible-checksums/src/configuration.ts b/packages/middleware-flexible-checksums/src/configuration.ts index be3e7a41aeb9..6f8a0f04f0c2 100644 --- a/packages/middleware-flexible-checksums/src/configuration.ts +++ b/packages/middleware-flexible-checksums/src/configuration.ts @@ -1,4 +1,4 @@ -import { Encoder, HashConstructor, StreamHasher } from "@aws-sdk/types"; +import { Encoder, GetAwsChunkedEncodingStream, HashConstructor, StreamHasher } from "@aws-sdk/types"; export interface PreviouslyResolved { /** @@ -7,6 +7,16 @@ export interface PreviouslyResolved { */ base64Encoder: Encoder; + /** + * A function that can calculate the length of a body. + */ + bodyLengthChecker: (body: any) => number | undefined; + + /** + * A function that returns Readable Stream which follows aws-chunked encoding stream. + */ + getAwsChunkedEncodingStream: GetAwsChunkedEncodingStream; + /** * A constructor for a class implementing the {@link Hash} interface that computes MD5 hashes. * @internal diff --git a/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.spec.ts b/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.spec.ts index 144d78d10f40..492d8c4d3295 100644 --- a/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.spec.ts +++ b/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.spec.ts @@ -9,16 +9,18 @@ import { getChecksumAlgorithmForRequest } from "./getChecksumAlgorithmForRequest import { getChecksumLocationName } from "./getChecksumLocationName"; import { FlexibleChecksumsMiddlewareConfig } from "./getFlexibleChecksumsPlugin"; import { hasHeader } from "./hasHeader"; +import { isStreaming } from "./isStreaming"; import { selectChecksumAlgorithmFunction } from "./selectChecksumAlgorithmFunction"; import { validateChecksumFromResponse } from "./validateChecksumFromResponse"; jest.mock("@aws-sdk/protocol-http"); +jest.mock("./getChecksum"); jest.mock("./getChecksumAlgorithmForRequest"); jest.mock("./getChecksumLocationName"); -jest.mock("./selectChecksumAlgorithmFunction"); -jest.mock("./getChecksum"); jest.mock("./hasHeader"); +jest.mock("./isStreaming"); jest.mock("./validateChecksumFromResponse"); +jest.mock("./selectChecksumAlgorithmFunction"); describe(flexibleChecksumsMiddleware.name, () => { const mockNext = jest.fn(); @@ -31,8 +33,8 @@ describe(flexibleChecksumsMiddleware.name, () => { const mockConfig = {} as PreviouslyResolved; const mockMiddlewareConfig = { input: mockInput } as FlexibleChecksumsMiddlewareConfig; - const mockBody = {}; - const mockHeaders = {}; + const mockBody = { body: "mockBody" }; + const mockHeaders = { "content-length": 100 }; const mockRequest = { body: mockBody, headers: mockHeaders }; const mockArgs = { request: mockRequest } as BuildHandlerArguments; const mockResult = { response: {} }; @@ -41,11 +43,12 @@ describe(flexibleChecksumsMiddleware.name, () => { mockNext.mockResolvedValueOnce(mockResult); const { isInstance } = HttpRequest; (isInstance as unknown as jest.Mock).mockReturnValue(true); + (getChecksum as jest.Mock).mockReturnValue(mockChecksum); (getChecksumAlgorithmForRequest as jest.Mock).mockReturnValue(ChecksumAlgorithm.MD5); (getChecksumLocationName as jest.Mock).mockReturnValue(mockChecksumLocationName); - (selectChecksumAlgorithmFunction as jest.Mock).mockReturnValue(mockChecksumAlgorithmFunction); - (getChecksum as jest.Mock).mockReturnValue(mockChecksum); (hasHeader as jest.Mock).mockReturnValue(false); + (isStreaming as jest.Mock).mockReturnValue(false); + (selectChecksumAlgorithmFunction as jest.Mock).mockReturnValue(mockChecksumAlgorithmFunction); }); afterEach(() => { @@ -53,7 +56,7 @@ describe(flexibleChecksumsMiddleware.name, () => { jest.clearAllMocks(); }); - describe("skips checksum computation", () => { + describe("skips", () => { it("if not an instance of HttpRequest", async () => { const { isInstance } = HttpRequest; (isInstance as unknown as jest.Mock).mockReturnValue(false); @@ -65,7 +68,6 @@ describe(flexibleChecksumsMiddleware.name, () => { describe("request checksum", () => { afterEach(() => { expect(getChecksumAlgorithmForRequest).toHaveBeenCalledTimes(1); - expect(selectChecksumAlgorithmFunction).not.toHaveBeenCalled(); expect(getChecksum).not.toHaveBeenCalled(); }); @@ -75,6 +77,7 @@ describe(flexibleChecksumsMiddleware.name, () => { await handler(mockArgs); expect(getChecksumLocationName).not.toHaveBeenCalled(); expect(mockNext).toHaveBeenCalledWith(mockArgs); + expect(selectChecksumAlgorithmFunction).not.toHaveBeenCalled(); }); it("if header is already present", async () => { @@ -87,6 +90,7 @@ describe(flexibleChecksumsMiddleware.name, () => { (hasHeader as jest.Mock).mockReturnValue(true); await handler(mockArgsWithChecksumHeader); expect(getChecksumLocationName).toHaveBeenCalledTimes(1); + expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1); expect(hasHeader).toHaveBeenCalledTimes(1); expect(mockNext).toHaveBeenCalledWith(mockArgsWithChecksumHeader); expect(hasHeader).toHaveBeenCalledWith(mockChecksumLocationName, mockHeadersWithChecksumHeader); @@ -112,21 +116,57 @@ describe(flexibleChecksumsMiddleware.name, () => { }); }); - it("adds checksum in the request header", async () => { - const handler = flexibleChecksumsMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, {}); - await handler(mockArgs); - expect(getChecksumLocationName).toHaveBeenCalledTimes(1); - expect(hasHeader).toHaveBeenCalledTimes(1); - expect(mockNext).toHaveBeenCalledWith({ - ...mockArgs, - request: { - ...mockRequest, - headers: { ...mockHeaders, [mockChecksumLocationName]: mockChecksum }, - }, + describe("adds checksum in the request header", () => { + afterEach(() => { + expect(getChecksumAlgorithmForRequest).toHaveBeenCalledTimes(1); + expect(getChecksumLocationName).toHaveBeenCalledTimes(1); + expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1); + }); + + it("for streaming body", async () => { + (isStreaming as jest.Mock).mockReturnValue(true); + const mockUpdatedBody = { body: "mockUpdatedBody" }; + const mockGetAwsChunkedEncodingStream = jest.fn().mockReturnValue(mockUpdatedBody); + + const handler = flexibleChecksumsMiddleware( + { ...mockConfig, getAwsChunkedEncodingStream: mockGetAwsChunkedEncodingStream }, + mockMiddlewareConfig + )(mockNext, {}); + await handler(mockArgs); + + expect(mockNext).toHaveBeenCalledWith({ + ...mockArgs, + request: { + ...mockRequest, + headers: { + ...mockHeaders, + "content-length": undefined, + "content-encoding": "aws-chunked", + "transfer-encoding": "chunked", + "x-amz-decoded-content-length": mockHeaders["content-length"], + "x-amz-content-sha256": "STREAMING-UNSIGNED-PAYLOAD-TRAILER", + "x-amz-trailer": mockChecksumLocationName, + }, + body: mockUpdatedBody, + }, + }); + expect(mockGetAwsChunkedEncodingStream).toHaveBeenCalledTimes(1); + }); + + it("for non-streaming body", async () => { + const handler = flexibleChecksumsMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, {}); + await handler(mockArgs); + expect(hasHeader).toHaveBeenCalledTimes(1); + expect(mockNext).toHaveBeenCalledWith({ + ...mockArgs, + request: { + ...mockRequest, + headers: { ...mockHeaders, [mockChecksumLocationName]: mockChecksum }, + }, + }); + expect(hasHeader).toHaveBeenCalledWith(mockChecksumLocationName, mockHeaders); + expect(getChecksum).toHaveBeenCalledTimes(1); }); - expect(hasHeader).toHaveBeenCalledWith(mockChecksumLocationName, mockHeaders); - expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1); - expect(getChecksum).toHaveBeenCalledTimes(1); }); it("validates checksum from the response header", async () => { diff --git a/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.ts b/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.ts index cf3618d3a4e0..7f8883bf67f3 100644 --- a/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.ts +++ b/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.ts @@ -13,6 +13,7 @@ import { getChecksumAlgorithmForRequest } from "./getChecksumAlgorithmForRequest import { getChecksumLocationName } from "./getChecksumLocationName"; import { FlexibleChecksumsMiddlewareConfig } from "./getFlexibleChecksumsPlugin"; import { hasHeader } from "./hasHeader"; +import { isStreaming } from "./isStreaming"; import { selectChecksumAlgorithmFunction } from "./selectChecksumAlgorithmFunction"; import { validateChecksumFromResponse } from "./validateChecksumFromResponse"; @@ -26,20 +27,38 @@ export const flexibleChecksumsMiddleware = const { request } = args; const { body: requestBody, headers } = request; - const { streamHasher, base64Encoder } = config; + const { base64Encoder, streamHasher } = config; const { input, requestChecksumRequired, requestAlgorithmMember } = middlewareConfig; const checksumAlgorithm = getChecksumAlgorithmForRequest(input, { requestChecksumRequired, requestAlgorithmMember, }); + let updatedBody = requestBody; let updatedHeaders = headers; if (checksumAlgorithm) { const checksumLocationName = getChecksumLocationName(checksumAlgorithm); - // ToDo: Update trailer instead if it is Unsigned-payload. - if (!hasHeader(checksumLocationName, headers)) { - const checksumAlgorithmFn = selectChecksumAlgorithmFunction(checksumAlgorithm, config); + const checksumAlgorithmFn = selectChecksumAlgorithmFunction(checksumAlgorithm, config); + if (isStreaming(requestBody)) { + const { getAwsChunkedEncodingStream, bodyLengthChecker } = config; + updatedBody = getAwsChunkedEncodingStream(requestBody, { + base64Encoder, + bodyLengthChecker, + checksumLocationName, + checksumAlgorithmFn, + streamHasher, + }); + updatedHeaders = { + ...headers, + "content-encoding": "aws-chunked", + "transfer-encoding": "chunked", + "x-amz-decoded-content-length": headers["content-length"], + "x-amz-content-sha256": "STREAMING-UNSIGNED-PAYLOAD-TRAILER", + "x-amz-trailer": checksumLocationName, + }; + delete updatedHeaders["content-length"]; + } else if (!hasHeader(checksumLocationName, headers)) { const checksum = await getChecksum(requestBody, { streamHasher, checksumAlgorithmFn, base64Encoder }); updatedHeaders = { ...headers, @@ -53,6 +72,7 @@ export const flexibleChecksumsMiddleware = request: { ...request, headers: updatedHeaders, + body: updatedBody, }, });