Skip to content

Commit

Permalink
Fix _use_repl_context_if_available to return original function resu…
Browse files Browse the repository at this point in the history
…lt when `get_context` returns None (#5194)

* return original function result when context is None

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* add unit tests

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
  • Loading branch information
harupy committed Dec 23, 2021
1 parent 4c5bdaa commit c098669
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
3 changes: 2 additions & 1 deletion mlflow/utils/databricks_utils.py
Expand Up @@ -31,7 +31,8 @@ def wrapper(*args, **kwargs):
if context is not None and hasattr(context, name):
return getattr(context, name)
except Exception:
return f(*args, **kwargs)
pass
return f(*args, **kwargs)

return wrapper

Expand Down
22 changes: 22 additions & 0 deletions tests/utils/test_databricks_utils.py
Expand Up @@ -258,6 +258,18 @@ def mock_import(name, *args, **kwargs):


def test_use_repl_context_if_available(tmpdir):
# Simulate a case where `dbruntime.databricks_repl_context.get_context` is unavailable.
with pytest.raises(ModuleNotFoundError, match="No module named 'dbruntime'"):
from dbruntime.databricks_repl_context import get_context # pylint: disable=unused-import

command_context_mock = mock.MagicMock()
command_context_mock.jobId().get.return_value = "job_id"
with mock.patch(
"mlflow.utils.databricks_utils._get_command_context", return_value=command_context_mock
) as mock_get_command_context:
assert databricks_utils.get_job_id() == "job_id"
mock_get_command_context.assert_called_once()

# Create a fake databricks_repl_context module
tmpdir.mkdir("dbruntime").join("databricks_repl_context.py").write(
"""
Expand All @@ -267,6 +279,16 @@ def get_context():
)
sys.path.append(tmpdir.strpath)

# Simulate a case where the REPL context object is not initialized.
with mock.patch(
"dbruntime.databricks_repl_context.get_context",
return_value=None,
) as mock_get_context, mock.patch(
"mlflow.utils.databricks_utils._get_command_context", return_value=command_context_mock
) as mock_get_command_context:
assert databricks_utils.get_job_id() == "job_id"
mock_get_command_context.assert_called_once()

with mock.patch(
"dbruntime.databricks_repl_context.get_context",
return_value=mock.MagicMock(jobId="job_id"),
Expand Down

0 comments on commit c098669

Please sign in to comment.