Skip to content

Commit

Permalink
Fix numpy ValueError with seeds < 0 or > 2 ** 32 - 1 (#275)
Browse files Browse the repository at this point in the history
Fixes #269
  • Loading branch information
adamchainz committed Jul 10, 2020
1 parent c610c42 commit b38226e
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
3 changes: 3 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
History
-------

* Fix numpy error ``ValueError: Seed must be between 0 and 2**32 - 1`` when
passed a seed outside of this range.

3.4.0 (2020-05-27)
------------------

Expand Down
19 changes: 15 additions & 4 deletions src/pytest_randomly.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import hashlib
import random
import sys

Expand Down Expand Up @@ -140,11 +141,12 @@ def _reseed(config, offset=0):
faker_random.setstate(random_states[seed])

if have_numpy:
if seed not in np_random_states:
np_random.seed(seed)
np_random_states[seed] = np_random.get_state()
numpy_seed = _truncate_seed_for_numpy(seed)
if numpy_seed not in np_random_states:
np_random.seed(numpy_seed)
np_random_states[numpy_seed] = np_random.get_state()
else:
np_random.set_state(np_random_states[seed])
np_random.set_state(np_random_states[numpy_seed])

if entrypoint_reseeds is None:
entrypoint_reseeds = [
Expand All @@ -154,6 +156,15 @@ def _reseed(config, offset=0):
reseed(seed)


def _truncate_seed_for_numpy(seed):
seed = abs(seed)
if seed <= 2 ** 32 - 1:
return seed

seed_bytes = seed.to_bytes(seed.bit_length(), "big")
return int.from_bytes(hashlib.sha512(seed_bytes).digest()[: 32 // 8], "big")


def pytest_report_header(config):
seed = config.getoption("randomly_seed")
_reseed(config)
Expand Down
14 changes: 14 additions & 0 deletions tests/test_pytest_randomly.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,20 @@ def test_two():
out.assert_outcomes(passed=2)


def test_numpy_doesnt_crash_with_large_seed(ourtestdir):
ourtestdir.makepyfile(
test_one="""
import numpy as np
def test_one():
assert np.random.rand() >= 0.0
"""
)

out = ourtestdir.runpytest("--randomly-seed=7106521602475165645")
out.assert_outcomes(passed=1)


def test_failing_import(testdir):
"""Test with pytest raising CollectError or ImportError.
Expand Down

0 comments on commit b38226e

Please sign in to comment.