Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use REPL context attributes if available to avoid calling JVM methods #5132

Merged
merged 22 commits into from Dec 22, 2021
Merged
44 changes: 44 additions & 0 deletions mlflow/utils/databricks_utils.py
@@ -1,6 +1,7 @@
import os
import logging
import subprocess
import functools

from mlflow.exceptions import MlflowException
from mlflow.utils.rest_utils import MlflowHostCreds
Expand All @@ -11,6 +12,32 @@
_logger = logging.getLogger(__name__)


def _use_repl_context_if_available(name):
"""
Creates a decorator to insert a short circuit that returns the specified REPL context attribute
if it's available.

:param name: Attribute name (e.g. "apiUrl").
:return: Decorator to insert the short circuit.
"""

def decorator(f):
@functools.wraps(f)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice use of the decorator factory here. +1

def wrapper(*args, **kwargs):
try:
from dbruntime.databricks_repl_context import get_context

context = get_context()
if context is not None and hasattr(context, name):
return getattr(context, name)
except Exception:
return f(*args, **kwargs)

return wrapper

return decorator


def _get_dbutils():
try:
import IPython
Expand Down Expand Up @@ -50,6 +77,7 @@ def _get_context_tag(context_tag_key):
return None


@_use_repl_context_if_available("aclPathOfAclRoot")
def acl_path_of_acl_root():
try:
return _get_command_context().aclPathOfAclRoot().get()
Expand All @@ -72,6 +100,7 @@ def is_databricks_default_tracking_uri(tracking_uri):
return tracking_uri.lower().strip() == "databricks"


@_use_repl_context_if_available("isInNotebook")
def is_in_databricks_notebook():
if _get_property_from_spark_context("spark.databricks.notebook.id") is not None:
return True
Expand All @@ -81,6 +110,7 @@ def is_in_databricks_notebook():
return False


@_use_repl_context_if_available("isInJob")
def is_in_databricks_job():
try:
return get_job_id() is not None and get_job_run_id() is not None
Expand Down Expand Up @@ -111,6 +141,7 @@ def is_dbfs_fuse_available():
return False


@_use_repl_context_if_available("isInCluster")
def is_in_cluster():
try:
spark_session = _get_active_spark_session()
Expand All @@ -122,6 +153,7 @@ def is_in_cluster():
return False


@_use_repl_context_if_available("notebookId")
def get_notebook_id():
"""Should only be called if is_in_databricks_notebook is true"""
notebook_id = _get_property_from_spark_context("spark.databricks.notebook.id")
Expand All @@ -133,6 +165,7 @@ def get_notebook_id():
return None


@_use_repl_context_if_available("notebookPath")
def get_notebook_path():
"""Should only be called if is_in_databricks_notebook is true"""
path = _get_property_from_spark_context("spark.databricks.notebook.path")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work with ephemeral notebooks within and without jobs?

Expand All @@ -144,6 +177,7 @@ def get_notebook_path():
return _get_extra_context("notebook_path")


@_use_repl_context_if_available("runtimeVersion")
def get_databricks_runtime():
if is_in_databricks_runtime():
spark_session = _get_active_spark_session()
Expand All @@ -154,13 +188,15 @@ def get_databricks_runtime():
return None


@_use_repl_context_if_available("clusterId")
def get_cluster_id():
spark_session = _get_active_spark_session()
if spark_session is None:
return None
return spark_session.conf.get("spark.databricks.clusterUsageTags.clusterId")


@_use_repl_context_if_available("jobGroupId")
def get_job_group_id():
try:
dbutils = _get_dbutils()
Expand All @@ -171,6 +207,7 @@ def get_job_group_id():
return None


@_use_repl_context_if_available("replId")
def get_repl_id():
"""
:return: The ID of the current Databricks Python REPL
Expand Down Expand Up @@ -198,20 +235,23 @@ def get_repl_id():
pass


@_use_repl_context_if_available("jobId")
def get_job_id():
try:
return _get_command_context().jobId().get()
except Exception:
return _get_context_tag("jobId")


@_use_repl_context_if_available("idInJob")
def get_job_run_id():
try:
return _get_command_context().idInJob().get()
except Exception:
return _get_context_tag("idInJob")


@_use_repl_context_if_available("jobTaskType")
def get_job_type():
"""Should only be called if is_in_databricks_job is true"""
try:
Expand All @@ -220,6 +260,7 @@ def get_job_type():
return _get_context_tag("jobTaskType")


@_use_repl_context_if_available("commandRunId")
def get_command_run_id():
try:
return _get_command_context().commandRunId().get()
Expand All @@ -228,6 +269,7 @@ def get_command_run_id():
return None


@_use_repl_context_if_available("apiUrl")
def get_webapp_url():
"""Should only be called if is_in_databricks_notebook or is_in_databricks_jobs is true"""
url = _get_property_from_spark_context("spark.databricks.api.url")
Expand All @@ -239,13 +281,15 @@ def get_webapp_url():
return _get_extra_context("api_url")


@_use_repl_context_if_available("workspaceId")
def get_workspace_id():
try:
return _get_command_context().workspaceId().get()
except Exception:
return _get_context_tag("orgId")


@_use_repl_context_if_available("browserHostName")
def get_browser_hostname():
try:
return _get_command_context().browserHostName().get()
Expand Down
39 changes: 39 additions & 0 deletions tests/utils/test_databricks_utils.py
Expand Up @@ -255,3 +255,42 @@ def mock_import(name, *args, **kwargs):

with mock.patch("builtins.__import__", side_effect=mock_import):
assert databricks_utils.get_repl_id() == "testReplId2"


def test_use_repl_context_if_available(tmpdir):
# Create a fake databricks_repl_context module
tmpdir.mkdir("dbruntime").join("databricks_repl_context.py").write(
"""
def get_context():
pass
"""
)
sys.path.append(tmpdir.strpath)

with mock.patch(
"dbruntime.databricks_repl_context.get_context",
return_value=mock.MagicMock(jobId="job_id"),
) as mock_get_context, mock.patch("mlflow.utils.databricks_utils._get_dbutils") as mock_dbutils:
assert databricks_utils.get_job_id() == "job_id"
mock_get_context.assert_called_once()
mock_dbutils.assert_not_called()

with mock.patch(
"dbruntime.databricks_repl_context.get_context",
return_value=mock.MagicMock(notebookId="notebook_id"),
) as mock_get_context, mock.patch(
"mlflow.utils.databricks_utils._get_property_from_spark_context"
) as mock_spark_context:
assert databricks_utils.get_notebook_id() == "notebook_id"
mock_get_context.assert_called_once()
mock_spark_context.assert_not_called()

with mock.patch(
"dbruntime.databricks_repl_context.get_context",
return_value=mock.MagicMock(isInCluster=True),
) as mock_get_context, mock.patch(
"mlflow.utils._spark_utils._get_active_spark_session"
) as mock_spark_session:
assert databricks_utils.is_in_cluster()
mock_get_context.assert_called_once()
mock_spark_session.assert_not_called()