diff --git a/src/function/probability/pickRandom.js b/src/function/probability/pickRandom.js index e0104553d6..24c1404d54 100644 --- a/src/function/probability/pickRandom.js +++ b/src/function/probability/pickRandom.js @@ -1,7 +1,7 @@ import { factory } from '../../utils/factory' import { isNumber } from '../../utils/is' -import { arraySize } from '../../utils/array' import { createRng } from './util/seededRNG' +import { flatten } from '../../utils/array' const name = 'pickRandom' const dependencies = ['typed', 'config', '?on'] @@ -76,15 +76,11 @@ export const createPickRandom = /* #__PURE__ */ factory(name, dependencies, ({ t number = 1 } - possibles = possibles.valueOf() // get Array + possibles = flatten(possibles.valueOf()).valueOf() // get Array if (weights) { weights = weights.valueOf() // get Array } - if (arraySize(possibles).length > 1) { - throw new Error('Only one dimensional vectors supported') - } - let totalWeights = 0 if (typeof weights !== 'undefined') { diff --git a/test/unit-tests/function/probability/pickRandom.test.js b/test/unit-tests/function/probability/pickRandom.test.js index fc0cbdd0c8..5ee9e1334c 100644 --- a/test/unit-tests/function/probability/pickRandom.test.js +++ b/test/unit-tests/function/probability/pickRandom.test.js @@ -1,6 +1,7 @@ import assert from 'assert' import { filter, times } from 'lodash' import math from '../../../../src/bundleAny' +import { flatten } from '../../../../src/utils/array' const math2 = math.create({ randomSeed: 'test2' }) const pickRandom = math2.pickRandom @@ -10,12 +11,6 @@ describe('pickRandom', function () { assert.strictEqual(typeof math.pickRandom, 'function') }) - it('should throw an error when providing a multi dimensional matrix', function () { - assert.throws(function () { - pickRandom(math.matrix([[1, 2], [3, 4]])) - }, /Only one dimensional vectors supported/) - }) - it('should throw an error if the length of the weights does not match the length of the possibles', function () { const possibles = [11, 22, 33, 44, 55] const weights = [1, 5, 2, 4] @@ -68,9 +63,9 @@ describe('pickRandom', function () { const weights = [1, 5, 2, 4, 6] const number = 5 - assert.strictEqual(pickRandom(possibles, number), possibles) - assert.strictEqual(pickRandom(possibles, number, weights), possibles) - assert.strictEqual(pickRandom(possibles, weights, number), possibles) + pickRandom(possibles, number).forEach((element, index) => assert.strictEqual(element, possibles[index])) + pickRandom(possibles, number, weights).forEach((element, index) => assert.strictEqual(element, possibles[index])) + pickRandom(possibles, weights, number).forEach((element, index) => assert.strictEqual(element, possibles[index])) }) it('should return the given array if the given number is greater than its length', function () { @@ -78,9 +73,9 @@ describe('pickRandom', function () { const weights = [1, 5, 2, 4, 6] const number = 6 - assert.strictEqual(pickRandom(possibles, number), possibles) - assert.strictEqual(pickRandom(possibles, number, weights), possibles) - assert.strictEqual(pickRandom(possibles, weights, number), possibles) + pickRandom(possibles, number).forEach((element, index) => assert.strictEqual(element, possibles[index])) + pickRandom(possibles, number, weights).forEach((element, index) => assert.strictEqual(element, possibles[index])) + pickRandom(possibles, weights, number).forEach((element, index) => assert.strictEqual(element, possibles[index])) }) it('should return an empty array if the given number is 0', function () { @@ -117,6 +112,30 @@ describe('pickRandom', function () { assert.strictEqual(pickRandom(possibles, weights, number).length, number) }) + it('should pick a number from the given multi dimensional array following an uniform distribution', function () { + const possibles = [[11, 12], [22, 23], [33, 34], [44, 45], [55, 56]] + const picked = [] + + times(1000, () => picked.push(pickRandom(possibles))) + + flatten(possibles).forEach(possible => { + const count = filter(flatten(picked), val => val === possible).length + assert.strictEqual(math.round(count / picked.length, 1), 0.1) + }) + }) + + it('should pick a value from the given multi dimensional array following an uniform distribution', function () { + // just to be sure that works for any kind of array + const possibles = [[[11], [12]], ['test', 45], 'another test', 10, false, [1.3, 4.5, true]] + const picked = [] + + times(1000, () => picked.push(pickRandom(possibles))) + flatten(possibles).forEach(possible => { + const count = filter(picked, val => val === possible).length + assert.strictEqual(math.round(count / picked.length, 1), 0.1) + }) + }) + it('should pick a value from the given array following an uniform distribution if only possibles are passed', function () { const possibles = [11, 22, 33, 44, 55] const picked = []