diff --git a/mlflow/utils/databricks_utils.py b/mlflow/utils/databricks_utils.py index f9701dca1d0b5..4e5e609028329 100644 --- a/mlflow/utils/databricks_utils.py +++ b/mlflow/utils/databricks_utils.py @@ -31,7 +31,8 @@ def wrapper(*args, **kwargs): if context is not None and hasattr(context, name): return getattr(context, name) except Exception: - return f(*args, **kwargs) + pass + return f(*args, **kwargs) return wrapper diff --git a/tests/utils/test_databricks_utils.py b/tests/utils/test_databricks_utils.py index 066e1ccc398ac..a3a6449a516f0 100644 --- a/tests/utils/test_databricks_utils.py +++ b/tests/utils/test_databricks_utils.py @@ -258,6 +258,18 @@ def mock_import(name, *args, **kwargs): def test_use_repl_context_if_available(tmpdir): + # Simulate a case where `dbruntime.databricks_repl_context.get_context` is unavailable. + with pytest.raises(ModuleNotFoundError, match="No module named 'dbruntime'"): + from dbruntime.databricks_repl_context import get_context # pylint: disable=unused-import + + command_context_mock = mock.MagicMock() + command_context_mock.jobId().get.return_value = "job_id" + with mock.patch( + "mlflow.utils.databricks_utils._get_command_context", return_value=command_context_mock + ) as mock_get_command_context: + assert databricks_utils.get_job_id() == "job_id" + mock_get_command_context.assert_called_once() + # Create a fake databricks_repl_context module tmpdir.mkdir("dbruntime").join("databricks_repl_context.py").write( """ @@ -267,6 +279,16 @@ def get_context(): ) sys.path.append(tmpdir.strpath) + # Simulate a case where the REPL context object is not initialized. + with mock.patch( + "dbruntime.databricks_repl_context.get_context", + return_value=None, + ) as mock_get_context, mock.patch( + "mlflow.utils.databricks_utils._get_command_context", return_value=command_context_mock + ) as mock_get_command_context: + assert databricks_utils.get_job_id() == "job_id" + mock_get_command_context.assert_called_once() + with mock.patch( "dbruntime.databricks_repl_context.get_context", return_value=mock.MagicMock(jobId="job_id"),