Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use REPL context attributes if available to avoid calling JVM methods #5132

Merged
merged 22 commits into from Dec 22, 2021
Merged
44 changes: 44 additions & 0 deletions mlflow/utils/databricks_utils.py
@@ -1,6 +1,7 @@
import os
import logging
import subprocess
import functools

from mlflow.exceptions import MlflowException
from mlflow.utils.rest_utils import MlflowHostCreds
Expand All @@ -10,6 +11,34 @@

_logger = logging.getLogger(__name__)

_ENV_VAR_PREFIX = "DATABRICKS_"


def _use_env_var_if_exists(env_var, *, if_exists=os.getenv):
"""
Creates a decorator to insert a short circuit that returns `if_exists(env_var)` if
the environment variable `env_var` exists.

:param env_var: The name of an environment variable to use.
:param if_exists: A function to evaluate if `env_var` exists. Defaults to `os.getenv`.
:return: A decorator to insert the short circuit.
"""

def decorator(f):
@functools.wraps(f)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice use of the decorator factory here. +1

def wrapper(*args, **kwargs):
if env_var in os.environ:
return if_exists(env_var)
return f(*args, **kwargs)

return wrapper

return decorator


def _return_true(_):
return True


def _get_dbutils():
try:
Expand Down Expand Up @@ -50,6 +79,7 @@ def _get_context_tag(context_tag_key):
return None


@_use_env_var_if_exists(_ENV_VAR_PREFIX + "ACL_PATH_OF_ACL_ROOT")
def acl_path_of_acl_root():
try:
return _get_command_context().aclPathOfAclRoot().get()
Expand All @@ -72,6 +102,7 @@ def is_databricks_default_tracking_uri(tracking_uri):
return tracking_uri.lower().strip() == "databricks"


@_use_env_var_if_exists(_ENV_VAR_PREFIX + "NOTEBOOK_ID", if_exists=_return_true)
def is_in_databricks_notebook():
if _get_property_from_spark_context("spark.databricks.notebook.id") is not None:
return True
Expand Down Expand Up @@ -111,6 +142,7 @@ def is_dbfs_fuse_available():
return False


@_use_env_var_if_exists(_ENV_VAR_PREFIX + "CLUSTER_ID", if_exists=_return_true)
def is_in_cluster():
try:
spark_session = _get_active_spark_session()
Expand All @@ -122,6 +154,7 @@ def is_in_cluster():
return False


@_use_env_var_if_exists(_ENV_VAR_PREFIX + "NOTEBOOK_ID")
def get_notebook_id():
"""Should only be called if is_in_databricks_notebook is true"""
notebook_id = _get_property_from_spark_context("spark.databricks.notebook.id")
Expand All @@ -133,6 +166,7 @@ def get_notebook_id():
return None


@_use_env_var_if_exists(_ENV_VAR_PREFIX + "NOTEBOOK_PATH")
def get_notebook_path():
"""Should only be called if is_in_databricks_notebook is true"""
path = _get_property_from_spark_context("spark.databricks.notebook.path")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work with ephemeral notebooks within and without jobs?

Expand All @@ -144,6 +178,7 @@ def get_notebook_path():
return _get_extra_context("notebook_path")


@_use_env_var_if_exists(_ENV_VAR_PREFIX + "RUNTIME_VERSION")
def get_databricks_runtime():
if is_in_databricks_runtime():
spark_session = _get_active_spark_session()
Expand All @@ -154,13 +189,15 @@ def get_databricks_runtime():
return None


@_use_env_var_if_exists(_ENV_VAR_PREFIX + "CLUSTER_ID")
def get_cluster_id():
spark_session = _get_active_spark_session()
if spark_session is None:
return None
return spark_session.conf.get("spark.databricks.clusterUsageTags.clusterId")


@_use_env_var_if_exists(_ENV_VAR_PREFIX + "JOB_GROUP_ID")
def get_job_group_id():
try:
dbutils = _get_dbutils()
Expand All @@ -171,6 +208,7 @@ def get_job_group_id():
return None


@_use_env_var_if_exists(_ENV_VAR_PREFIX + "REPL_ID")
def get_repl_id():
"""
:return: The ID of the current Databricks Python REPL
Expand Down Expand Up @@ -198,20 +236,23 @@ def get_repl_id():
pass


@_use_env_var_if_exists(_ENV_VAR_PREFIX + "JOB_ID")
def get_job_id():
try:
return _get_command_context().jobId().get()
except Exception:
return _get_context_tag("jobId")


@_use_env_var_if_exists(_ENV_VAR_PREFIX + "ID_IN_JOB")
def get_job_run_id():
try:
return _get_command_context().idInJob().get()
except Exception:
return _get_context_tag("idInJob")


@_use_env_var_if_exists(_ENV_VAR_PREFIX + "JOB_TASK_TYPE")
def get_job_type():
"""Should only be called if is_in_databricks_job is true"""
try:
Expand All @@ -228,6 +269,7 @@ def get_command_run_id():
return None


@_use_env_var_if_exists(_ENV_VAR_PREFIX + "API_URL")
def get_webapp_url():
"""Should only be called if is_in_databricks_notebook or is_in_databricks_jobs is true"""
url = _get_property_from_spark_context("spark.databricks.api.url")
Expand All @@ -239,13 +281,15 @@ def get_webapp_url():
return _get_extra_context("api_url")


@_use_env_var_if_exists(_ENV_VAR_PREFIX + "WORKSPACE_ID")
def get_workspace_id():
try:
return _get_command_context().workspaceId().get()
except Exception:
return _get_context_tag("orgId")


@_use_env_var_if_exists(_ENV_VAR_PREFIX + "BROWSER_HOST_NAME")
def get_browser_hostname():
try:
return _get_command_context().browserHostName().get()
Expand Down
23 changes: 22 additions & 1 deletion tests/utils/test_databricks_utils.py
Expand Up @@ -11,7 +11,7 @@
is_databricks_default_tracking_uri,
)
from mlflow.utils.uri import construct_db_uri_from_profile
from tests.helper_functions import mock_method_chain
from tests.helper_functions import mock_method_chain, multi_context


def test_no_throw():
Expand Down Expand Up @@ -254,3 +254,24 @@ def mock_import(name, *args, **kwargs):

with mock.patch("builtins.__import__", side_effect=mock_import):
assert databricks_utils.get_repl_id() == "testReplId2"


def test_use_env_var_if_exists():
with mock.patch.dict(
"os.environ",
{
databricks_utils._ENV_VAR_PREFIX + "NOTEBOOK_ID": "1",
databricks_utils._ENV_VAR_PREFIX + "CLUSTER_ID": "a",
},
clear=True,
):
with multi_context(
mock.patch("mlflow.utils.databricks_utils._get_dbutils"),
mock.patch("mlflow.utils.databricks_utils._get_property_from_spark_context"),
mock.patch("mlflow.utils._spark_utils._get_active_spark_session"),
) as mocks:
assert databricks_utils.get_notebook_id() == "1"
assert databricks_utils.is_in_databricks_notebook()
assert databricks_utils.get_cluster_id() == "a"
assert databricks_utils.is_in_cluster()
assert all(m.call_count == 0 for m in mocks)