diff --git a/examples/mocks/src/zustand.ts b/examples/mocks/__mocks__/zustand.ts similarity index 55% rename from examples/mocks/src/zustand.ts rename to examples/mocks/__mocks__/zustand.ts index 9d1c16beb43b..6d3bc5dd95ac 100644 --- a/examples/mocks/src/zustand.ts +++ b/examples/mocks/__mocks__/zustand.ts @@ -1,13 +1,8 @@ import actualCreate from 'zustand' -// a variable to hold reset functions for all stores declared in the app -const storeResetFns = new Set() - // when creating a store, we get its initial state, create a reset function and add it in the set const create = vi.fn((createState) => { const store = actualCreate(createState) - const initialState = store.getState() - storeResetFns.add(() => store.setState(initialState, true)) return store }) diff --git a/examples/mocks/src/zustand-magic.ts b/examples/mocks/src/zustand-magic.ts new file mode 100644 index 000000000000..dc364a65803f --- /dev/null +++ b/examples/mocks/src/zustand-magic.ts @@ -0,0 +1,5 @@ +import zustand from 'zustand' + +export const magic = () => { + return zustand() +} diff --git a/examples/mocks/test/self-importing.test.ts b/examples/mocks/test/self-importing.test.ts index bb1d8f98bb5a..2b244b567089 100644 --- a/examples/mocks/test/self-importing.test.ts +++ b/examples/mocks/test/self-importing.test.ts @@ -1,4 +1,5 @@ import zustand from 'zustand' +import { magic } from '../src/zustand-magic' vi.mock('zustand') @@ -6,4 +7,10 @@ describe('zustand didn\'t go into an infinite loop', () => { test('zustand is mocked', () => { expect(vi.isMockFunction(zustand)).toBe(true) }) + + test('magic calls zustand', () => { + const store = magic() + expect(zustand).toHaveBeenCalled() + expect(store).toBeTypeOf('function') + }) }) diff --git a/packages/vitest/src/runtime/mocker.ts b/packages/vitest/src/runtime/mocker.ts index 9b5c9c912e29..ed166a8c8eee 100644 --- a/packages/vitest/src/runtime/mocker.ts +++ b/packages/vitest/src/runtime/mocker.ts @@ -250,7 +250,10 @@ export class VitestMocker { } if (typeof mock === 'function' && !callstack.includes(`mock:${dep}`)) { callstack.push(`mock:${dep}`) - return this.callFunctionMock(dep, mock) + const result = await this.callFunctionMock(dep, mock) + const indexMock = callstack.indexOf(`mock:${dep}`) + callstack.splice(indexMock, 1) + return result } if (typeof mock === 'string' && !callstack.includes(mock)) dep = mock