Skip to content

Commit

Permalink
Distinct cache keys when api_version=None
Browse files Browse the repository at this point in the history
  • Loading branch information
honno committed Sep 28, 2022
1 parent 057d4de commit bc4bcb6
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 11 deletions.
2 changes: 1 addition & 1 deletion hypothesis-python/src/hypothesis/extra/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,7 +1076,7 @@ def complex_dtypes(

namespace = StrategiesNamespace(**kwargs)
try:
_args_to_xps[(xp, api_version)] = namespace
_args_to_xps[(xp, None if inferred_version else api_version)] = namespace
except TypeError:
pass

Expand Down
53 changes: 43 additions & 10 deletions hypothesis-python/tests/array_api/test_strategies_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,32 @@
)
from hypothesis.strategies import SearchStrategy

pytestmark = pytest.mark.filterwarnings("ignore::hypothesis.errors.HypothesisWarning")

@pytest.mark.filterwarnings("ignore::hypothesis.errors.HypothesisWarning")
def test_caching(xp, monkeypatch):

class HashableArrayModuleFactory:
"""
mock_xp cannot be hashed and thus cannot be used in our cache. So just for
the purposes of testing the cache, we wrap it with an unsafe hash method.
"""

def __getattr__(self, name):
return getattr(mock_xp, name)

def __hash__(self):
return hash(tuple(sorted(mock_xp.__dict__)))


@pytest.mark.parametrize("api_version", ["2021.12", None])
def test_caching(api_version, monkeypatch):
"""Caches namespaces respective to arguments."""
try:
hash(xp)
except TypeError:
pytest.skip("xp not hashable")
assert isinstance(array_api._args_to_xps, WeakValueDictionary)
xp = HashableArrayModuleFactory()
assert isinstance(array_api._args_to_xps, WeakValueDictionary) # sanity check
monkeypatch.setattr(array_api, "_args_to_xps", WeakValueDictionary())
assert len(array_api._args_to_xps) == 0 # sanity check
xps1 = array_api.make_strategies_namespace(xp, api_version="2021.12")
xps1 = array_api.make_strategies_namespace(xp, api_version=api_version)
assert len(array_api._args_to_xps) == 1
xps2 = array_api.make_strategies_namespace(xp, api_version="2021.12")
xps2 = array_api.make_strategies_namespace(xp, api_version=api_version)
assert len(array_api._args_to_xps) == 1
assert isinstance(xps2, SimpleNamespace)
assert xps2 is xps1
Expand All @@ -43,7 +55,28 @@ def test_caching(xp, monkeypatch):
assert len(array_api._args_to_xps) == 0


@pytest.mark.filterwarnings("ignore::hypothesis.errors.HypothesisWarning")
@pytest.mark.parametrize(
"api_version1, api_version2", [(None, "2021.12"), ("2021.12", None)]
)
def test_inferred_namespace_is_cached_seperately(
api_version1, api_version2, monkeypatch
):
"""Results from inferred versions do not share the same cache key as results
from specified versions."""
xp = HashableArrayModuleFactory()
xp.__array_api_version__ = "2021.12"
assert isinstance(array_api._args_to_xps, WeakValueDictionary) # sanity check
monkeypatch.setattr(array_api, "_args_to_xps", WeakValueDictionary())
assert len(array_api._args_to_xps) == 0 # sanity check
xps1 = array_api.make_strategies_namespace(xp, api_version=api_version1)
assert xps1.api_version == "2021.12" # sanity check
assert len(array_api._args_to_xps) == 1
xps2 = array_api.make_strategies_namespace(xp, api_version=api_version2)
assert xps2.api_version == "2021.12" # sanity check
assert len(array_api._args_to_xps) == 2
assert xps2 is not xps1


def test_complex_dtypes_raises_on_2021_12():
"""Accessing complex_dtypes() for 2021.12 strategy namespace raises helpful
error, but accessing on future versions returns expected strategy."""
Expand Down

0 comments on commit bc4bcb6

Please sign in to comment.