Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix _use_repl_context_if_available to return original function result when get_context returns None #5194

Merged
merged 2 commits into from Dec 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlflow/utils/databricks_utils.py
Expand Up @@ -31,7 +31,8 @@ def wrapper(*args, **kwargs):
if context is not None and hasattr(context, name):
return getattr(context, name)
except Exception:
return f(*args, **kwargs)
pass
return f(*args, **kwargs)
Comment on lines -34 to +35
Copy link
Member Author

@harupy harupy Dec 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original code:

    def decorator(f):
        @functools.wraps(f)
        def wrapper(*args, **kwargs):
            try:
                from dbruntime.databricks_repl_context import get_context

                # We initialize the REPL context using a pre-command hook.
                # If the pre-commnad hook is not called, `context` remains `None`.
                context = get_context()
                if context is not None and hasattr(context, name):
                    return getattr(context, name)
            except Exception:
                # This line is called only when we encounter an exception
                # in the try clause.
                return f(*args, **kwargs)


return wrapper

Expand Down
22 changes: 22 additions & 0 deletions tests/utils/test_databricks_utils.py
Expand Up @@ -258,6 +258,18 @@ def mock_import(name, *args, **kwargs):


def test_use_repl_context_if_available(tmpdir):
# Simulate a case where `dbruntime.databricks_repl_context.get_context` is unavailable.
with pytest.raises(ModuleNotFoundError, match="No module named 'dbruntime'"):
from dbruntime.databricks_repl_context import get_context # pylint: disable=unused-import

command_context_mock = mock.MagicMock()
command_context_mock.jobId().get.return_value = "job_id"
with mock.patch(
"mlflow.utils.databricks_utils._get_command_context", return_value=command_context_mock
) as mock_get_command_context:
assert databricks_utils.get_job_id() == "job_id"
mock_get_command_context.assert_called_once()

# Create a fake databricks_repl_context module
tmpdir.mkdir("dbruntime").join("databricks_repl_context.py").write(
"""
Expand All @@ -267,6 +279,16 @@ def get_context():
)
sys.path.append(tmpdir.strpath)

# Simulate a case where the REPL context object is not initialized.
with mock.patch(
"dbruntime.databricks_repl_context.get_context",
return_value=None,
) as mock_get_context, mock.patch(
"mlflow.utils.databricks_utils._get_command_context", return_value=command_context_mock
) as mock_get_command_context:
assert databricks_utils.get_job_id() == "job_id"
mock_get_command_context.assert_called_once()

with mock.patch(
"dbruntime.databricks_repl_context.get_context",
return_value=mock.MagicMock(jobId="job_id"),
Expand Down