diff --git a/CHANGELOG.md b/CHANGELOG.md index 88cfba26a809..337de80ea7f2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ### Features +- `[feat(@jest/mock)]` Add `withImplementation` method for temporarily overriding a mock. - `[feat(@jest/environment, jest-runtime)]` Allow `jest.requireActual` and `jest.requireMock` to take a type argument ([#13253](https://github.com/facebook/jest/pull/13253)) - `[feat(@jest/environment]` Allow `jest.mock` and `jest.doMock` to take a type argument ([#13254](https://github.com/facebook/jest/pull/13254)) - `[@jest/fake-timers]` Add `jest.now()` to return the current fake clock time ([#13244](https://github.com/facebook/jest/pull/13244), [13246](https://github.com/facebook/jest/pull/13246)) diff --git a/packages/jest-mock/README.md b/packages/jest-mock/README.md index 5bc75093f8fa..5ec503f64e39 100644 --- a/packages/jest-mock/README.md +++ b/packages/jest-mock/README.md @@ -98,3 +98,9 @@ In case both `.mockImplementationOnce()` / `.mockImplementation()` and `.mockRet - if the last call is `.mockReturnValueOnce()` or `.mockReturnValue()`, use the specific return value or default return value. If specific return values are used up or no default return value is set, fall back to try `.mockImplementation()`; - if the last call is `.mockImplementationOnce()` or `.mockImplementation()`, run the specific implementation and return the result or run default implementation and return the result. + +##### `.withImplementation(function, callback)` + +Temporarily overrides the default mock implementation within the callback, then restores it's previous implementation. + +If the callback is async or returns a promise like object, `withImplementation` will return a promise. Awaiting the promise will await the callback and reset the implementation. diff --git a/packages/jest-mock/src/__tests__/index.test.ts b/packages/jest-mock/src/__tests__/index.test.ts index 6b9f716d7ef0..9f1efe45f5e1 100644 --- a/packages/jest-mock/src/__tests__/index.test.ts +++ b/packages/jest-mock/src/__tests__/index.test.ts @@ -1073,6 +1073,59 @@ describe('moduleMocker', () => { }); }); + describe('withImplementation', () => { + it('sets an implementation which is available within the callback', async () => { + const mock1 = jest.fn(); + const mock2 = jest.fn(); + + const Module = jest.fn(() => ({someFn: mock1})); + const testFn = function () { + const m = new Module(); + m.someFn(); + }; + + Module.withImplementation( + () => ({someFn: mock2}), + () => { + testFn(); + expect(mock2).toHaveBeenCalled(); + expect(mock1).not.toHaveBeenCalled(); + }, + ); + + testFn(); + expect(mock1).toHaveBeenCalled(); + }); + + it('returns a promise if the provided callback is asynchronous', async () => { + const mock1 = jest.fn(); + const mock2 = jest.fn(); + + const Module = jest.fn(() => ({someFn: mock1})); + const testFn = function () { + const m = new Module(); + m.someFn(); + }; + + const promise = Module.withImplementation( + () => ({someFn: mock2}), + async () => { + testFn(); + expect(mock2).toHaveBeenCalled(); + expect(mock1).not.toHaveBeenCalled(); + }, + ); + + // Is there a better way to detect a promise? + expect(typeof promise.then).toBe('function'); + + await promise; + + testFn(); + expect(mock1).toHaveBeenCalled(); + }); + }); + test('mockReturnValue does not override mockImplementationOnce', () => { const mockFn = jest .fn() diff --git a/packages/jest-mock/src/index.ts b/packages/jest-mock/src/index.ts index 69319a05c583..f986d615d5b1 100644 --- a/packages/jest-mock/src/index.ts +++ b/packages/jest-mock/src/index.ts @@ -124,6 +124,12 @@ type RejectType = ReturnType extends PromiseLike ? unknown : never; +type WithImplementationSyncCallbackReturn = void | undefined; +type WithImplementationAsyncCallbackReturn = Promise; +type WithImplementationCallbackReturn = + | WithImplementationSyncCallbackReturn + | WithImplementationAsyncCallbackReturn; + export interface MockInstance { _isMockFunction: true; _protoImpl: Function; @@ -135,6 +141,12 @@ export interface MockInstance { mockRestore(): void; mockImplementation(fn: T): this; mockImplementationOnce(fn: T): this; + withImplementation( + fn: T, + callback: () => R, + ): R extends WithImplementationAsyncCallbackReturn + ? Promise + : undefined; mockName(name: string): this; mockReturnThis(): this; mockReturnValue(value: ReturnType): this; @@ -768,6 +780,34 @@ export class ModuleMocker { return f; }; + f.withImplementation = ( + fn: UnknownFunction, + callback: () => R, + // @ts-expect-error: Type guards are not advanced enough for this use case + ): R extends WithImplementationAsyncCallbackReturn + ? Promise + : undefined => { + // Remember previous mock implementation, then set new one + const mockConfig = this._ensureMockConfig(f); + const previousImplementation = mockConfig.mockImpl; + mockConfig.mockImpl = fn; + + const returnedValue = callback(); + + if ( + typeof returnedValue === 'object' && + returnedValue !== null && + typeof returnedValue.then === 'function' + ) { + // @ts-expect-error: Type guards are not advanced enough for this use case + return returnedValue.then(() => { + mockConfig.mockImpl = previousImplementation; + }); + } else { + mockConfig.mockImpl = previousImplementation; + } + }; + f.mockImplementation = (fn: UnknownFunction) => { // next function call will use mock implementation return value const mockConfig = this._ensureMockConfig(f);