Skip to content

Commit

Permalink
JAX integration test robustness.
Browse files Browse the repository at this point in the history
This commit improves the robustness (i.e., portability) of JAX-specific
integration tests in our test suite, preventing these tests from
erroneously failing with spurious warnings in the event of a mismatch
between the CPU flags with which the low-level C-based `jaxlib` package
was compiled and the CPU flags supported by the current system.
(*Unsettling settlement of unmentionable sediment!*)
  • Loading branch information
leycec committed May 4, 2024
1 parent 1b3b589 commit 910579e
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 28 deletions.
40 changes: 39 additions & 1 deletion beartype/_util/error/utilerrwarn.py
Expand Up @@ -15,19 +15,57 @@
from beartype.typing import (
Any,
Iterable,
Iterator,
)
from beartype._data.error.dataerrmagic import EXCEPTION_PLACEHOLDER
from beartype._data.hint.datahinttyping import TypeWarning
from beartype._util.error.utilerrtest import is_exception_message_str
from beartype._util.py.utilpyversion import IS_PYTHON_AT_LEAST_3_12
from beartype._util.py.utilpyversion import (
IS_PYTHON_AT_LEAST_3_11,
IS_PYTHON_AT_LEAST_3_12,
)
from beartype._util.text.utiltextmunge import uppercase_str_char_first
from collections.abc import Iterable as IterableABC
from contextlib import contextmanager
from warnings import (
WarningMessage,
catch_warnings,
simplefilter,
warn,
warn_explicit,
)

# ....................{ CONTEXTS }....................
#FIXME: Unit test us up, please.
@contextmanager
def warnings_ignored() -> Iterator[None]:
'''
Context manager temporarily ignoring *all* warnings transitively emitted
within the body of this context.
Yields
------
None
This context manager yields *no* objects.
See Also
--------
https://stackoverflow.com/a/14463362/2809027
StackOverflow answer strongly inspiring this implementation.
'''

# If the active Python interpreter targets Python > 3.11, prefer an
# efficient one-liner yielding the desired outcome. Get it? Yielding? ...heh
if IS_PYTHON_AT_LEAST_3_11:
with catch_warnings(action='ignore'): # type: ignore[call-overload]
yield
# Else, the active Python interpreter targets Python <= 3.10. In this case,
# fallback to an inefficient generator yielding the same outcome.
else:
with catch_warnings():
simplefilter('ignore')
yield

# ....................{ WARNERS }....................
# If the active Python interpreter targets Python >= 3.12, the standard
# warnings.warn() function supports the optional "skip_file_prefixes" parameter
Expand Down
8 changes: 4 additions & 4 deletions beartype/_util/module/utilmodimport.py
Expand Up @@ -97,8 +97,8 @@ def import_module_or_none(
module_name : str
Fully-qualified name of the module to be imported.
exception_cls : Type[Exception]
Type of exception to be raised by this function. Defaults to
:class:`._BeartypeUtilModuleException`.
Type of exception to be raised in the event of a fatal error. Defaults
to :class:`._BeartypeUtilModuleException`.
exception_prefix : str, optional
Human-readable label prefixing the representation of this object in the
exception message. Defaults to the empty string.
Expand All @@ -119,8 +119,8 @@ def import_module_or_none(
Warns
-----
BeartypeModuleUnimportableWarning
If a module with this name exists *but* that module is unimportable
due to raising module-scoped exceptions at importation time.
If a module with this name exists *but* that module is unimportable due
to raising module-scoped exceptions at importation time.
'''

# Avoid circular import dependencies.
Expand Down
44 changes: 35 additions & 9 deletions beartype/_util/module/utilmodtest.py
Expand Up @@ -12,7 +12,10 @@

# ....................{ IMPORTS }....................
from beartype.roar._roarexc import _BeartypeUtilModuleException
from beartype.typing import Optional
from beartype._cave._cavefast import ModuleType
from beartype._data.hint.datahinttyping import TypeException
from beartype._util.error.utilerrwarn import warnings_ignored
from beartype._util.text.utiltextidentifier import die_unless_identifier
from beartype._util.text.utiltextversion import convert_str_version_to_tuple
from importlib.metadata import version as get_module_version # type: ignore[attr-defined]
Expand Down Expand Up @@ -100,7 +103,13 @@ def die_unless_module_attr_name(
# attribute name.

# ....................{ TESTERS }....................
def is_module(module_name: str) -> bool:
def is_module(
# Mandatory parameters.
module_name: str,

# Optional parameters.
is_warnings_ignore: bool = False,
) -> bool:
'''
:data:`True` only if the module or C extension with the passed
fully-qualified name is importable under the active Python interpreter.
Expand All @@ -114,6 +123,10 @@ def is_module(module_name: str) -> bool:
----------
module_name : str
Fully-qualified name of the module to be imported.
is_warnings_ignore : bool, optional
:data:`True` only if this tester ignores *all* warnings transitively
emitted as a side effect by the importation of this module. Defaults to
:data:`False` for safety.
Returns
-------
Expand All @@ -123,15 +136,25 @@ def is_module(module_name: str) -> bool:
Warns
-----
BeartypeModuleUnimportableWarning
If a module with this name exists *but* that module is unimportable
due to raising module-scoped exceptions at importation time.
If a module with this name exists *but* that module is unimportable due
to raising module-scoped exceptions at importation time.
'''

# Avoid circular import dependencies.
from beartype._util.module.utilmodimport import import_module_or_none

# Module with this name if this module is importable *OR* "None" otherwise.
module = import_module_or_none(module_name)
module: Optional[ModuleType] = None

# If ignoring *ALL* warnings transitively emitted as a side effect by the
# importation of this module, attempt to dynamically import this module
# under a context manager ignoring these warnings.
if is_warnings_ignore:
with warnings_ignored():
module = import_module_or_none(module_name)
# Else, dynamically import this module *WITHOUT* ignoring these warnings.
else:
module = import_module_or_none(module_name)

# Return true only if this module is importable.
return module is not None
Expand Down Expand Up @@ -169,8 +192,8 @@ def is_module_version_at_least(module_name: str, version_minimum: str) -> bool:
Warns
-----
BeartypeModuleUnimportableWarning
If a module with this name exists *but* that module is unimportable
due to raising module-scoped exceptions at importation time.
If a module with this name exists *but* that module is unimportable due
to raising module-scoped exceptions at importation time.
'''
assert isinstance(version_minimum, str), (
f'{repr(version_minimum)} not string.')
Expand All @@ -194,7 +217,7 @@ def is_module_version_at_least(module_name: str, version_minimum: str) -> bool:

# ....................{ TESTERS ~ package }....................
#FIXME: Unit test us up, please.
def is_package(package_name: str) -> bool:
def is_package(package_name: str, **kwargs) -> bool:
'''
:data:`True` only if the package with the passed fully-qualified name is
importable under the active Python interpreter.
Expand All @@ -209,6 +232,9 @@ def is_package(package_name: str) -> bool:
package_name : str
Fully-qualified name of the package to be imported.
All remaining keyword parameters are passed as is to the lower-level
:func:`.is_module` tester.
Returns
-------
bool
Expand All @@ -218,9 +244,9 @@ def is_package(package_name: str) -> bool:
-----
BeartypeModuleUnimportableWarning
If a package with this name exists *but* that package is unimportable
due to raising module-scoped exceptions from the top-level `__init__`
due to raising module-scoped exceptions from the top-level ``__init__``
submodule of this package at importation time.
'''

# Be the one liner you want to see in the world.
return is_module(f'{package_name}.__init__')
return is_module(f'{package_name}.__init__', **kwargs)
2 changes: 1 addition & 1 deletion beartype/meta.py
Expand Up @@ -209,7 +209,7 @@ def _convert_version_str_to_tuple(version_str: str): # -> _Tuple[int, ...]:
# For further details, see http://semver.org.
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

VERSION = '0.18.6'
VERSION = '0.19.0'
'''
Human-readable package version as a ``.``-delimited string.
'''
Expand Down
34 changes: 28 additions & 6 deletions beartype_test/a90_func/z90_lib/a80_jax/test_equinox.py
Expand Up @@ -61,10 +61,21 @@ def test_equinox_filter_jit() -> None:
#dangerous and *MUST* be isolated to a subprocess. Honestly, what a pain.
#See similar logic in "test_jax" also requiring a similar resolution.
# If any requisite JAX package is unimportable, silently reduce to a noop.
#
# Note that merely testing the importability of a JAX package emits warnings
# in unpredictably hardware-dependent edge cases. Since that then induces
# test failure, these tests necessarily ignore these warnings. For example,
# if the low-level C-based "jaxlib" package was compiled on a newer system
# supporting the assembly-level AVX instruction set that the current system
# fails to support, these tests would emit this warning:
# E RuntimeError: This version of jaxlib was built using AVX
# instructions, which your CPU and/or operating system do not
# support. You may be able work around this issue by building jaxlib
# from source.
if not (
is_package('equinox') and
is_package('jax') and
is_package('jaxtyping')
is_package('equinox', is_warnings_ignore=True) and
is_package('jax', is_warnings_ignore=True) and
is_package('jaxtyping', is_warnings_ignore=True)
):
return
# Else, all requisite JAX packages is importable.
Expand Down Expand Up @@ -158,10 +169,21 @@ def test_equinox_module_subclass() -> None:
#dangerous and *MUST* be isolated to a subprocess. Honestly, what a pain.
#See similar logic in "test_jax" also requiring a similar resolution.
# If any requisite JAX package is unimportable, silently reduce to a noop.
#
# Note that merely testing the importability of a JAX package emits warnings
# in unpredictably hardware-dependent edge cases. Since that then induces
# test failure, these tests necessarily ignore these warnings. For example,
# if the low-level C-based "jaxlib" package was compiled on a newer system
# supporting the assembly-level AVX instruction set that the current system
# fails to support, these tests would emit this warning:
# E RuntimeError: This version of jaxlib was built using AVX
# instructions, which your CPU and/or operating system do not
# support. You may be able work around this issue by building jaxlib
# from source.
if not (
is_package('equinox') and
is_package('jax') and
is_package('jaxtyping')
is_package('equinox', is_warnings_ignore=True) and
is_package('jax', is_warnings_ignore=True) and
is_package('jaxtyping', is_warnings_ignore=True)
):
return
# Else, all requisite JAX packages is importable.
Expand Down
20 changes: 17 additions & 3 deletions beartype_test/a90_func/z90_lib/a80_jax/test_jax.py
Expand Up @@ -60,10 +60,24 @@ def test_jax_jit() -> None:
#FIXME: *EVEN THIS ISN"T SAFE.* Any importation whatsoever from JAX is
#dangerous and *MUST* be isolated to a subprocess. Honestly, what a pain.
#See similar logic in "test_equinox" also requiring a similar resolution.
# If either the "jax" or "jaxtyping" packages are unimportable, silently reduce to a noop.
if not (is_package('jax') and is_package('jaxtyping')):
# If any requisite JAX package is unimportable, silently reduce to a noop.
#
# Note that merely testing the importability of a JAX package emits warnings
# in unpredictably hardware-dependent edge cases. Since that then induces
# test failure, these tests necessarily ignore these warnings. For example,
# if the low-level C-based "jaxlib" package was compiled on a newer system
# supporting the assembly-level AVX instruction set that the current system
# fails to support, these tests would emit this warning:
# E RuntimeError: This version of jaxlib was built using AVX
# instructions, which your CPU and/or operating system do not
# support. You may be able work around this issue by building jaxlib
# from source.
if not (
is_package('jax', is_warnings_ignore=True) and
is_package('jaxtyping', is_warnings_ignore=True)
):
return
# Else, the "jax" package is importable.
# Else, all requisite JAX packages is importable.

# ....................{ IMPORTS ~ late }....................
# Defer JAX-dependent imports.
Expand Down
14 changes: 12 additions & 2 deletions doc/src/_links.rst
Expand Up @@ -325,8 +325,6 @@
https://www.sympy.org
.. _TensorFlow:
https://www.tensorflow.org
.. _equinox:
https://github.com/patrick-kidger/equinox
.. _nptyping:
https://github.com/ramonhagenaars/nptyping
.. _numerary:
Expand All @@ -344,7 +342,19 @@
.. _mypy-boto3:
https://mypy-boto3.readthedocs.io

.. # ------------------( LINKS ~ py : package : equinox )------------------
.. _equinox:
https://docs.kidger.site/equinox
.. _equinox.Module:
https://docs.kidger.site/equinox/api/module/module
.. _equinox.filter_jit:
https://docs.kidger.site/equinox/api/transformations/#equinox.filter_jit

.. # ------------------( LINKS ~ py : package : jax )------------------
.. _jax:
https://jax.readthedocs.io
.. _jax.jit:
https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit
.. _jax.numpy:
https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html

Expand Down
8 changes: 6 additions & 2 deletions doc/src/pep.rst
Expand Up @@ -324,10 +324,14 @@ you into stunned disbelief that somebody typed all this. [#rsi]_
+------------------------+-----------------------------------------------------------+--------------------------+---------------------------+
| :mod:`enum` | :obj:`~enum.Enum` | **0.16.0**\ \ *current* | *none* |
+------------------------+-----------------------------------------------------------+--------------------------+---------------------------+
| equinox_ | *all* || **0.17.0**\ \ *current* |
+------------------------+-----------------------------------------------------------+--------------------------+---------------------------+
| | :obj:`~enum.StrEnum` | **0.16.0**\ \ *current* | *none* |
+------------------------+-----------------------------------------------------------+--------------------------+---------------------------+
| equinox_ | `Module <equinox.module_>`__ || **0.17.0**\ \ *current* |
+------------------------+-----------------------------------------------------------+--------------------------+---------------------------+
| | `@filter_jit <equinox.filter_jit_>`__ || **0.19.0**\ \ *current* |
+------------------------+-----------------------------------------------------------+--------------------------+---------------------------+
| jax_ | `@jit <jax.jit_>`__ || **0.19.0**\ \ *current* |
+------------------------+-----------------------------------------------------------+--------------------------+---------------------------+
| :mod:`functools` | :obj:`~functools.lru_cache` || **0.15.0**\ \ *current* |
+------------------------+-----------------------------------------------------------+--------------------------+---------------------------+
| nuitka_ | *all* || **0.12.0**\ \ *current* |
Expand Down

0 comments on commit 910579e

Please sign in to comment.