Skip to content

Commit

Permalink
Use REPL context attributes if available to avoid calling JVM methods (
Browse files Browse the repository at this point in the history
…#5132)

* update get_notebook_path

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* add decorator

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* pass func to wraps

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* refactor

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* update docstring

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* test

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* fix

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* hardcode prefix

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* add _use_message_metadata_if_exists

Signed-off-by: harupy <hkawamura0130@gmail.com>

* use context metadata

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* debug

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* fix attr name

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* remove print

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* fix tests

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* fix docstring

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* use get_context

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* rename functions

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* fix module name

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* fix module name

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* use boolean attributes

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* refactor

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
  • Loading branch information
harupy committed Dec 22, 2021
1 parent e1f0a24 commit 852f567
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 0 deletions.
44 changes: 44 additions & 0 deletions mlflow/utils/databricks_utils.py
@@ -1,6 +1,7 @@
import os
import logging
import subprocess
import functools

from mlflow.exceptions import MlflowException
from mlflow.utils.rest_utils import MlflowHostCreds
Expand All @@ -11,6 +12,32 @@
_logger = logging.getLogger(__name__)


def _use_repl_context_if_available(name):
"""
Creates a decorator to insert a short circuit that returns the specified REPL context attribute
if it's available.
:param name: Attribute name (e.g. "apiUrl").
:return: Decorator to insert the short circuit.
"""

def decorator(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
try:
from dbruntime.databricks_repl_context import get_context

context = get_context()
if context is not None and hasattr(context, name):
return getattr(context, name)
except Exception:
return f(*args, **kwargs)

return wrapper

return decorator


def _get_dbutils():
try:
import IPython
Expand Down Expand Up @@ -50,6 +77,7 @@ def _get_context_tag(context_tag_key):
return None


@_use_repl_context_if_available("aclPathOfAclRoot")
def acl_path_of_acl_root():
try:
return _get_command_context().aclPathOfAclRoot().get()
Expand All @@ -72,6 +100,7 @@ def is_databricks_default_tracking_uri(tracking_uri):
return tracking_uri.lower().strip() == "databricks"


@_use_repl_context_if_available("isInNotebook")
def is_in_databricks_notebook():
if _get_property_from_spark_context("spark.databricks.notebook.id") is not None:
return True
Expand All @@ -81,6 +110,7 @@ def is_in_databricks_notebook():
return False


@_use_repl_context_if_available("isInJob")
def is_in_databricks_job():
try:
return get_job_id() is not None and get_job_run_id() is not None
Expand Down Expand Up @@ -111,6 +141,7 @@ def is_dbfs_fuse_available():
return False


@_use_repl_context_if_available("isInCluster")
def is_in_cluster():
try:
spark_session = _get_active_spark_session()
Expand All @@ -122,6 +153,7 @@ def is_in_cluster():
return False


@_use_repl_context_if_available("notebookId")
def get_notebook_id():
"""Should only be called if is_in_databricks_notebook is true"""
notebook_id = _get_property_from_spark_context("spark.databricks.notebook.id")
Expand All @@ -133,6 +165,7 @@ def get_notebook_id():
return None


@_use_repl_context_if_available("notebookPath")
def get_notebook_path():
"""Should only be called if is_in_databricks_notebook is true"""
path = _get_property_from_spark_context("spark.databricks.notebook.path")
Expand All @@ -144,6 +177,7 @@ def get_notebook_path():
return _get_extra_context("notebook_path")


@_use_repl_context_if_available("runtimeVersion")
def get_databricks_runtime():
if is_in_databricks_runtime():
spark_session = _get_active_spark_session()
Expand All @@ -154,13 +188,15 @@ def get_databricks_runtime():
return None


@_use_repl_context_if_available("clusterId")
def get_cluster_id():
spark_session = _get_active_spark_session()
if spark_session is None:
return None
return spark_session.conf.get("spark.databricks.clusterUsageTags.clusterId")


@_use_repl_context_if_available("jobGroupId")
def get_job_group_id():
try:
dbutils = _get_dbutils()
Expand All @@ -171,6 +207,7 @@ def get_job_group_id():
return None


@_use_repl_context_if_available("replId")
def get_repl_id():
"""
:return: The ID of the current Databricks Python REPL
Expand Down Expand Up @@ -198,20 +235,23 @@ def get_repl_id():
pass


@_use_repl_context_if_available("jobId")
def get_job_id():
try:
return _get_command_context().jobId().get()
except Exception:
return _get_context_tag("jobId")


@_use_repl_context_if_available("idInJob")
def get_job_run_id():
try:
return _get_command_context().idInJob().get()
except Exception:
return _get_context_tag("idInJob")


@_use_repl_context_if_available("jobTaskType")
def get_job_type():
"""Should only be called if is_in_databricks_job is true"""
try:
Expand All @@ -220,6 +260,7 @@ def get_job_type():
return _get_context_tag("jobTaskType")


@_use_repl_context_if_available("commandRunId")
def get_command_run_id():
try:
return _get_command_context().commandRunId().get()
Expand All @@ -228,6 +269,7 @@ def get_command_run_id():
return None


@_use_repl_context_if_available("apiUrl")
def get_webapp_url():
"""Should only be called if is_in_databricks_notebook or is_in_databricks_jobs is true"""
url = _get_property_from_spark_context("spark.databricks.api.url")
Expand All @@ -239,13 +281,15 @@ def get_webapp_url():
return _get_extra_context("api_url")


@_use_repl_context_if_available("workspaceId")
def get_workspace_id():
try:
return _get_command_context().workspaceId().get()
except Exception:
return _get_context_tag("orgId")


@_use_repl_context_if_available("browserHostName")
def get_browser_hostname():
try:
return _get_command_context().browserHostName().get()
Expand Down
39 changes: 39 additions & 0 deletions tests/utils/test_databricks_utils.py
Expand Up @@ -255,3 +255,42 @@ def mock_import(name, *args, **kwargs):

with mock.patch("builtins.__import__", side_effect=mock_import):
assert databricks_utils.get_repl_id() == "testReplId2"


def test_use_repl_context_if_available(tmpdir):
# Create a fake databricks_repl_context module
tmpdir.mkdir("dbruntime").join("databricks_repl_context.py").write(
"""
def get_context():
pass
"""
)
sys.path.append(tmpdir.strpath)

with mock.patch(
"dbruntime.databricks_repl_context.get_context",
return_value=mock.MagicMock(jobId="job_id"),
) as mock_get_context, mock.patch("mlflow.utils.databricks_utils._get_dbutils") as mock_dbutils:
assert databricks_utils.get_job_id() == "job_id"
mock_get_context.assert_called_once()
mock_dbutils.assert_not_called()

with mock.patch(
"dbruntime.databricks_repl_context.get_context",
return_value=mock.MagicMock(notebookId="notebook_id"),
) as mock_get_context, mock.patch(
"mlflow.utils.databricks_utils._get_property_from_spark_context"
) as mock_spark_context:
assert databricks_utils.get_notebook_id() == "notebook_id"
mock_get_context.assert_called_once()
mock_spark_context.assert_not_called()

with mock.patch(
"dbruntime.databricks_repl_context.get_context",
return_value=mock.MagicMock(isInCluster=True),
) as mock_get_context, mock.patch(
"mlflow.utils._spark_utils._get_active_spark_session"
) as mock_spark_session:
assert databricks_utils.is_in_cluster()
mock_get_context.assert_called_once()
mock_spark_session.assert_not_called()

0 comments on commit 852f567

Please sign in to comment.