From a1727d81a95f910929dac668fc8cedee4b564875 Mon Sep 17 00:00:00 2001 From: harupy <17039389+harupy@users.noreply.github.com> Date: Thu, 23 Dec 2021 12:16:57 +0900 Subject: [PATCH 1/2] return original function result when context is None Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- mlflow/utils/databricks_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlflow/utils/databricks_utils.py b/mlflow/utils/databricks_utils.py index f9701dca1d0b5..4e5e609028329 100644 --- a/mlflow/utils/databricks_utils.py +++ b/mlflow/utils/databricks_utils.py @@ -31,7 +31,8 @@ def wrapper(*args, **kwargs): if context is not None and hasattr(context, name): return getattr(context, name) except Exception: - return f(*args, **kwargs) + pass + return f(*args, **kwargs) return wrapper From ad88e08484ff33d5ccfbd82322ac82bab2ec8c0e Mon Sep 17 00:00:00 2001 From: harupy <17039389+harupy@users.noreply.github.com> Date: Thu, 23 Dec 2021 12:44:28 +0900 Subject: [PATCH 2/2] add unit tests Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- tests/utils/test_databricks_utils.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/utils/test_databricks_utils.py b/tests/utils/test_databricks_utils.py index 066e1ccc398ac..a3a6449a516f0 100644 --- a/tests/utils/test_databricks_utils.py +++ b/tests/utils/test_databricks_utils.py @@ -258,6 +258,18 @@ def mock_import(name, *args, **kwargs): def test_use_repl_context_if_available(tmpdir): + # Simulate a case where `dbruntime.databricks_repl_context.get_context` is unavailable. + with pytest.raises(ModuleNotFoundError, match="No module named 'dbruntime'"): + from dbruntime.databricks_repl_context import get_context # pylint: disable=unused-import + + command_context_mock = mock.MagicMock() + command_context_mock.jobId().get.return_value = "job_id" + with mock.patch( + "mlflow.utils.databricks_utils._get_command_context", return_value=command_context_mock + ) as mock_get_command_context: + assert databricks_utils.get_job_id() == "job_id" + mock_get_command_context.assert_called_once() + # Create a fake databricks_repl_context module tmpdir.mkdir("dbruntime").join("databricks_repl_context.py").write( """ @@ -267,6 +279,16 @@ def get_context(): ) sys.path.append(tmpdir.strpath) + # Simulate a case where the REPL context object is not initialized. + with mock.patch( + "dbruntime.databricks_repl_context.get_context", + return_value=None, + ) as mock_get_context, mock.patch( + "mlflow.utils.databricks_utils._get_command_context", return_value=command_context_mock + ) as mock_get_command_context: + assert databricks_utils.get_job_id() == "job_id" + mock_get_command_context.assert_called_once() + with mock.patch( "dbruntime.databricks_repl_context.get_context", return_value=mock.MagicMock(jobId="job_id"),