Skip to content

Commit

Permalink
feat: bind expect state to context (#1468)
Browse files Browse the repository at this point in the history
* feat: bind expect state to context

This fixes calling expect.assertions inside concurrent

* refactor: cleanup

* chore: fix types

* chore: allow not passing expect to getMatcherContext
  • Loading branch information
sheremet-va committed Jun 13, 2022
1 parent 1e30295 commit 35ab058
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 63 deletions.
2 changes: 2 additions & 0 deletions packages/vitest/src/integrations/chai/constants.ts
@@ -0,0 +1,2 @@
export const GLOBAL_EXPECT = Symbol.for('expect-global')
export const MATCHERS_OBJECT = Symbol.for('matchers-object')
56 changes: 49 additions & 7 deletions packages/vitest/src/integrations/chai/index.ts
@@ -1,12 +1,15 @@
import chai from 'chai'
import chai, { util } from 'chai'
import './setup'
import type { Test } from '../../types'
import { getFullName } from '../../utils'
import type { MatcherState } from '../../types/chai'
import { getState, setState } from './jest-expect'
import { GLOBAL_EXPECT } from './constants'

export function createExpect(test?: Test) {
const expect = ((value: any, message?: string): Vi.Assertion => {
const { assertionCalls } = getState()
setState({ assertionCalls: assertionCalls + 1 })
const { assertionCalls } = getState(expect)
setState({ assertionCalls: assertionCalls + 1 }, expect)
const assert = chai.expect(value, message) as unknown as Vi.Assertion
if (test)
// @ts-expect-error internal
Expand All @@ -16,16 +19,55 @@ export function createExpect(test?: Test) {
}) as Vi.ExpectStatic
Object.assign(expect, chai.expect)

expect.getState = getState
expect.setState = setState
expect.getState = () => getState(expect)
expect.setState = state => setState(state as Partial<MatcherState>, expect)

setState({
assertionCalls: 0,
isExpectingAssertions: false,
isExpectingAssertionsError: null,
expectedAssertionsNumber: null,
expectedAssertionsNumberErrorGen: null,
testPath: test?.suite.file?.filepath,
currentTestName: test ? getFullName(test) : undefined,
}, expect)

// @ts-expect-error untyped
expect.extend = matchers => chai.expect.extend(expect, matchers)

function assertions(expected: number) {
const errorGen = () => new Error(`expected number of assertions to be ${expected}, but got ${expect.getState().assertionCalls}`)
if (Error.captureStackTrace)
Error.captureStackTrace(errorGen(), assertions)

expect.setState({
expectedAssertionsNumber: expected,
expectedAssertionsNumberErrorGen: errorGen,
})
}

function hasAssertions() {
const error = new Error('expected any number of assertion, but got none')
if (Error.captureStackTrace)
Error.captureStackTrace(error, hasAssertions)

expect.setState({
isExpectingAssertions: true,
isExpectingAssertionsError: error,
})
}

util.addMethod(expect, 'assertions', assertions)
util.addMethod(expect, 'hasAssertions', hasAssertions)

return expect
}

const expect = createExpect()
const globalExpect = createExpect()

Object.defineProperty(globalThis, GLOBAL_EXPECT, {
value: globalExpect,
})

export { assert, should } from 'chai'
export { chai, expect }
export { chai, globalExpect as expect }
@@ -1,4 +1,5 @@
import type { ChaiPlugin, MatcherState } from '../../types/chai'
import { GLOBAL_EXPECT } from './constants'
import { getState } from './jest-expect'
import * as matcherUtils from './jest-matcher-utils'

Expand All @@ -19,9 +20,9 @@ export abstract class AsymmetricMatcher<

constructor(protected sample: T, protected inverse = false) {}

protected getMatcherContext(): State {
protected getMatcherContext(expect?: Vi.ExpectStatic): State {
return {
...getState(),
...getState(expect || (globalThis as any)[GLOBAL_EXPECT]),
equals,
isNot: this.inverse,
utils: matcherUtils,
Expand Down
56 changes: 10 additions & 46 deletions packages/vitest/src/integrations/chai/jest-expect.ts
Expand Up @@ -10,31 +10,25 @@ import type { ChaiPlugin, MatcherState } from '../../types/chai'
import { arrayBufferEquality, generateToBeMessage, iterableEquality, equals as jestEquals, sparseArrayEquality, subsetEquality, typeEquality } from './jest-utils'
import type { AsymmetricMatcher } from './jest-asymmetric-matchers'
import { stringify } from './jest-matcher-utils'
import { MATCHERS_OBJECT } from './constants'

const MATCHERS_OBJECT = Symbol.for('matchers-object')

if (!Object.prototype.hasOwnProperty.call(global, MATCHERS_OBJECT)) {
const defaultState: Partial<MatcherState> = {
assertionCalls: 0,
isExpectingAssertions: false,
isExpectingAssertionsError: null,
expectedAssertionsNumber: null,
expectedAssertionsNumberErrorGen: null,
}
if (!Object.prototype.hasOwnProperty.call(globalThis, MATCHERS_OBJECT)) {
Object.defineProperty(globalThis, MATCHERS_OBJECT, {
value: {
state: defaultState,
},
value: new WeakMap<Vi.ExpectStatic, MatcherState>(),
})
}

export const getState = <State extends MatcherState = MatcherState>(): State =>
(globalThis as any)[MATCHERS_OBJECT].state
export const getState = <State extends MatcherState = MatcherState>(expect: Vi.ExpectStatic): State =>
(globalThis as any)[MATCHERS_OBJECT].get(expect)

export const setState = <State extends MatcherState = MatcherState>(
state: Partial<State>,
expect: Vi.ExpectStatic,
): void => {
Object.assign((globalThis as any)[MATCHERS_OBJECT].state, state)
const map = (globalThis as any)[MATCHERS_OBJECT]
const current = map.get(expect) || {}
Object.assign(current, state)
map.set(expect, current)
}

// Jest Expect Compact
Expand Down Expand Up @@ -676,36 +670,6 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
return proxy
})

utils.addMethod(
chai.expect,
'assertions',
function assertions(expected: number) {
const errorGen = () => new Error(`expected number of assertions to be ${expected}, but got ${getState().assertionCalls}`)
if (Error.captureStackTrace)
Error.captureStackTrace(errorGen(), assertions)

setState({
expectedAssertionsNumber: expected,
expectedAssertionsNumberErrorGen: errorGen,
})
},
)

utils.addMethod(
chai.expect,
'hasAssertions',
function hasAssertions() {
const error = new Error('expected any number of assertion, but got none')
if (Error.captureStackTrace)
Error.captureStackTrace(error, hasAssertions)

setState({
isExpectingAssertions: true,
isExpectingAssertionsError: error,
})
},
)

utils.addMethod(
chai.expect,
'addSnapshotSerializer',
Expand Down
10 changes: 5 additions & 5 deletions packages/vitest/src/integrations/chai/jest-extend.ts
Expand Up @@ -20,7 +20,7 @@ import {
const isAsyncFunction = (fn: unknown) =>
typeof fn === 'function' && (fn as any)[Symbol.toStringTag] === 'AsyncFunction'

const getMatcherState = (assertion: Chai.AssertionStatic & Chai.Assertion) => {
const getMatcherState = (assertion: Chai.AssertionStatic & Chai.Assertion, expect: Vi.ExpectStatic) => {
const obj = assertion._obj
const isNot = util.flag(assertion, 'negate') as boolean
const promise = util.flag(assertion, 'promise') || ''
Expand All @@ -31,7 +31,7 @@ const getMatcherState = (assertion: Chai.AssertionStatic & Chai.Assertion) => {
}

const matcherState: MatcherState = {
...getState(),
...getState(expect),
isNot,
utils: jestUtils,
promise,
Expand All @@ -58,7 +58,7 @@ function JestExtendPlugin(expect: Vi.ExpectStatic, matchers: MatchersObject): Ch
return (c, utils) => {
Object.entries(matchers).forEach(([expectAssertionName, expectAssertion]) => {
function expectSyncWrapper(this: Chai.AssertionStatic & Chai.Assertion, ...args: any[]) {
const { state, isNot, obj } = getMatcherState(this)
const { state, isNot, obj } = getMatcherState(this, expect)

// @ts-expect-error args wanting tuple
const { pass, message, actual, expected } = expectAssertion.call(state, obj, ...args) as SyncExpectationResult
Expand All @@ -68,7 +68,7 @@ function JestExtendPlugin(expect: Vi.ExpectStatic, matchers: MatchersObject): Ch
}

async function expectAsyncWrapper(this: Chai.AssertionStatic & Chai.Assertion, ...args: any[]) {
const { state, isNot, obj } = getMatcherState(this)
const { state, isNot, obj } = getMatcherState(this, expect)

// @ts-expect-error args wanting tuple
const { pass, message, actual, expected } = await expectAssertion.call(state, obj, ...args) as SyncExpectationResult
Expand All @@ -88,7 +88,7 @@ function JestExtendPlugin(expect: Vi.ExpectStatic, matchers: MatchersObject): Ch

asymmetricMatch(other: unknown) {
const { pass } = expectAssertion.call(
this.getMatcherContext(),
this.getMatcherContext(expect),
other,
...this.sample,
) as SyncExpectationResult
Expand Down
5 changes: 5 additions & 0 deletions packages/vitest/src/runtime/context.ts
Expand Up @@ -61,6 +61,11 @@ export function createTestContext(test: Test): TestContext {
return _expect
},
})
Object.defineProperty(context, '_local', {
get() {
return _expect != null
},
})

return context
}
Expand Down
16 changes: 13 additions & 3 deletions packages/vitest/src/runtime/run.ts
Expand Up @@ -2,8 +2,9 @@ import type { File, HookCleanupCallback, HookListener, ResolvedConfig, Suite, Su
import { vi } from '../integrations/vi'
import { getSnapshotClient } from '../integrations/snapshot/chai'
import { clearTimeout, getFullName, getWorkerState, hasFailed, hasTests, partitionSuiteChildren, setTimeout } from '../utils'
import { getState, setState } from '../integrations/chai/jest-expect'
import { takeCoverage } from '../integrations/coverage'
import { getState, setState } from '../integrations/chai/jest-expect'
import { GLOBAL_EXPECT } from '../integrations/chai/constants'
import { getFn, getHooks } from './map'
import { rpc } from './rpc'
import { collectTests } from './collect'
Expand Down Expand Up @@ -111,9 +112,18 @@ export async function runTest(test: Test) {
expectedAssertionsNumberErrorGen: null,
testPath: test.suite.file?.filepath,
currentTestName: getFullName(test),
})
}, (globalThis as any)[GLOBAL_EXPECT])
await getFn(test)()
const { assertionCalls, expectedAssertionsNumber, expectedAssertionsNumberErrorGen, isExpectingAssertions, isExpectingAssertionsError } = getState()
const {
assertionCalls,
expectedAssertionsNumber,
expectedAssertionsNumberErrorGen,
isExpectingAssertions,
isExpectingAssertionsError,
// @ts-expect-error local is private
} = test.context._local
? test.context.expect.getState()
: getState((globalThis as any)[GLOBAL_EXPECT])
if (expectedAssertionsNumber !== null && assertionCalls !== expectedAssertionsNumber)
throw expectedAssertionsNumberErrorGen!()
if (isExpectingAssertions === true && assertionCalls === 0)
Expand Down
19 changes: 19 additions & 0 deletions test/core/test/concurrent.spec.ts
@@ -0,0 +1,19 @@
import { test } from 'vitest'

function delay(ms: number) {
return new Promise(resolve => setTimeout(resolve, ms))
}

test.concurrent('test1', async ({ expect }) => {
expect.assertions(1)
await delay(10).then(() => {
expect(1).eq(1)
})
})

test.concurrent('test2', async ({ expect }) => {
expect.assertions(1)
await delay(100).then(() => {
expect(2).eq(2)
})
})

0 comments on commit 35ab058

Please sign in to comment.