From b3562348b2ebfc68a0eb5c54263fb0900f5ec014 Mon Sep 17 00:00:00 2001 From: Trivikram Kamat <16024985+trivikr@users.noreply.github.com> Date: Fri, 15 Apr 2022 10:11:54 -0700 Subject: [PATCH] chore(middleware-flexible-checksums): call stringHasher from flexibleChecksumsMiddleware (#3363) * chore: call stringHasher from flexibleChecksumsMiddleware * test: call stringHasher from flexibleChecksumsMiddleware * test: move mockBase64Encoder to where needed --- .../src/flexibleChecksumsMiddleware.spec.ts | 24 ++++++++++++------- .../src/flexibleChecksumsMiddleware.ts | 6 ++--- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.spec.ts b/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.spec.ts index 17423ad38837..fb5e68f70986 100644 --- a/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.spec.ts +++ b/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.spec.ts @@ -4,22 +4,22 @@ import { BuildHandlerArguments } from "@aws-sdk/types"; import { PreviouslyResolved } from "./configuration"; import { ChecksumAlgorithm } from "./constants"; import { flexibleChecksumsMiddleware } from "./flexibleChecksumsMiddleware"; -import { getChecksum } from "./getChecksum"; import { getChecksumAlgorithmForRequest } from "./getChecksumAlgorithmForRequest"; import { getChecksumLocationName } from "./getChecksumLocationName"; import { hasHeader } from "./hasHeader"; import { isStreaming } from "./isStreaming"; import { selectChecksumAlgorithmFunction } from "./selectChecksumAlgorithmFunction"; +import { stringHasher } from "./stringHasher"; import { validateChecksumFromResponse } from "./validateChecksumFromResponse"; jest.mock("@aws-sdk/protocol-http"); -jest.mock("./getChecksum"); jest.mock("./getChecksumAlgorithmForRequest"); jest.mock("./getChecksumLocationName"); jest.mock("./hasHeader"); jest.mock("./isStreaming"); -jest.mock("./validateChecksumFromResponse"); jest.mock("./selectChecksumAlgorithmFunction"); +jest.mock("./stringHasher"); +jest.mock("./validateChecksumFromResponse"); describe(flexibleChecksumsMiddleware.name, () => { const mockNext = jest.fn(); @@ -42,10 +42,9 @@ describe(flexibleChecksumsMiddleware.name, () => { mockNext.mockResolvedValueOnce(mockResult); const { isInstance } = HttpRequest; (isInstance as unknown as jest.Mock).mockReturnValue(true); - (getChecksum as jest.Mock).mockReturnValue(mockChecksum); (getChecksumAlgorithmForRequest as jest.Mock).mockReturnValue(ChecksumAlgorithm.MD5); (getChecksumLocationName as jest.Mock).mockReturnValue(mockChecksumLocationName); - (hasHeader as jest.Mock).mockReturnValue(false); + (hasHeader as jest.Mock).mockReturnValue(true); (isStreaming as jest.Mock).mockReturnValue(false); (selectChecksumAlgorithmFunction as jest.Mock).mockReturnValue(mockChecksumAlgorithmFunction); }); @@ -67,7 +66,6 @@ describe(flexibleChecksumsMiddleware.name, () => { describe("request checksum", () => { afterEach(() => { expect(getChecksumAlgorithmForRequest).toHaveBeenCalledTimes(1); - expect(getChecksum).not.toHaveBeenCalled(); }); it("if checksumAlgorithm is not defined", async () => { @@ -86,7 +84,6 @@ describe(flexibleChecksumsMiddleware.name, () => { ...mockArgs, request: { ...mockRequest, headers: mockHeadersWithChecksumHeader }, }; - (hasHeader as jest.Mock).mockReturnValue(true); await handler(mockArgsWithChecksumHeader); expect(getChecksumLocationName).toHaveBeenCalledWith(ChecksumAlgorithm.MD5); expect(selectChecksumAlgorithmFunction).toHaveBeenCalledWith(ChecksumAlgorithm.MD5, mockConfig); @@ -168,7 +165,15 @@ describe(flexibleChecksumsMiddleware.name, () => { }); it("for non-streaming body", async () => { - const handler = flexibleChecksumsMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, {}); + const mockRawChecksum = Buffer.from(mockChecksum); + const mockBase64Encoder = jest.fn().mockReturnValue(mockChecksum); + (stringHasher as jest.Mock).mockResolvedValue(mockRawChecksum); + (hasHeader as jest.Mock).mockReturnValue(false); + + const handler = flexibleChecksumsMiddleware( + { ...mockConfig, base64Encoder: mockBase64Encoder }, + mockMiddlewareConfig + )(mockNext, {}); await handler(mockArgs); expect(hasHeader).toHaveBeenCalledTimes(1); expect(mockNext).toHaveBeenCalledWith({ @@ -179,7 +184,8 @@ describe(flexibleChecksumsMiddleware.name, () => { }, }); expect(hasHeader).toHaveBeenCalledWith(mockChecksumLocationName, mockHeaders); - expect(getChecksum).toHaveBeenCalledTimes(1); + expect(stringHasher).toHaveBeenCalledWith(mockChecksumAlgorithmFunction, mockRequest.body); + expect(mockBase64Encoder).toHaveBeenCalledWith(mockRawChecksum); }); }); diff --git a/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.ts b/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.ts index 7f8883bf67f3..c358bc70aca4 100644 --- a/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.ts +++ b/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.ts @@ -8,13 +8,13 @@ import { } from "@aws-sdk/types"; import { PreviouslyResolved } from "./configuration"; -import { getChecksum } from "./getChecksum"; import { getChecksumAlgorithmForRequest } from "./getChecksumAlgorithmForRequest"; import { getChecksumLocationName } from "./getChecksumLocationName"; import { FlexibleChecksumsMiddlewareConfig } from "./getFlexibleChecksumsPlugin"; import { hasHeader } from "./hasHeader"; import { isStreaming } from "./isStreaming"; import { selectChecksumAlgorithmFunction } from "./selectChecksumAlgorithmFunction"; +import { stringHasher } from "./stringHasher"; import { validateChecksumFromResponse } from "./validateChecksumFromResponse"; export const flexibleChecksumsMiddleware = @@ -59,10 +59,10 @@ export const flexibleChecksumsMiddleware = }; delete updatedHeaders["content-length"]; } else if (!hasHeader(checksumLocationName, headers)) { - const checksum = await getChecksum(requestBody, { streamHasher, checksumAlgorithmFn, base64Encoder }); + const rawChecksum = await stringHasher(checksumAlgorithmFn, requestBody); updatedHeaders = { ...headers, - [checksumLocationName]: checksum, + [checksumLocationName]: base64Encoder(rawChecksum), }; } }