Skip to content

Commit

Permalink
Introduce pytest-raises-without-match rule rule to prevent `pytest.…
Browse files Browse the repository at this point in the history
…raises` from being called without `match` argument (#5015)

* squash

Signed-off-by: harupy <hkawamura0130@gmail.com>

* remove disable=pytest-raises-without-match

Signed-off-by: harupy <hkawamura0130@gmail.com>

* fix tests

Signed-off-by: harupy <hkawamura0130@gmail.com>
  • Loading branch information
harupy committed Dec 5, 2021
1 parent f827fa4 commit 011ee78
Show file tree
Hide file tree
Showing 84 changed files with 857 additions and 526 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/master.yml
Expand Up @@ -41,7 +41,10 @@ jobs:
run: |
source ./dev/install-common-deps.sh
pip install -r requirements/lint-requirements.txt
- name: Run tests
- name: Test custom pylint-plugins
run : |
pytest tests/pylint_plugins
- name: Run lint checks
run: |
./dev/lint.sh
r:
Expand Down
7 changes: 5 additions & 2 deletions mlflow/entities/run_info.py
Expand Up @@ -4,19 +4,22 @@
from mlflow.exceptions import MlflowException

from mlflow.protos.service_pb2 import RunInfo as ProtoRunInfo
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE


def check_run_is_active(run_info):
if run_info.lifecycle_stage != LifecycleStage.ACTIVE:
raise MlflowException(
"The run {} must be in 'active' lifecycle_stage.".format(run_info.run_id)
"The run {} must be in 'active' lifecycle_stage.".format(run_info.run_id),
error_code=INVALID_PARAMETER_VALUE,
)


def check_run_is_deleted(run_info):
if run_info.lifecycle_stage != LifecycleStage.DELETED:
raise MlflowException(
"The run {} must be in 'deleted' lifecycle_stage.".format(run_info.run_id)
"The run {} must be in 'deleted' lifecycle_stage.".format(run_info.run_id),
error_code=INVALID_PARAMETER_VALUE,
)


Expand Down
2 changes: 1 addition & 1 deletion mlflow/utils/uri.py
Expand Up @@ -258,7 +258,7 @@ def is_databricks_model_registry_artifacts_uri(artifact_uri):
def construct_run_url(hostname, experiment_id, run_id, workspace_id=None):
if not hostname or not experiment_id or not run_id:
raise MlflowException(
"Hostname, experiment ID, and run ID are all required to construct" "a run URL"
"Hostname, experiment ID, and run ID are all required to construct a run URL"
)
prefix = hostname
if workspace_id and workspace_id != "0":
Expand Down
5 changes: 5 additions & 0 deletions pylint_plugins/__init__.py
@@ -0,0 +1,5 @@
from .pytest_raises_without_match import PytestRaisesWithoutMatch


def register(linter):
linter.register_checker(PytestRaisesWithoutMatch(linter))
52 changes: 52 additions & 0 deletions pylint_plugins/pytest_raises_without_match/README.md
@@ -0,0 +1,52 @@
# `pytest-raises-without-match`

This custom pylint rule disallows calling `pytest.raises` without a `match` argument
to avoid capturing unintended exceptions and eliminate false-positive tests.

## Example

Suppose we want to test this function throws `Exception("bar")` when `condition2` is satisfied.

```python
def func():
if condition1:
raise Exception("foo")

if condition2:
raise Exception("bar")
```

### Bad

```python
def test_func():
with pytest.raises(Exception):
func()
```

- This test passes when `condition1` is unintentionally satisfied.
- Future code readers will struggle to identify which exception `pytest.raises` should match.

### Good

```python
def test_func():
with pytest.raises(Exception, match="bar"):
func()
```

- This test fails when `condition1` is unintentionally satisfied.
- Future code readers can quickly identify which exception `pytest.raises` should match by searching `bar`.

## How to disable this rule

```python
def test_func():
with pytest.raises(Exception): # pylint: disable=pytest-raises-without-match
func()
```

## References

- https://docs.pytest.org/en/latest/how-to/assert.html#assertions-about-expected-exceptions
- https://docs.pytest.org/en/latest/reference/reference.html#pytest.raises
38 changes: 38 additions & 0 deletions pylint_plugins/pytest_raises_without_match/__init__.py
@@ -0,0 +1,38 @@
import astroid
from pylint.interfaces import IAstroidChecker
from pylint.checkers import BaseChecker


class PytestRaisesWithoutMatch(BaseChecker):
__implements__ = IAstroidChecker

name = "pytest-raises-without-match"
msgs = {
"W0001": (
"`pytest.raises` must be called with `match` argument` ",
name,
"Use `pytest.raises(<exception>, match=...)`",
),
}
priority = -1

@staticmethod
def _is_pytest_raises_call(node: astroid.Call):
if not isinstance(node.func, astroid.Attribute) or not isinstance(
node.func.expr, astroid.Name
):
return False
return node.func.expr.name == "pytest" and node.func.attrname == "raises"

@staticmethod
def _called_with_match(node: astroid.Call):
# Note `match` is a keyword-only argument:
# https://docs.pytest.org/en/latest/reference/reference.html#pytest.raises
return any(k.arg == "match" for k in node.keywords)

def visit_call(self, node: astroid.Call):
if not PytestRaisesWithoutMatch._is_pytest_raises_call(node):
return

if not PytestRaisesWithoutMatch._called_with_match(node):
self.add_message(self.name, node=node)
2 changes: 1 addition & 1 deletion pylintrc
Expand Up @@ -33,7 +33,7 @@ jobs=0

# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
load-plugins=pylint_plugins

# Pickle collected data for later comparisons.
persistent=yes
Expand Down
2 changes: 1 addition & 1 deletion tests/autologging/test_autologging_behaviors_unit.py
Expand Up @@ -224,7 +224,7 @@ def parallel_fn():
time.sleep(np.random.random())
patch_destination.fn()

with pytest.raises(Exception):
with pytest.raises(Exception, match="enablement error"):
test_autolog(silent=True)

with pytest.warns(None):
Expand Down
4 changes: 2 additions & 2 deletions tests/autologging/test_autologging_client.py
Expand Up @@ -233,7 +233,7 @@ def test_client_correctly_operates_as_context_manager_for_synchronous_flush():
assert run_tags_1 == tags_to_log

exc_to_raise = Exception("test exception")
with pytest.raises(Exception) as raised_exc_info:
with pytest.raises(Exception, match=str(exc_to_raise)) as raised_exc_info:
with mlflow.start_run(), MlflowAutologgingQueueingClient() as client:
run_id_2 = mlflow.active_run().info.run_id
client.log_params(run_id_2, params_to_log)
Expand Down Expand Up @@ -264,7 +264,7 @@ def test_logging_failures_are_handled_as_expected():
client.log_metrics(run_id=pending_run_id, metrics={"a": 1})
client.set_terminated(run_id=pending_run_id, status="KILLED")

with pytest.raises(MlflowException) as exc:
with pytest.raises(MlflowException, match="Batch logging failed!") as exc:
client.flush()

runs = mlflow.search_runs(experiment_ids=[experiment_id], output_format="list")
Expand Down
28 changes: 14 additions & 14 deletions tests/autologging/test_autologging_safety_unit.py
Expand Up @@ -282,7 +282,7 @@ def patch_impl(original, *args, **kwargs):

safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)

with pytest.raises(Exception) as exc:
with pytest.raises(Exception, match=str(exc_to_throw)) as exc:
patch_destination.fn()

assert exc.value == exc_to_throw
Expand Down Expand Up @@ -319,7 +319,7 @@ def patch_impl(original, *args, **kwargs):
raise exc_to_throw

safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)
with pytest.raises(Exception) as exc:
with pytest.raises(Exception, match=str(exc_to_throw)) as exc:
patch_destination.fn()

assert exc.value == exc_to_throw
Expand Down Expand Up @@ -860,7 +860,7 @@ def non_throwing_function():
def throwing_function():
raise exc_to_throw

with pytest.raises(Exception) as exc:
with pytest.raises(Exception, match=str(exc_to_throw)) as exc:
throwing_function()

assert exc.value == exc_to_throw
Expand Down Expand Up @@ -913,7 +913,7 @@ class ThrowingClass(baseclass, metaclass=metaclass):
def function(self):
raise exc_to_throw

with pytest.raises(Exception) as exc:
with pytest.raises(Exception, match=str(exc_to_throw)) as exc:
ThrowingClass().function()

assert exc.value == exc_to_throw
Expand Down Expand Up @@ -998,14 +998,14 @@ def patch_function(original, *args, **kwargs):

patch_function = with_managed_run("test_integration", patch_function)

with pytest.raises(Exception):
with pytest.raises(Exception, match="bad implementation"):
patch_function(lambda: "foo")

assert patch_function_active_run is not None
status1 = client.get_run(patch_function_active_run.info.run_id).info.status
assert RunStatus.from_string(status1) == RunStatus.FAILED

with mlflow.start_run() as active_run, pytest.raises(Exception):
with mlflow.start_run() as active_run, pytest.raises(Exception, match="bad implementation"):
patch_function(lambda: "foo")
assert patch_function_active_run == active_run
# `with_managed_run` should not terminate a preexisting MLflow run,
Expand Down Expand Up @@ -1053,14 +1053,14 @@ def _on_exception(self, exception):

TestPatch = with_managed_run("test_integration", TestPatch)

with pytest.raises(Exception):
with pytest.raises(Exception, match="bad implementation"):
TestPatch.call(lambda: "foo")

assert patch_function_active_run is not None
status1 = client.get_run(patch_function_active_run.info.run_id).info.status
assert RunStatus.from_string(status1) == RunStatus.FAILED

with mlflow.start_run() as active_run, pytest.raises(Exception):
with mlflow.start_run() as active_run, pytest.raises(Exception, match="bad implementation"):
TestPatch.call(lambda: "foo")
assert patch_function_active_run == active_run
# `with_managed_run` should not terminate a preexisting MLflow run,
Expand Down Expand Up @@ -1108,7 +1108,7 @@ def original():
"test_integration", lambda original, *args, **kwargs: original(*args, **kwargs)
)

with pytest.raises(KeyboardInterrupt):
with pytest.raises(KeyboardInterrupt, match=""):
patch_function_1(original)

assert not mlflow.active_run()
Expand All @@ -1124,7 +1124,7 @@ def _on_exception(self, exception):

patch_function_2 = with_managed_run("test_integration", PatchFunction2)

with pytest.raises(KeyboardInterrupt):
with pytest.raises(KeyboardInterrupt, match=""):

patch_function_2.call(original)

Expand Down Expand Up @@ -1418,8 +1418,8 @@ def patch_fn(original):

# If use safe_patch to patch, exception would not come from original fn and so would be logged
patch_destination.fn = patch_fn
with pytest.raises(Exception):
patch_destination.fn()
with pytest.raises(Exception, match="Exception that should stop autologging session"):
patch_destination.fn(lambda: None)

assert _AutologgingSessionManager.active_session() is None

Expand Down Expand Up @@ -1573,7 +1573,7 @@ def _predict(self, X, a, b):
@property
def predict(self):
if not self._has_predict:
raise AttributeError()
raise AttributeError("does not have predict")
return self._predict

class ExtendedEstimator(BaseEstimator):
Expand Down Expand Up @@ -1624,7 +1624,7 @@ def autolog(disable=False, exclusive=False, silent=False): # pylint: disable=un

bad_estimator = EstimatorCls(has_predict=False)
assert not hasattr(bad_estimator, "predict")
with pytest.raises(AttributeError):
with pytest.raises(AttributeError, match="does not have predict"):
bad_estimator.predict(X=1, a=2, b=3)

autolog(disable=True)
Expand Down
4 changes: 2 additions & 2 deletions tests/autologging/test_autologging_utils.py
Expand Up @@ -911,9 +911,9 @@ def f4(self, *args, **kwargs):
assert 3 == get_instance_method_first_arg_value(Test.f1, [3], {"cd2": 4})
assert 3 == get_instance_method_first_arg_value(Test.f1, [], {"ab1": 3, "cd2": 4})
assert 3 == get_instance_method_first_arg_value(Test.f2, [3, 4], {})
with pytest.raises(AssertionError):
with pytest.raises(AssertionError, match=""):
get_instance_method_first_arg_value(Test.f3, [], {"ab1": 3, "cd2": 4})
with pytest.raises(AssertionError):
with pytest.raises(AssertionError, match=""):
get_instance_method_first_arg_value(Test.f4, [], {"ab1": 3, "cd2": 4})


Expand Down
6 changes: 4 additions & 2 deletions tests/azureml/test_deploy.py
Expand Up @@ -287,7 +287,9 @@ def test_deploy_throws_exception_if_model_does_not_contain_pyfunc_flavor(sklearn
del model_config.flavors[pyfunc.FLAVOR_NAME]
model_config.save(model_config_path)

with AzureMLMocks(), pytest.raises(MlflowException) as exc:
with AzureMLMocks(), pytest.raises(
MlflowException, match="does not contain the `python_function` flavor"
) as exc:
workspace = get_azure_workspace()
mlflow.azureml.deploy(model_uri=model_path, workspace=workspace)
assert exc.error_code == INVALID_PARAMETER_VALUE
Expand All @@ -304,7 +306,7 @@ def test_deploy_throws_exception_if_model_python_version_is_less_than_three(
model_config.flavors[pyfunc.FLAVOR_NAME][pyfunc.PY_VERSION] = "2.7.6"
model_config.save(model_config_path)

with AzureMLMocks(), pytest.raises(MlflowException) as exc:
with AzureMLMocks(), pytest.raises(MlflowException, match="Python 3 and above") as exc:
workspace = get_azure_workspace()
mlflow.azureml.deploy(model_uri=model_path, workspace=workspace)
assert exc.error_code == INVALID_PARAMETER_VALUE
Expand Down
6 changes: 4 additions & 2 deletions tests/azureml/test_image_creation.py
Expand Up @@ -417,7 +417,9 @@ def test_build_image_throws_exception_if_model_does_not_contain_pyfunc_flavor(
del model_config.flavors[pyfunc.FLAVOR_NAME]
model_config.save(model_config_path)

with AzureMLMocks(), pytest.raises(MlflowException) as exc:
with AzureMLMocks(), pytest.raises(
MlflowException, match="does not contain the `python_function` flavor"
) as exc:
workspace = get_azure_workspace()
mlflow.azureml.build_image(model_uri=model_path, workspace=workspace)
assert exc.error_code == INVALID_PARAMETER_VALUE
Expand All @@ -434,7 +436,7 @@ def test_build_image_throws_exception_if_model_python_version_is_less_than_three
model_config.flavors[pyfunc.FLAVOR_NAME][pyfunc.PY_VERSION] = "2.7.6"
model_config.save(model_config_path)

with AzureMLMocks(), pytest.raises(MlflowException) as exc:
with AzureMLMocks(), pytest.raises(MlflowException, match="Python 3 and above") as exc:
workspace = get_azure_workspace()
mlflow.azureml.build_image(model_uri=model_path, workspace=workspace)
assert exc.error_code == INVALID_PARAMETER_VALUE
Expand Down
2 changes: 1 addition & 1 deletion tests/data/test_data.py
Expand Up @@ -50,7 +50,7 @@ def test_download_uri():
# Verify exceptions are thrown when downloading from unsupported/invalid URIs
invalid_prefixes = ["file://", "/tmp"]
for prefix in invalid_prefixes:
with temp_directory() as dst_dir, pytest.raises(DownloadException):
with temp_directory() as dst_dir, pytest.raises(DownloadException, match="`uri` must be"):
download_uri(
uri=os.path.join(prefix, "some/path"), output_path=os.path.join(dst_dir, "tmp-file")
)
8 changes: 5 additions & 3 deletions tests/deployments/test_deployments.py
Expand Up @@ -45,7 +45,9 @@ def test_get_success():


def test_wrong_target_name():
with pytest.raises(MlflowException):
with pytest.raises(
MlflowException, match='No plugin found for managing model deployments to "wrong_target"'
):
deployments.get_deploy_client("wrong_target")


Expand All @@ -56,15 +58,15 @@ class DummyPlugin:
dummy_plugin = DummyPlugin()
plugin_manager = DeploymentPlugins()
plugin_manager.registry["dummy"] = dummy_plugin
with pytest.raises(MlflowException):
with pytest.raises(MlflowException, match="Plugin registered for the target dummy"):
plugin_manager["dummy"] # pylint: disable=pointless-statement


def test_plugin_raising_error():
client = deployments.get_deploy_client(f_target)
# special case to raise error
os.environ["raiseError"] = "True"
with pytest.raises(RuntimeError):
with pytest.raises(RuntimeError, match="Error requested"):
client.list_deployments()
os.environ["raiseError"] = "False"

Expand Down
2 changes: 1 addition & 1 deletion tests/entities/test_run.py
Expand Up @@ -95,6 +95,6 @@ def test_string_repr(self):

def test_creating_run_with_absent_info_throws_exception(self):
run_data = TestRunData._create()[0]
with pytest.raises(MlflowException) as no_info_exc:
with pytest.raises(MlflowException, match="run_info cannot be None") as no_info_exc:
Run(None, run_data)
assert "run_info cannot be None" in str(no_info_exc)

0 comments on commit 011ee78

Please sign in to comment.