Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use REPL context attributes if available to avoid calling JVM methods #5132

Merged
merged 22 commits into from Dec 22, 2021
Merged
54 changes: 54 additions & 0 deletions mlflow/utils/databricks_utils.py
@@ -1,6 +1,7 @@
import os
import logging
import subprocess
import functools

from mlflow.exceptions import MlflowException
from mlflow.utils.rest_utils import MlflowHostCreds
Expand All @@ -11,6 +12,43 @@
_logger = logging.getLogger(__name__)


def _get_context_attribute(name):
try:
from dbruntime.databricks_repl_context import get_context

return getattr(get_context(), name)
except Exception:
return None


def _use_context_attribute_if_available(name, *, if_available=lambda x: x):
"""
Creates a decorator to insert a short circuit that returns `if_available(name)`
if the specified context attribute is available.

:param name: Context attribute name.
:param if_available: Function to evaluate when the specified context attribute is available.
Defaults to `lambda x: x`.
:return: Decorator to insert the short circuit.
"""

def decorator(f):
@functools.wraps(f)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice use of the decorator factory here. +1

def wrapper(*args, **kwargs):
metadata = _get_context_attribute(name)
if metadata:
return if_available(metadata)
return f(*args, **kwargs)

return wrapper

return decorator


def _return_true(_):
return True


def _get_dbutils():
try:
import IPython
Expand Down Expand Up @@ -50,6 +88,7 @@ def _get_context_tag(context_tag_key):
return None


@_use_context_attribute_if_available("aclPathOfAclRoot")
def acl_path_of_acl_root():
try:
return _get_command_context().aclPathOfAclRoot().get()
Expand All @@ -72,6 +111,7 @@ def is_databricks_default_tracking_uri(tracking_uri):
return tracking_uri.lower().strip() == "databricks"


@_use_context_attribute_if_available("notebookId", if_available=_return_true)
def is_in_databricks_notebook():
if _get_property_from_spark_context("spark.databricks.notebook.id") is not None:
return True
Expand Down Expand Up @@ -111,6 +151,7 @@ def is_dbfs_fuse_available():
return False


@_use_context_attribute_if_available("clusterId", if_available=_return_true)
def is_in_cluster():
try:
spark_session = _get_active_spark_session()
Expand All @@ -122,6 +163,7 @@ def is_in_cluster():
return False


@_use_context_attribute_if_available("notebookId")
def get_notebook_id():
"""Should only be called if is_in_databricks_notebook is true"""
notebook_id = _get_property_from_spark_context("spark.databricks.notebook.id")
Expand All @@ -133,6 +175,7 @@ def get_notebook_id():
return None


@_use_context_attribute_if_available("notebookPath")
def get_notebook_path():
"""Should only be called if is_in_databricks_notebook is true"""
path = _get_property_from_spark_context("spark.databricks.notebook.path")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work with ephemeral notebooks within and without jobs?

Expand All @@ -144,6 +187,7 @@ def get_notebook_path():
return _get_extra_context("notebook_path")


@_use_context_attribute_if_available("runtimeVersion")
def get_databricks_runtime():
if is_in_databricks_runtime():
spark_session = _get_active_spark_session()
Expand All @@ -154,13 +198,15 @@ def get_databricks_runtime():
return None


@_use_context_attribute_if_available("clusterId")
def get_cluster_id():
spark_session = _get_active_spark_session()
if spark_session is None:
return None
return spark_session.conf.get("spark.databricks.clusterUsageTags.clusterId")


@_use_context_attribute_if_available("jobGroupId")
def get_job_group_id():
try:
dbutils = _get_dbutils()
Expand All @@ -171,6 +217,7 @@ def get_job_group_id():
return None


@_use_context_attribute_if_available("replId")
def get_repl_id():
"""
:return: The ID of the current Databricks Python REPL
Expand Down Expand Up @@ -198,20 +245,23 @@ def get_repl_id():
pass


@_use_context_attribute_if_available("jobId")
def get_job_id():
try:
return _get_command_context().jobId().get()
except Exception:
return _get_context_tag("jobId")


@_use_context_attribute_if_available("idInJob")
def get_job_run_id():
try:
return _get_command_context().idInJob().get()
except Exception:
return _get_context_tag("idInJob")


@_use_context_attribute_if_available("jobTaskType")
def get_job_type():
"""Should only be called if is_in_databricks_job is true"""
try:
Expand All @@ -220,6 +270,7 @@ def get_job_type():
return _get_context_tag("jobTaskType")


@_use_context_attribute_if_available("commandRunId")
def get_command_run_id():
try:
return _get_command_context().commandRunId().get()
Expand All @@ -228,6 +279,7 @@ def get_command_run_id():
return None


@_use_context_attribute_if_available("apiUrl")
def get_webapp_url():
"""Should only be called if is_in_databricks_notebook or is_in_databricks_jobs is true"""
url = _get_property_from_spark_context("spark.databricks.api.url")
Expand All @@ -239,13 +291,15 @@ def get_webapp_url():
return _get_extra_context("api_url")


@_use_context_attribute_if_available("workspaceId")
def get_workspace_id():
try:
return _get_command_context().workspaceId().get()
except Exception:
return _get_context_tag("orgId")


@_use_context_attribute_if_available("browserHostName")
def get_browser_hostname():
try:
return _get_command_context().browserHostName().get()
Expand Down
39 changes: 39 additions & 0 deletions tests/utils/test_databricks_utils.py
Expand Up @@ -254,3 +254,42 @@ def mock_import(name, *args, **kwargs):

with mock.patch("builtins.__import__", side_effect=mock_import):
assert databricks_utils.get_repl_id() == "testReplId2"


def test_use_context_attribute_if_available():

with mock.patch(
"mlflow.utils.databricks_utils._get_context_attribute",
return_value="job_id",
) as mock_context_metadata, mock.patch(
"mlflow.utils.databricks_utils._get_dbutils"
) as mock_dbutils:
assert databricks_utils.get_job_id() == "job_id"
mock_context_metadata.assert_called_once_with("jobId")
mock_dbutils.assert_not_called()

with mock.patch(
"mlflow.utils.databricks_utils._get_context_attribute",
return_value="notebook_id",
) as mock_context_metadata, mock.patch(
"mlflow.utils.databricks_utils._get_property_from_spark_context"
) as mock_spark_context:
assert databricks_utils.get_notebook_id() == "notebook_id"
mock_context_metadata.assert_called_once_with("notebookId")
mock_context_metadata.reset_mock()
assert databricks_utils.is_in_databricks_notebook()
mock_context_metadata.assert_called_once_with("notebookId")
mock_spark_context.assert_not_called()

with mock.patch(
"mlflow.utils.databricks_utils._get_context_attribute",
return_value="cluster_id",
) as mock_context_metadata, mock.patch(
"mlflow.utils._spark_utils._get_active_spark_session"
) as mock_spark_session:
assert databricks_utils.get_cluster_id() == "cluster_id"
mock_context_metadata.assert_called_once_with("clusterId")
mock_context_metadata.reset_mock()
assert databricks_utils.is_in_cluster()
mock_context_metadata.assert_called_once_with("clusterId")
mock_spark_session.assert_not_called()