From 22e69c3bdba7c4af616f6f0ce4e74821dc665b10 Mon Sep 17 00:00:00 2001 From: Harutaka Kawamura Date: Mon, 22 Nov 2021 12:29:55 +0900 Subject: [PATCH] Enable `line-too-long` for pylint (#5085) * Enable line-too-long for pylint Signed-off-by: harupy * replace noqa Signed-off-by: harupy * lint Signed-off-by: harupy * fix Signed-off-by: harupy * define multi_context Signed-off-by: harupy * use mock_method_chain Signed-off-by: harupy * shorten URIs Signed-off-by: harupy * fix test_gcs_artifact_repo.py Signed-off-by: harupy * remove noaq Signed-off-by: harupy * use list Signed-off-by: harupy * fix examples Signed-off-by: harupy * always set return_value and side_effect Signed-off-by: harupy * enforce keyword arguments Signed-off-by: harupy --- dev/set_matrix.py | 2 +- mlflow/pytorch/_pytorch_autolog.py | 2 +- mlflow/sklearn/utils.py | 4 +- mlflow/tracking/client.py | 4 +- mlflow/xgboost/__init__.py | 2 +- pylintrc | 5 +- .../test_autologging_safety_unit.py | 4 +- tests/helper_functions.py | 41 ++++++++++++- tests/resources/mlflow-test-plugin/setup.py | 10 ++-- tests/shap/test_log.py | 3 +- .../artifact/test_databricks_artifact_repo.py | 2 +- .../store/artifact/test_gcs_artifact_repo.py | 58 +++++++++++++++---- .../context/test_databricks_job_context.py | 8 ++- .../test_databricks_notebook_context.py | 7 ++- tests/tracking/fluent/test_fluent.py | 40 +++++++++++-- tests/tracking/test_tracking.py | 4 +- tests/utils/test_databricks_utils.py | 47 ++++++++------- tests/utils/test_uri.py | 38 ++++++------ 18 files changed, 203 insertions(+), 78 deletions(-) diff --git a/dev/set_matrix.py b/dev/set_matrix.py index aa3d5160c5242..d90f1e556907d 100644 --- a/dev/set_matrix.py +++ b/dev/set_matrix.py @@ -562,7 +562,7 @@ def main(args): if "GITHUB_ACTIONS" in os.environ: # `::set-output` is a special syntax for GitHub Actions to set an action's output parameter. - # https://docs.github.com/en/free-pro-team@latest/actions/reference/workflow-commands-for-github-actions#setting-an-output-parameter # noqa + # https://docs.github.com/en/free-pro-team@latest/actions/reference/workflow-commands-for-github-actions#setting-an-output-parameter # Note that this actually doesn't print anything to the console. print("::set-output name=matrix::{}".format(json.dumps(matrix))) diff --git a/mlflow/pytorch/_pytorch_autolog.py b/mlflow/pytorch/_pytorch_autolog.py index 9b9ffc37b3cf0..87d75fb0fcc0e 100644 --- a/mlflow/pytorch/_pytorch_autolog.py +++ b/mlflow/pytorch/_pytorch_autolog.py @@ -69,7 +69,7 @@ def _log_metrics(self, trainer, pl_module): # pytorch-lightning runs a few steps of validation in the beginning of training # as a sanity check to catch bugs without having to wait for the training routine # to complete. During this check, we should skip logging metrics. - # https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#num-sanity-val-steps # noqa + # https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#num-sanity-val-steps sanity_checking = ( # `running_sanity_check` has been renamed to `sanity_checking`: # https://github.com/PyTorchLightning/pytorch-lightning/pull/9209 diff --git a/mlflow/sklearn/utils.py b/mlflow/sklearn/utils.py index 1a51a0d79678b..930c1a9210c8d 100644 --- a/mlflow/sklearn/utils.py +++ b/mlflow/sklearn/utils.py @@ -92,7 +92,7 @@ def _get_sample_weight(arg_names, args, kwargs): # In most cases, X_var_name and y_var_name become "X" and "y", respectively. # However, certain sklearn models use different variable names for X and y. - # E.g., see: https://scikit-learn.org/stable/modules/generated/sklearn.multioutput.MultiOutputClassifier.html#sklearn.multioutput.MultiOutputClassifier.fit # noqa: E501 + # E.g., see: https://scikit-learn.org/stable/modules/generated/sklearn.multioutput.MultiOutputClassifier.html#sklearn.multioutput.MultiOutputClassifier.fit X_var_name, y_var_name = fit_arg_names[:2] Xy = _get_Xy(fit_args, fit_kwargs, X_var_name, y_var_name) sample_weight = ( @@ -598,7 +598,7 @@ def _create_child_runs_for_parameter_search( parameter search estimator - `cv_estimator`, which provides relevant performance metrics for each point in the parameter search space. One child run is created for each point in the parameter search space. For additional information, see - `https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html`_. # noqa: E501 + `https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html`_. :param autologging_client: An instance of `MlflowAutologgingQueueingClient` used for efficiently logging run data to MLflow Tracking. diff --git a/mlflow/tracking/client.py b/mlflow/tracking/client.py index 1e8d03612baa9..f1e06dae2609e 100644 --- a/mlflow/tracking/client.py +++ b/mlflow/tracking/client.py @@ -1235,7 +1235,7 @@ def _is_numpy_array(image): return isinstance(image, np.ndarray) def _normalize_to_uint8(x): - # Based on: https://github.com/matplotlib/matplotlib/blob/06567e021f21be046b6d6dcf00380c1cb9adaf3c/lib/matplotlib/image.py#L684 # noqa + # Based on: https://github.com/matplotlib/matplotlib/blob/06567e021f21be046b6d6dcf00380c1cb9adaf3c/lib/matplotlib/image.py#L684 is_int = np.issubdtype(x.dtype, np.integer) low = 0 @@ -1268,7 +1268,7 @@ def _normalize_to_uint8(x): "Please install it via: pip install Pillow" ) from exc - # Ref.: https://numpy.org/doc/stable/reference/generated/numpy.dtype.kind.html#numpy-dtype-kind # noqa + # Ref.: https://numpy.org/doc/stable/reference/generated/numpy.dtype.kind.html#numpy-dtype-kind valid_data_types = { "b": "bool", "i": "signed integer", diff --git a/mlflow/xgboost/__init__.py b/mlflow/xgboost/__init__.py index c9cee4633f678..3d664a51dcfbb 100644 --- a/mlflow/xgboost/__init__.py +++ b/mlflow/xgboost/__init__.py @@ -426,7 +426,7 @@ def record_eval_results(eval_results, metrics_logger): # In xgboost >= 1.3.0, user-defined callbacks should inherit # `xgboost.callback.TrainingCallback`: - # https://xgboost.readthedocs.io/en/latest/python/callbacks.html#defining-your-own-callback # noqa + # https://xgboost.readthedocs.io/en/latest/python/callbacks.html#defining-your-own-callback return AutologCallback(metrics_logger, eval_results) else: from mlflow.xgboost._autolog import autolog_callback diff --git a/pylintrc b/pylintrc index 013dc2c817c41..cb8bc449d36e2 100644 --- a/pylintrc +++ b/pylintrc @@ -177,7 +177,8 @@ disable=print-statement, # either give multiple identifier separated by comma (,) or put this option # multiple time (only on the command line, not in the configuration file where # it should appear only once). See also the "--disable" option for examples. -enable=c-extension-no-member +enable=c-extension-no-member, + line-too-long [REPORTS] @@ -330,7 +331,7 @@ variable-naming-style=snake_case expected-line-ending-format= # Regexp for a line that is allowed to be longer than the limit. -ignore-long-lines=^\s*(# )??$ +ignore-long-lines=https?://\S+ # Number of spaces of indent required inside a hanging or continued line. indent-after-paren=4 diff --git a/tests/autologging/test_autologging_safety_unit.py b/tests/autologging/test_autologging_safety_unit.py index f42e08f0b062e..b886f1202a84b 100644 --- a/tests/autologging/test_autologging_safety_unit.py +++ b/tests/autologging/test_autologging_safety_unit.py @@ -644,7 +644,7 @@ def patch_impl(original, *args, **kwargs): assert patch_success.exception is original_success.exception is None -def test_safe_patch_makes_expected_event_logging_calls_when_patch_implementation_throws_and_original_succeeds( # noqa +def test_safe_patch_makes_expected_event_logging_calls_when_patch_implementation_throws_and_original_succeeds( # pylint: disable=line-too-long patch_destination, test_autologging_integration, mock_event_logger, @@ -684,7 +684,7 @@ def patch_impl(original, *args, **kwargs): assert patch_error.exception == exc_to_raise -def test_safe_patch_makes_expected_event_logging_calls_when_patch_implementation_throws_and_original_throws( # noqa +def test_safe_patch_makes_expected_event_logging_calls_when_patch_implementation_throws_and_original_throws( # pylint: disable=line-too-long patch_destination, test_autologging_integration, mock_event_logger, diff --git a/tests/helper_functions.py b/tests/helper_functions.py index d76214019722f..c15fe34262dbb 100644 --- a/tests/helper_functions.py +++ b/tests/helper_functions.py @@ -1,6 +1,9 @@ import os import random +import functools from unittest import mock +from contextlib import ExitStack, contextmanager + import requests import time @@ -241,7 +244,7 @@ def __exit__(self, tp, val, traceback): pgrp = os.getpgid(self._proc.pid) os.killpg(pgrp, signal.SIGTERM) else: - # https://stackoverflow.com/questions/47016723/windows-equivalent-for-spawning-and-killing-separate-process-group-in-python-3 # noqa + # https://stackoverflow.com/questions/47016723/windows-equivalent-for-spawning-and-killing-separate-process-group-in-python-3 self._proc.send_signal(signal.CTRL_BREAK_EVENT) self._proc.kill() @@ -406,3 +409,39 @@ def decorator(f): return pytest.mark.allow_infer_pip_requirements_fallback(f) if condition else f return decorator + + +def mock_method_chain(mock_obj, methods, return_value=None, side_effect=None): + """ + Mock a chain of methods. + + Examples + -------- + >>> from unittest import mock + >>> m = mock.MagicMock() + >>> mock_method_chain(m, ["a", "b"], return_value=0) + >>> m.a().b() + 0 + >>> mock_method_chain(m, ["c.d", "e"], return_value=1) + >>> m.c.d().e() + 1 + >>> mock_method_chain(m, ["f"], side_effect=Exception("side_effect")) + >>> m.f() + Traceback (most recent call last): + ... + Exception: side_effect + """ + length = len(methods) + for idx, method in enumerate(methods): + mock_obj = functools.reduce(getattr, method.split("."), mock_obj) + if idx != length - 1: + mock_obj = mock_obj.return_value + else: + mock_obj.return_value = return_value + mock_obj.side_effect = side_effect + + +@contextmanager +def multi_context(*cms): + with ExitStack() as stack: + yield list(map(stack.enter_context, cms)) diff --git a/tests/resources/mlflow-test-plugin/setup.py b/tests/resources/mlflow-test-plugin/setup.py index 61af2d9bfc593..dd985de91951b 100644 --- a/tests/resources/mlflow-test-plugin/setup.py +++ b/tests/resources/mlflow-test-plugin/setup.py @@ -13,17 +13,17 @@ # Define a Tracking Store plugin for tracking URIs with scheme 'file-plugin' "mlflow.tracking_store": "file-plugin=mlflow_test_plugin.file_store:PluginFileStore", # Define a ArtifactRepository plugin for artifact URIs with scheme 'file-plugin' - "mlflow.artifact_repository": "file-plugin=mlflow_test_plugin.local_artifact:PluginLocalArtifactRepository", # noqa + "mlflow.artifact_repository": "file-plugin=mlflow_test_plugin.local_artifact:PluginLocalArtifactRepository", # pylint: disable=line-too-long # Define a RunContextProvider plugin. The entry point name for run context providers # is not used, and so is set to the string "unused" here - "mlflow.run_context_provider": "unused=mlflow_test_plugin.run_context_provider:PluginRunContextProvider", # noqa + "mlflow.run_context_provider": "unused=mlflow_test_plugin.run_context_provider:PluginRunContextProvider", # pylint: disable=line-too-long # Define a RequestHeaderProvider plugin. The entry point name for request header providers # is not used, and so is set to the string "unused" here - "mlflow.request_header_provider": "unused=mlflow_test_plugin.request_header_provider:PluginRequestHeaderProvider", # noqa + "mlflow.request_header_provider": "unused=mlflow_test_plugin.request_header_provider:PluginRequestHeaderProvider", # pylint: disable=line-too-long # Define a Model Registry Store plugin for tracking URIs with scheme 'file-plugin' - "mlflow.model_registry_store": "file-plugin=mlflow_test_plugin.sqlalchemy_store:PluginRegistrySqlAlchemyStore", # noqa + "mlflow.model_registry_store": "file-plugin=mlflow_test_plugin.sqlalchemy_store:PluginRegistrySqlAlchemyStore", # pylint: disable=line-too-long # Define a MLflow Project Backend plugin called 'dummy-backend' - "mlflow.project_backend": "dummy-backend=mlflow_test_plugin.dummy_backend:PluginDummyProjectBackend", # noqa + "mlflow.project_backend": "dummy-backend=mlflow_test_plugin.dummy_backend:PluginDummyProjectBackend", # pylint: disable=line-too-long # Define a MLflow model deployment plugin for target 'faketarget' "mlflow.deployments": "faketarget=mlflow_test_plugin.fake_deployment_plugin", }, diff --git a/tests/shap/test_log.py b/tests/shap/test_log.py index acbec88a3a9b0..b454136b2585e 100644 --- a/tests/shap/test_log.py +++ b/tests/shap/test_log.py @@ -293,7 +293,8 @@ def test_pyfunc_serve_and_score(): # `link` defaults to `shap.links.identity` which is decorated by `numba.jit` and causes # the following error when loading the explainer for serving: # ``` - # Exception: The passed link function needs to be callable and have a callable .inverse property! # noqa + # Exception: The passed link function needs to be callable and have a callable + # .inverse property! # ``` # As a workaround, use an identify function that's NOT decorated by `numba.jit`. link=create_identity_function(), diff --git a/tests/store/artifact/test_databricks_artifact_repo.py b/tests/store/artifact/test_databricks_artifact_repo.py index c344af34972ba..e0e1d1ae5e003 100644 --- a/tests/store/artifact/test_databricks_artifact_repo.py +++ b/tests/store/artifact/test_databricks_artifact_repo.py @@ -152,7 +152,7 @@ def test_init_artifact_uri(self, artifact_uri, expected_uri, expected_db_uri): ("dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts", ""), ("dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts/arty", "arty"), ( - "dbfs://prof@databricks/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts/arty", # noqa + "dbfs://prof@databricks/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts/arty", # pylint: disable=line-too-long "arty", ), ( diff --git a/tests/store/artifact/test_gcs_artifact_repo.py b/tests/store/artifact/test_gcs_artifact_repo.py index d9d555351dc49..63a3f82a15b88 100644 --- a/tests/store/artifact/test_gcs_artifact_repo.py +++ b/tests/store/artifact/test_gcs_artifact_repo.py @@ -5,10 +5,11 @@ from unittest import mock from google.cloud.storage import client as gcs_client +from google.auth.exceptions import DefaultCredentialsError from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository from mlflow.store.artifact.gcs_artifact_repo import GCSArtifactRepository -from google.auth.exceptions import DefaultCredentialsError +from tests.helper_functions import mock_method_chain @pytest.fixture @@ -123,8 +124,15 @@ def test_log_artifact(gcs_mock, tmpdir): # This will call isfile on the code path being used, # thus testing that it's being called with an actually file path - gcs_mock.Client.return_value.bucket.return_value.blob.return_value.upload_from_filename.side_effect = ( # noqa - os.path.isfile + mock_method_chain( + gcs_mock, + [ + "Client", + "bucket", + "blob", + "upload_from_filename", + ], + side_effect=os.path.isfile, ) repo.log_artifact(fpath) @@ -141,8 +149,15 @@ def test_log_artifacts(gcs_mock, tmpdir): subd.join("b.txt").write("B") subd.join("c.txt").write("C") - gcs_mock.Client.return_value.bucket.return_value.blob.return_value.upload_from_filename.side_effect = ( # noqa - os.path.isfile + mock_method_chain( + gcs_mock, + [ + "Client", + "bucket", + "blob", + "upload_from_filename", + ], + side_effect=os.path.isfile, ) repo.log_artifacts(subd.strpath) @@ -165,8 +180,15 @@ def mkfile(fname): f = tmpdir.join(fname) f.write("hello world!") - gcs_mock.Client.return_value.bucket.return_value.blob.return_value.download_to_filename.side_effect = ( # noqa - mkfile + mock_method_chain( + gcs_mock, + [ + "Client", + "bucket", + "blob", + "download_to_filename", + ], + side_effect=mkfile, ) repo.download_artifacts("test.txt") @@ -230,10 +252,24 @@ def mkfile(fname): f = tmpdir.join(fname) f.write("hello world!") - gcs_mock.Client.return_value.bucket.return_value.list_blobs.side_effect = get_mock_listing - - gcs_mock.Client.return_value.bucket.return_value.blob.return_value.download_to_filename.side_effect = ( # noqa - mkfile + mock_method_chain( + gcs_mock, + [ + "Client", + "bucket", + "list_blobs", + ], + side_effect=get_mock_listing, + ) + mock_method_chain( + gcs_mock, + [ + "Client", + "bucket", + "blob", + "download_to_filename", + ], + side_effect=mkfile, ) # Ensure that the root directory can be downloaded successfully diff --git a/tests/tracking/context/test_databricks_job_context.py b/tests/tracking/context/test_databricks_job_context.py index d5b290f496277..cfbf502118556 100644 --- a/tests/tracking/context/test_databricks_job_context.py +++ b/tests/tracking/context/test_databricks_job_context.py @@ -10,6 +10,7 @@ MLFLOW_DATABRICKS_WEBAPP_URL, ) from mlflow.tracking.context.databricks_job_context import DatabricksJobRunContext +from tests.helper_functions import multi_context def test_databricks_job_run_context_in_context(): @@ -23,7 +24,12 @@ def test_databricks_job_run_context_tags(): patch_job_type = mock.patch("mlflow.utils.databricks_utils.get_job_type") patch_webapp_url = mock.patch("mlflow.utils.databricks_utils.get_webapp_url") - with patch_job_id as job_id_mock, patch_job_run_id as job_run_id_mock, patch_job_type as job_type_mock, patch_webapp_url as webapp_url_mock: # noqa + with multi_context(patch_job_id, patch_job_run_id, patch_job_type, patch_webapp_url) as ( + job_id_mock, + job_run_id_mock, + job_type_mock, + webapp_url_mock, + ): assert DatabricksJobRunContext().tags() == { MLFLOW_SOURCE_NAME: "jobs/{job_id}/run/{job_run_id}".format( job_id=job_id_mock.return_value, job_run_id=job_run_id_mock.return_value diff --git a/tests/tracking/context/test_databricks_notebook_context.py b/tests/tracking/context/test_databricks_notebook_context.py index 50b54524c2c00..15a8d9c185c8f 100644 --- a/tests/tracking/context/test_databricks_notebook_context.py +++ b/tests/tracking/context/test_databricks_notebook_context.py @@ -9,6 +9,7 @@ MLFLOW_DATABRICKS_WEBAPP_URL, ) from mlflow.tracking.context.databricks_notebook_context import DatabricksNotebookRunContext +from tests.helper_functions import multi_context def test_databricks_notebook_run_context_in_context(): @@ -21,7 +22,11 @@ def test_databricks_notebook_run_context_tags(): patch_notebook_path = mock.patch("mlflow.utils.databricks_utils.get_notebook_path") patch_webapp_url = mock.patch("mlflow.utils.databricks_utils.get_webapp_url") - with patch_notebook_id as notebook_id_mock, patch_notebook_path as notebook_path_mock, patch_webapp_url as webapp_url_mock: # noqa + with multi_context(patch_notebook_id, patch_notebook_path, patch_webapp_url) as ( + notebook_id_mock, + notebook_path_mock, + webapp_url_mock, + ): assert DatabricksNotebookRunContext().tags() == { MLFLOW_SOURCE_NAME: notebook_path_mock.return_value, MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.NOTEBOOK), diff --git a/tests/tracking/fluent/test_fluent.py b/tests/tracking/fluent/test_fluent.py index d8cc63076f213..cbd71035ce135 100644 --- a/tests/tracking/fluent/test_fluent.py +++ b/tests/tracking/fluent/test_fluent.py @@ -47,6 +47,7 @@ from mlflow.utils.file_utils import TempDir from tests.tracking.integration_test_utils import _init_server +from tests.helper_functions import multi_context class HelperEnv: @@ -334,7 +335,15 @@ def test_start_run_defaults(empty_active_run_stack): # pylint: disable=unused-a create_run_patch = mock.patch.object(MlflowClient, "create_run") - with experiment_id_patch, databricks_notebook_patch, user_patch, source_name_patch, source_type_patch, source_version_patch, create_run_patch: # noqa + with multi_context( + experiment_id_patch, + databricks_notebook_patch, + user_patch, + source_name_patch, + source_type_patch, + source_version_patch, + create_run_patch, + ): active_run = start_run() MlflowClient.create_run.assert_called_once_with( experiment_id=mock_experiment_id, tags=expected_tags @@ -386,7 +395,16 @@ def test_start_run_defaults_databricks_notebook( create_run_patch = mock.patch.object(MlflowClient, "create_run") - with experiment_id_patch, databricks_notebook_patch, user_patch, source_version_patch, notebook_id_patch, notebook_path_patch, webapp_url_patch, create_run_patch: # noqa + with multi_context( + experiment_id_patch, + databricks_notebook_patch, + user_patch, + source_version_patch, + notebook_id_patch, + notebook_path_patch, + webapp_url_patch, + create_run_patch, + ): active_run = start_run() MlflowClient.create_run.assert_called_once_with( experiment_id=mock_experiment_id, tags=expected_tags @@ -435,7 +453,15 @@ def test_start_run_creates_new_run_with_user_specified_tags(): create_run_patch = mock.patch.object(MlflowClient, "create_run") - with experiment_id_patch, databricks_notebook_patch, user_patch, source_name_patch, source_type_patch, source_version_patch, create_run_patch: # noqa + with multi_context( + experiment_id_patch, + databricks_notebook_patch, + user_patch, + source_name_patch, + source_type_patch, + source_version_patch, + create_run_patch, + ): active_run = start_run(tags=user_specified_tags) MlflowClient.create_run.assert_called_once_with( experiment_id=mock_experiment_id, tags=expected_tags @@ -483,7 +509,13 @@ def test_start_run_with_parent(): create_run_patch = mock.patch.object(MlflowClient, "create_run") - with databricks_notebook_patch, active_run_stack_patch, create_run_patch, user_patch, source_name_patch: # noqa + with multi_context( + databricks_notebook_patch, + active_run_stack_patch, + create_run_patch, + user_patch, + source_name_patch, + ): active_run = start_run(experiment_id=mock_experiment_id, nested=True) MlflowClient.create_run.assert_called_once_with( experiment_id=mock_experiment_id, tags=expected_tags diff --git a/tests/tracking/test_tracking.py b/tests/tracking/test_tracking.py index 27bfa388e7870..0683513ca8fec 100644 --- a/tests/tracking/test_tracking.py +++ b/tests/tracking/test_tracking.py @@ -786,7 +786,7 @@ def test_log_image_numpy_shape(size): @pytest.mark.parametrize( "dtype", [ - # Ref.: https://numpy.org/doc/stable/user/basics.types.html#array-types-and-conversions-between-types # noqa + # Ref.: https://numpy.org/doc/stable/user/basics.types.html#array-types-and-conversions-between-types "int8", "int16", "int32", @@ -947,7 +947,7 @@ def test_get_artifact_uri_uses_currently_active_run_id(): ), ( "mysql+driver://user:password@host:port/dbname/subpath/#fragment", - "mysql+driver://user:password@host:port/dbname/subpath/{run_id}/artifacts/{path}#fragment", # noqa + "mysql+driver://user:password@host:port/dbname/subpath/{run_id}/artifacts/{path}#fragment", # pylint: disable=line-too-long ), ("s3://bucketname/rootpath", "s3://bucketname/rootpath/{run_id}/artifacts/{path}"), ("/dirname/rootpa#th?", "/dirname/rootpa#th?/{run_id}/artifacts/{path}"), diff --git a/tests/utils/test_databricks_utils.py b/tests/utils/test_databricks_utils.py index 9577a88aa4d8c..94a837006e3d6 100644 --- a/tests/utils/test_databricks_utils.py +++ b/tests/utils/test_databricks_utils.py @@ -11,6 +11,7 @@ is_databricks_default_tracking_uri, ) from mlflow.utils.uri import construct_db_uri_from_profile +from tests.helper_functions import mock_method_chain def test_no_throw(): @@ -105,12 +106,12 @@ def test_get_workspace_info_from_databricks_secrets(): def test_get_workspace_info_from_dbutils(): mock_dbutils = mock.MagicMock() - mock_dbutils.notebook.entry_point.getDbutils.return_value.notebook.return_value.getContext.return_value.browserHostName.return_value.get.return_value = ( # noqa - "mlflow.databricks.com" - ) - mock_dbutils.notebook.entry_point.getDbutils.return_value.notebook.return_value.getContext.return_value.workspaceId.return_value.get.return_value = ( # noqa - "1111" + methods = ["notebook.entry_point.getDbutils", "notebook", "getContext"] + mock_method_chain( + mock_dbutils, methods + ["browserHostName", "get"], return_value="mlflow.databricks.com" ) + mock_method_chain(mock_dbutils, methods + ["workspaceId", "get"], return_value="1111") + with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils): workspace_host, workspace_id = get_workspace_info_from_dbutils() assert workspace_host == "https://mlflow.databricks.com" @@ -119,15 +120,12 @@ def test_get_workspace_info_from_dbutils(): def test_get_workspace_info_from_dbutils_no_browser_host_name(): mock_dbutils = mock.MagicMock() - mock_dbutils.notebook.entry_point.getDbutils.return_value.notebook.return_value.getContext.return_value.browserHostName.return_value.get.return_value = ( # noqa - None - ) - mock_dbutils.notebook.entry_point.getDbutils.return_value.notebook.return_value.getContext.return_value.apiUrl.return_value.get.return_value = ( # noqa - "https://mlflow.databricks.com" - ) - mock_dbutils.notebook.entry_point.getDbutils.return_value.notebook.return_value.getContext.return_value.workspaceId.return_value.get.return_value = ( # noqa - "1111" + methods = ["notebook.entry_point.getDbutils", "notebook", "getContext"] + mock_method_chain(mock_dbutils, methods + ["browserHostName", "get"], return_value=None) + mock_method_chain( + mock_dbutils, methods + ["apiUrl", "get"], return_value="https://mlflow.databricks.com" ) + mock_method_chain(mock_dbutils, methods + ["workspaceId", "get"], return_value="1111") with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils): workspace_host, workspace_id = get_workspace_info_from_dbutils() assert workspace_host == "https://mlflow.databricks.com" @@ -136,22 +134,29 @@ def test_get_workspace_info_from_dbutils_no_browser_host_name(): def test_get_workspace_info_from_dbutils_old_runtimes(): mock_dbutils = mock.MagicMock() - mock_dbutils.notebook.entry_point.getDbutils.return_value.notebook.return_value.getContext.return_value.toJson.return_value = ( # noqa - '{"tags": {"orgId" : "1111", "browserHostName": "mlflow.databricks.com"}}' + methods = ["notebook.entry_point.getDbutils", "notebook", "getContext"] + mock_method_chain( + mock_dbutils, + methods + ["toJson", "get"], + return_value='{"tags": {"orgId" : "1111", "browserHostName": "mlflow.databricks.com"}}', ) - mock_dbutils.notebook.entry_point.getDbutils.return_value.notebook.return_value.getContext.return_value.browserHostName.return_value.get.return_value = ( # noqa - "mlflow.databricks.com" + mock_method_chain( + mock_dbutils, methods + ["browserHostName", "get"], return_value="mlflow.databricks.com" ) + # Mock out workspace ID tag mock_workspace_id_tag_opt = mock.MagicMock() mock_workspace_id_tag_opt.isDefined.return_value = True mock_workspace_id_tag_opt.get.return_value = "1111" - mock_dbutils.notebook.entry_point.getDbutils.return_value.notebook.return_value.getContext.return_value.tags.return_value.get.return_value = ( # noqa - mock_workspace_id_tag_opt + mock_method_chain( + mock_dbutils, methods + ["tags", "get"], return_value=mock_workspace_id_tag_opt ) + # Mimic old runtimes by raising an exception when the nonexistent "workspaceId" method is called - mock_dbutils.notebook.entry_point.getDbutils.return_value.notebook.return_value.getContext.return_value.workspaceId.side_effect = Exception( # noqa - "workspaceId method not defined!" + mock_method_chain( + mock_dbutils, + methods + ["workspaceId"], + side_effect=Exception("workspaceId method not defined!"), ) with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils): workspace_host, workspace_id = get_workspace_info_from_dbutils() diff --git a/tests/utils/test_uri.py b/tests/utils/test_uri.py index 497cd9e77bcaf..c060201a1066d 100644 --- a/tests/utils/test_uri.py +++ b/tests/utils/test_uri.py @@ -266,24 +266,24 @@ def test_append_to_uri_path_preserves_uri_schemes_hosts_queries_and_fragments(): ("dbscheme+dbdriver://#somefrag", "subpath", "dbscheme+dbdriver:subpath#somefrag"), ("dbscheme+dbdriver:///#somefrag", "/subpath", "dbscheme+dbdriver:/subpath#somefrag"), ( - "dbscheme+dbdriver://root:password?creds=mycreds", + "dbscheme+dbdriver://root:password?creds=creds", "subpath", - "dbscheme+dbdriver://root:password/subpath?creds=mycreds", + "dbscheme+dbdriver://root:password/subpath?creds=creds", ), ( - "dbscheme+dbdriver://root:password/path/?creds=mycreds", + "dbscheme+dbdriver://root:password/path/?creds=creds", "/subpath/anotherpath", - "dbscheme+dbdriver://root:password/path/subpath/anotherpath?creds=mycreds", + "dbscheme+dbdriver://root:password/path/subpath/anotherpath?creds=creds", ), ( - "dbscheme+dbdriver://root:password///path/?creds=mycreds", + "dbscheme+dbdriver://root:password///path/?creds=creds", "subpath/anotherpath", - "dbscheme+dbdriver://root:password///path/subpath/anotherpath?creds=mycreds", + "dbscheme+dbdriver://root:password///path/subpath/anotherpath?creds=creds", ), ( - "dbscheme+dbdriver://root:password///path/?creds=mycreds", + "dbscheme+dbdriver://root:password///path/?creds=creds", "/subpath", - "dbscheme+dbdriver://root:password///path/subpath?creds=mycreds", + "dbscheme+dbdriver://root:password///path/subpath?creds=creds", ), ( "dbscheme+dbdriver://root:password#myfragment", @@ -291,30 +291,30 @@ def test_append_to_uri_path_preserves_uri_schemes_hosts_queries_and_fragments(): "dbscheme+dbdriver://root:password/subpath#myfragment", ), ( - "dbscheme+dbdriver://root:password//path/#myfragmentwith$pecial@", + "dbscheme+dbdriver://root:password//path/#fragmentwith$pecial@", "subpath/anotherpath", - "dbscheme+dbdriver://root:password//path/subpath/anotherpath#myfragmentwith$pecial@", # noqa + "dbscheme+dbdriver://root:password//path/subpath/anotherpath#fragmentwith$pecial@", ), ( - "dbscheme+dbdriver://root:password@myhostname?creds=mycreds#myfragmentwith$pecial@", + "dbscheme+dbdriver://root:password@host?creds=creds#fragmentwith$pecial@", "subpath", - "dbscheme+dbdriver://root:password@myhostname/subpath?creds=mycreds#myfragmentwith$pecial@", # noqa + "dbscheme+dbdriver://root:password@host/subpath?creds=creds#fragmentwith$pecial@", ), ( - "dbscheme+dbdriver://root:password@myhostname.com/path?creds=mycreds#*frag@*", + "dbscheme+dbdriver://root:password@host.com/path?creds=creds#*frag@*", "subpath/dir", - "dbscheme+dbdriver://root:password@myhostname.com/path/subpath/dir?creds=mycreds#*frag@*", # noqa + "dbscheme+dbdriver://root:password@host.com/path/subpath/dir?creds=creds#*frag@*", ), ( - "dbscheme-dbdriver://root:password@myhostname.com/path?creds=mycreds#*frag@*", + "dbscheme-dbdriver://root:password@host.com/path?creds=creds#*frag@*", "subpath/dir", - "dbscheme-dbdriver://root:password@myhostname.com/path/subpath/dir?creds=mycreds#*frag@*", # noqa + "dbscheme-dbdriver://root:password@host.com/path/subpath/dir?creds=creds#*frag@*", ), ( - "dbscheme+dbdriver://root:password@myhostname.com/path?creds=mycreds,param=value#*frag@*", # noqa + "dbscheme+dbdriver://root:password@host.com/path?creds=creds,param=value#*frag@*", "subpath/dir", - "dbscheme+dbdriver://root:password@myhostname.com/path/subpath/dir?" - "creds=mycreds,param=value#*frag@*", + "dbscheme+dbdriver://root:password@host.com/path/subpath/dir?" + "creds=creds,param=value#*frag@*", ), ] )