diff --git a/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.spec.ts b/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.spec.ts index 492d8c4d3295..f653281759d8 100644 --- a/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.spec.ts +++ b/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.spec.ts @@ -4,23 +4,23 @@ import { BuildHandlerArguments } from "@aws-sdk/types"; import { PreviouslyResolved } from "./configuration"; import { ChecksumAlgorithm } from "./constants"; import { flexibleChecksumsMiddleware } from "./flexibleChecksumsMiddleware"; -import { getChecksum } from "./getChecksum"; import { getChecksumAlgorithmForRequest } from "./getChecksumAlgorithmForRequest"; import { getChecksumLocationName } from "./getChecksumLocationName"; import { FlexibleChecksumsMiddlewareConfig } from "./getFlexibleChecksumsPlugin"; import { hasHeader } from "./hasHeader"; import { isStreaming } from "./isStreaming"; import { selectChecksumAlgorithmFunction } from "./selectChecksumAlgorithmFunction"; +import { stringHasher } from "./stringHasher"; import { validateChecksumFromResponse } from "./validateChecksumFromResponse"; jest.mock("@aws-sdk/protocol-http"); -jest.mock("./getChecksum"); jest.mock("./getChecksumAlgorithmForRequest"); jest.mock("./getChecksumLocationName"); jest.mock("./hasHeader"); jest.mock("./isStreaming"); -jest.mock("./validateChecksumFromResponse"); jest.mock("./selectChecksumAlgorithmFunction"); +jest.mock("./stringHasher"); +jest.mock("./validateChecksumFromResponse"); describe(flexibleChecksumsMiddleware.name, () => { const mockNext = jest.fn(); @@ -43,10 +43,9 @@ describe(flexibleChecksumsMiddleware.name, () => { mockNext.mockResolvedValueOnce(mockResult); const { isInstance } = HttpRequest; (isInstance as unknown as jest.Mock).mockReturnValue(true); - (getChecksum as jest.Mock).mockReturnValue(mockChecksum); (getChecksumAlgorithmForRequest as jest.Mock).mockReturnValue(ChecksumAlgorithm.MD5); (getChecksumLocationName as jest.Mock).mockReturnValue(mockChecksumLocationName); - (hasHeader as jest.Mock).mockReturnValue(false); + (hasHeader as jest.Mock).mockReturnValue(true); (isStreaming as jest.Mock).mockReturnValue(false); (selectChecksumAlgorithmFunction as jest.Mock).mockReturnValue(mockChecksumAlgorithmFunction); }); @@ -68,7 +67,6 @@ describe(flexibleChecksumsMiddleware.name, () => { describe("request checksum", () => { afterEach(() => { expect(getChecksumAlgorithmForRequest).toHaveBeenCalledTimes(1); - expect(getChecksum).not.toHaveBeenCalled(); }); it("if checksumAlgorithm is not defined", async () => { @@ -87,7 +85,6 @@ describe(flexibleChecksumsMiddleware.name, () => { ...mockArgs, request: { ...mockRequest, headers: mockHeadersWithChecksumHeader }, }; - (hasHeader as jest.Mock).mockReturnValue(true); await handler(mockArgsWithChecksumHeader); expect(getChecksumLocationName).toHaveBeenCalledTimes(1); expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1); @@ -154,7 +151,15 @@ describe(flexibleChecksumsMiddleware.name, () => { }); it("for non-streaming body", async () => { - const handler = flexibleChecksumsMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, {}); + const mockRawChecksum = Buffer.from(mockChecksum); + const mockBase64Encoder = jest.fn().mockReturnValue(mockChecksum); + (stringHasher as jest.Mock).mockResolvedValue(mockRawChecksum); + (hasHeader as jest.Mock).mockReturnValue(false); + + const handler = flexibleChecksumsMiddleware( + { ...mockConfig, base64Encoder: mockBase64Encoder }, + mockMiddlewareConfig + )(mockNext, {}); await handler(mockArgs); expect(hasHeader).toHaveBeenCalledTimes(1); expect(mockNext).toHaveBeenCalledWith({ @@ -165,7 +170,8 @@ describe(flexibleChecksumsMiddleware.name, () => { }, }); expect(hasHeader).toHaveBeenCalledWith(mockChecksumLocationName, mockHeaders); - expect(getChecksum).toHaveBeenCalledTimes(1); + expect(stringHasher).toHaveBeenCalledWith(mockChecksumAlgorithmFunction, mockRequest.body); + expect(mockBase64Encoder).toHaveBeenCalledWith(mockRawChecksum); }); }); diff --git a/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.ts b/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.ts index 7f8883bf67f3..c358bc70aca4 100644 --- a/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.ts +++ b/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.ts @@ -8,13 +8,13 @@ import { } from "@aws-sdk/types"; import { PreviouslyResolved } from "./configuration"; -import { getChecksum } from "./getChecksum"; import { getChecksumAlgorithmForRequest } from "./getChecksumAlgorithmForRequest"; import { getChecksumLocationName } from "./getChecksumLocationName"; import { FlexibleChecksumsMiddlewareConfig } from "./getFlexibleChecksumsPlugin"; import { hasHeader } from "./hasHeader"; import { isStreaming } from "./isStreaming"; import { selectChecksumAlgorithmFunction } from "./selectChecksumAlgorithmFunction"; +import { stringHasher } from "./stringHasher"; import { validateChecksumFromResponse } from "./validateChecksumFromResponse"; export const flexibleChecksumsMiddleware = @@ -59,10 +59,10 @@ export const flexibleChecksumsMiddleware = }; delete updatedHeaders["content-length"]; } else if (!hasHeader(checksumLocationName, headers)) { - const checksum = await getChecksum(requestBody, { streamHasher, checksumAlgorithmFn, base64Encoder }); + const rawChecksum = await stringHasher(checksumAlgorithmFn, requestBody); updatedHeaders = { ...headers, - [checksumLocationName]: checksum, + [checksumLocationName]: base64Encoder(rawChecksum), }; } }