Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce pytest-raises-without-match rule rule to prevent pytest.raises from being called without match argument #5015

Merged
merged 3 commits into from Dec 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/master.yml
Expand Up @@ -41,7 +41,10 @@ jobs:
run: |
source ./dev/install-common-deps.sh
pip install -r requirements/lint-requirements.txt
- name: Run tests
- name: Test custom pylint-plugins
run : |
pytest tests/pylint_plugins
- name: Run lint checks
run: |
./dev/lint.sh
r:
Expand Down
7 changes: 5 additions & 2 deletions mlflow/entities/run_info.py
Expand Up @@ -4,19 +4,22 @@
from mlflow.exceptions import MlflowException

from mlflow.protos.service_pb2 import RunInfo as ProtoRunInfo
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE


def check_run_is_active(run_info):
if run_info.lifecycle_stage != LifecycleStage.ACTIVE:
raise MlflowException(
"The run {} must be in 'active' lifecycle_stage.".format(run_info.run_id)
"The run {} must be in 'active' lifecycle_stage.".format(run_info.run_id),
error_code=INVALID_PARAMETER_VALUE,
)


def check_run_is_deleted(run_info):
if run_info.lifecycle_stage != LifecycleStage.DELETED:
raise MlflowException(
"The run {} must be in 'deleted' lifecycle_stage.".format(run_info.run_id)
"The run {} must be in 'deleted' lifecycle_stage.".format(run_info.run_id),
error_code=INVALID_PARAMETER_VALUE,
)


Expand Down
2 changes: 1 addition & 1 deletion mlflow/utils/uri.py
Expand Up @@ -258,7 +258,7 @@ def is_databricks_model_registry_artifacts_uri(artifact_uri):
def construct_run_url(hostname, experiment_id, run_id, workspace_id=None):
if not hostname or not experiment_id or not run_id:
raise MlflowException(
"Hostname, experiment ID, and run ID are all required to construct" "a run URL"
"Hostname, experiment ID, and run ID are all required to construct a run URL"
)
prefix = hostname
if workspace_id and workspace_id != "0":
Expand Down
5 changes: 5 additions & 0 deletions pylint_plugins/__init__.py
@@ -0,0 +1,5 @@
from .pytest_raises_without_match import PytestRaisesWithoutMatch


def register(linter):
linter.register_checker(PytestRaisesWithoutMatch(linter))
52 changes: 52 additions & 0 deletions pylint_plugins/pytest_raises_without_match/README.md
@@ -0,0 +1,52 @@
# `pytest-raises-without-match`

This custom pylint rule disallows calling `pytest.raises` without a `match` argument
to avoid capturing unintended exceptions and eliminate false-positive tests.

## Example

Suppose we want to test this function throws `Exception("bar")` when `condition2` is satisfied.

```python
def func():
if condition1:
raise Exception("foo")

if condition2:
raise Exception("bar")
```

### Bad

```python
def test_func():
with pytest.raises(Exception):
func()
```

- This test passes when `condition1` is unintentionally satisfied.
- Future code readers will struggle to identify which exception `pytest.raises` should match.

### Good

```python
def test_func():
with pytest.raises(Exception, match="bar"):
func()
```

- This test fails when `condition1` is unintentionally satisfied.
- Future code readers can quickly identify which exception `pytest.raises` should match by searching `bar`.

## How to disable this rule

```python
def test_func():
with pytest.raises(Exception): # pylint: disable=pytest-raises-without-match
func()
```

## References

- https://docs.pytest.org/en/latest/how-to/assert.html#assertions-about-expected-exceptions
- https://docs.pytest.org/en/latest/reference/reference.html#pytest.raises
38 changes: 38 additions & 0 deletions pylint_plugins/pytest_raises_without_match/__init__.py
@@ -0,0 +1,38 @@
import astroid
from pylint.interfaces import IAstroidChecker
from pylint.checkers import BaseChecker


class PytestRaisesWithoutMatch(BaseChecker):
__implements__ = IAstroidChecker

name = "pytest-raises-without-match"
msgs = {
"W0001": (
"`pytest.raises` must be called with `match` argument` ",
name,
"Use `pytest.raises(<exception>, match=...)`",
),
}
priority = -1

@staticmethod
def _is_pytest_raises_call(node: astroid.Call):
if not isinstance(node.func, astroid.Attribute) or not isinstance(
node.func.expr, astroid.Name
):
return False
return node.func.expr.name == "pytest" and node.func.attrname == "raises"

@staticmethod
def _called_with_match(node: astroid.Call):
# Note `match` is a keyword-only argument:
# https://docs.pytest.org/en/latest/reference/reference.html#pytest.raises
return any(k.arg == "match" for k in node.keywords)

def visit_call(self, node: astroid.Call):
if not PytestRaisesWithoutMatch._is_pytest_raises_call(node):
return

if not PytestRaisesWithoutMatch._called_with_match(node):
self.add_message(self.name, node=node)
2 changes: 1 addition & 1 deletion pylintrc
Expand Up @@ -33,7 +33,7 @@ jobs=0

# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
load-plugins=pylint_plugins

# Pickle collected data for later comparisons.
persistent=yes
Expand Down
2 changes: 1 addition & 1 deletion tests/autologging/test_autologging_behaviors_unit.py
Expand Up @@ -224,7 +224,7 @@ def parallel_fn():
time.sleep(np.random.random())
patch_destination.fn()

with pytest.raises(Exception):
with pytest.raises(Exception, match="enablement error"):
test_autolog(silent=True)

with pytest.warns(None):
Expand Down
4 changes: 2 additions & 2 deletions tests/autologging/test_autologging_client.py
Expand Up @@ -233,7 +233,7 @@ def test_client_correctly_operates_as_context_manager_for_synchronous_flush():
assert run_tags_1 == tags_to_log

exc_to_raise = Exception("test exception")
with pytest.raises(Exception) as raised_exc_info:
with pytest.raises(Exception, match=str(exc_to_raise)) as raised_exc_info:
with mlflow.start_run(), MlflowAutologgingQueueingClient() as client:
run_id_2 = mlflow.active_run().info.run_id
client.log_params(run_id_2, params_to_log)
Expand Down Expand Up @@ -264,7 +264,7 @@ def test_logging_failures_are_handled_as_expected():
client.log_metrics(run_id=pending_run_id, metrics={"a": 1})
client.set_terminated(run_id=pending_run_id, status="KILLED")

with pytest.raises(MlflowException) as exc:
with pytest.raises(MlflowException, match="Batch logging failed!") as exc:
client.flush()

runs = mlflow.search_runs(experiment_ids=[experiment_id], output_format="list")
Expand Down
28 changes: 14 additions & 14 deletions tests/autologging/test_autologging_safety_unit.py
Expand Up @@ -282,7 +282,7 @@ def patch_impl(original, *args, **kwargs):

safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)

with pytest.raises(Exception) as exc:
with pytest.raises(Exception, match=str(exc_to_throw)) as exc:
patch_destination.fn()

assert exc.value == exc_to_throw
Expand Down Expand Up @@ -319,7 +319,7 @@ def patch_impl(original, *args, **kwargs):
raise exc_to_throw

safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)
with pytest.raises(Exception) as exc:
with pytest.raises(Exception, match=str(exc_to_throw)) as exc:
patch_destination.fn()

assert exc.value == exc_to_throw
Expand Down Expand Up @@ -860,7 +860,7 @@ def non_throwing_function():
def throwing_function():
raise exc_to_throw

with pytest.raises(Exception) as exc:
with pytest.raises(Exception, match=str(exc_to_throw)) as exc:
throwing_function()

assert exc.value == exc_to_throw
Expand Down Expand Up @@ -913,7 +913,7 @@ class ThrowingClass(baseclass, metaclass=metaclass):
def function(self):
raise exc_to_throw

with pytest.raises(Exception) as exc:
with pytest.raises(Exception, match=str(exc_to_throw)) as exc:
ThrowingClass().function()

assert exc.value == exc_to_throw
Expand Down Expand Up @@ -998,14 +998,14 @@ def patch_function(original, *args, **kwargs):

patch_function = with_managed_run("test_integration", patch_function)

with pytest.raises(Exception):
with pytest.raises(Exception, match="bad implementation"):
patch_function(lambda: "foo")

assert patch_function_active_run is not None
status1 = client.get_run(patch_function_active_run.info.run_id).info.status
assert RunStatus.from_string(status1) == RunStatus.FAILED

with mlflow.start_run() as active_run, pytest.raises(Exception):
with mlflow.start_run() as active_run, pytest.raises(Exception, match="bad implementation"):
patch_function(lambda: "foo")
assert patch_function_active_run == active_run
# `with_managed_run` should not terminate a preexisting MLflow run,
Expand Down Expand Up @@ -1053,14 +1053,14 @@ def _on_exception(self, exception):

TestPatch = with_managed_run("test_integration", TestPatch)

with pytest.raises(Exception):
with pytest.raises(Exception, match="bad implementation"):
TestPatch.call(lambda: "foo")

assert patch_function_active_run is not None
status1 = client.get_run(patch_function_active_run.info.run_id).info.status
assert RunStatus.from_string(status1) == RunStatus.FAILED

with mlflow.start_run() as active_run, pytest.raises(Exception):
with mlflow.start_run() as active_run, pytest.raises(Exception, match="bad implementation"):
TestPatch.call(lambda: "foo")
assert patch_function_active_run == active_run
# `with_managed_run` should not terminate a preexisting MLflow run,
Expand Down Expand Up @@ -1108,7 +1108,7 @@ def original():
"test_integration", lambda original, *args, **kwargs: original(*args, **kwargs)
)

with pytest.raises(KeyboardInterrupt):
with pytest.raises(KeyboardInterrupt, match=""):
patch_function_1(original)

assert not mlflow.active_run()
Expand All @@ -1124,7 +1124,7 @@ def _on_exception(self, exception):

patch_function_2 = with_managed_run("test_integration", PatchFunction2)

with pytest.raises(KeyboardInterrupt):
with pytest.raises(KeyboardInterrupt, match=""):

patch_function_2.call(original)

Expand Down Expand Up @@ -1418,8 +1418,8 @@ def patch_fn(original):

# If use safe_patch to patch, exception would not come from original fn and so would be logged
patch_destination.fn = patch_fn
with pytest.raises(Exception):
patch_destination.fn()
with pytest.raises(Exception, match="Exception that should stop autologging session"):
patch_destination.fn(lambda: None)

assert _AutologgingSessionManager.active_session() is None

Expand Down Expand Up @@ -1573,7 +1573,7 @@ def _predict(self, X, a, b):
@property
def predict(self):
if not self._has_predict:
raise AttributeError()
raise AttributeError("does not have predict")
return self._predict

class ExtendedEstimator(BaseEstimator):
Expand Down Expand Up @@ -1624,7 +1624,7 @@ def autolog(disable=False, exclusive=False, silent=False): # pylint: disable=un

bad_estimator = EstimatorCls(has_predict=False)
assert not hasattr(bad_estimator, "predict")
with pytest.raises(AttributeError):
with pytest.raises(AttributeError, match="does not have predict"):
bad_estimator.predict(X=1, a=2, b=3)

autolog(disable=True)
Expand Down
4 changes: 2 additions & 2 deletions tests/autologging/test_autologging_utils.py
Expand Up @@ -911,9 +911,9 @@ def f4(self, *args, **kwargs):
assert 3 == get_instance_method_first_arg_value(Test.f1, [3], {"cd2": 4})
assert 3 == get_instance_method_first_arg_value(Test.f1, [], {"ab1": 3, "cd2": 4})
assert 3 == get_instance_method_first_arg_value(Test.f2, [3, 4], {})
with pytest.raises(AssertionError):
with pytest.raises(AssertionError, match=""):
get_instance_method_first_arg_value(Test.f3, [], {"ab1": 3, "cd2": 4})
with pytest.raises(AssertionError):
with pytest.raises(AssertionError, match=""):
get_instance_method_first_arg_value(Test.f4, [], {"ab1": 3, "cd2": 4})


Expand Down
6 changes: 4 additions & 2 deletions tests/azureml/test_deploy.py
Expand Up @@ -287,7 +287,9 @@ def test_deploy_throws_exception_if_model_does_not_contain_pyfunc_flavor(sklearn
del model_config.flavors[pyfunc.FLAVOR_NAME]
model_config.save(model_config_path)

with AzureMLMocks(), pytest.raises(MlflowException) as exc:
with AzureMLMocks(), pytest.raises(
MlflowException, match="does not contain the `python_function` flavor"
) as exc:
workspace = get_azure_workspace()
mlflow.azureml.deploy(model_uri=model_path, workspace=workspace)
assert exc.error_code == INVALID_PARAMETER_VALUE
Expand All @@ -304,7 +306,7 @@ def test_deploy_throws_exception_if_model_python_version_is_less_than_three(
model_config.flavors[pyfunc.FLAVOR_NAME][pyfunc.PY_VERSION] = "2.7.6"
model_config.save(model_config_path)

with AzureMLMocks(), pytest.raises(MlflowException) as exc:
with AzureMLMocks(), pytest.raises(MlflowException, match="Python 3 and above") as exc:
workspace = get_azure_workspace()
mlflow.azureml.deploy(model_uri=model_path, workspace=workspace)
assert exc.error_code == INVALID_PARAMETER_VALUE
Expand Down
6 changes: 4 additions & 2 deletions tests/azureml/test_image_creation.py
Expand Up @@ -417,7 +417,9 @@ def test_build_image_throws_exception_if_model_does_not_contain_pyfunc_flavor(
del model_config.flavors[pyfunc.FLAVOR_NAME]
model_config.save(model_config_path)

with AzureMLMocks(), pytest.raises(MlflowException) as exc:
with AzureMLMocks(), pytest.raises(
MlflowException, match="does not contain the `python_function` flavor"
) as exc:
workspace = get_azure_workspace()
mlflow.azureml.build_image(model_uri=model_path, workspace=workspace)
assert exc.error_code == INVALID_PARAMETER_VALUE
Expand All @@ -434,7 +436,7 @@ def test_build_image_throws_exception_if_model_python_version_is_less_than_three
model_config.flavors[pyfunc.FLAVOR_NAME][pyfunc.PY_VERSION] = "2.7.6"
model_config.save(model_config_path)

with AzureMLMocks(), pytest.raises(MlflowException) as exc:
with AzureMLMocks(), pytest.raises(MlflowException, match="Python 3 and above") as exc:
workspace = get_azure_workspace()
mlflow.azureml.build_image(model_uri=model_path, workspace=workspace)
assert exc.error_code == INVALID_PARAMETER_VALUE
Expand Down
2 changes: 1 addition & 1 deletion tests/data/test_data.py
Expand Up @@ -50,7 +50,7 @@ def test_download_uri():
# Verify exceptions are thrown when downloading from unsupported/invalid URIs
invalid_prefixes = ["file://", "/tmp"]
for prefix in invalid_prefixes:
with temp_directory() as dst_dir, pytest.raises(DownloadException):
with temp_directory() as dst_dir, pytest.raises(DownloadException, match="`uri` must be"):
download_uri(
uri=os.path.join(prefix, "some/path"), output_path=os.path.join(dst_dir, "tmp-file")
)
8 changes: 5 additions & 3 deletions tests/deployments/test_deployments.py
Expand Up @@ -45,7 +45,9 @@ def test_get_success():


def test_wrong_target_name():
with pytest.raises(MlflowException):
with pytest.raises(
MlflowException, match='No plugin found for managing model deployments to "wrong_target"'
):
deployments.get_deploy_client("wrong_target")


Expand All @@ -56,15 +58,15 @@ class DummyPlugin:
dummy_plugin = DummyPlugin()
plugin_manager = DeploymentPlugins()
plugin_manager.registry["dummy"] = dummy_plugin
with pytest.raises(MlflowException):
with pytest.raises(MlflowException, match="Plugin registered for the target dummy"):
plugin_manager["dummy"] # pylint: disable=pointless-statement


def test_plugin_raising_error():
client = deployments.get_deploy_client(f_target)
# special case to raise error
os.environ["raiseError"] = "True"
with pytest.raises(RuntimeError):
with pytest.raises(RuntimeError, match="Error requested"):
client.list_deployments()
os.environ["raiseError"] = "False"

Expand Down
2 changes: 1 addition & 1 deletion tests/entities/test_run.py
Expand Up @@ -95,6 +95,6 @@ def test_string_repr(self):

def test_creating_run_with_absent_info_throws_exception(self):
run_data = TestRunData._create()[0]
with pytest.raises(MlflowException) as no_info_exc:
with pytest.raises(MlflowException, match="run_info cannot be None") as no_info_exc:
Run(None, run_data)
assert "run_info cannot be None" in str(no_info_exc)