diff --git a/mlflow/utils/databricks_utils.py b/mlflow/utils/databricks_utils.py index f088af5c83ef5..f9701dca1d0b5 100644 --- a/mlflow/utils/databricks_utils.py +++ b/mlflow/utils/databricks_utils.py @@ -1,6 +1,7 @@ import os import logging import subprocess +import functools from mlflow.exceptions import MlflowException from mlflow.utils.rest_utils import MlflowHostCreds @@ -11,6 +12,32 @@ _logger = logging.getLogger(__name__) +def _use_repl_context_if_available(name): + """ + Creates a decorator to insert a short circuit that returns the specified REPL context attribute + if it's available. + + :param name: Attribute name (e.g. "apiUrl"). + :return: Decorator to insert the short circuit. + """ + + def decorator(f): + @functools.wraps(f) + def wrapper(*args, **kwargs): + try: + from dbruntime.databricks_repl_context import get_context + + context = get_context() + if context is not None and hasattr(context, name): + return getattr(context, name) + except Exception: + return f(*args, **kwargs) + + return wrapper + + return decorator + + def _get_dbutils(): try: import IPython @@ -50,6 +77,7 @@ def _get_context_tag(context_tag_key): return None +@_use_repl_context_if_available("aclPathOfAclRoot") def acl_path_of_acl_root(): try: return _get_command_context().aclPathOfAclRoot().get() @@ -72,6 +100,7 @@ def is_databricks_default_tracking_uri(tracking_uri): return tracking_uri.lower().strip() == "databricks" +@_use_repl_context_if_available("isInNotebook") def is_in_databricks_notebook(): if _get_property_from_spark_context("spark.databricks.notebook.id") is not None: return True @@ -81,6 +110,7 @@ def is_in_databricks_notebook(): return False +@_use_repl_context_if_available("isInJob") def is_in_databricks_job(): try: return get_job_id() is not None and get_job_run_id() is not None @@ -111,6 +141,7 @@ def is_dbfs_fuse_available(): return False +@_use_repl_context_if_available("isInCluster") def is_in_cluster(): try: spark_session = _get_active_spark_session() @@ -122,6 +153,7 @@ def is_in_cluster(): return False +@_use_repl_context_if_available("notebookId") def get_notebook_id(): """Should only be called if is_in_databricks_notebook is true""" notebook_id = _get_property_from_spark_context("spark.databricks.notebook.id") @@ -133,6 +165,7 @@ def get_notebook_id(): return None +@_use_repl_context_if_available("notebookPath") def get_notebook_path(): """Should only be called if is_in_databricks_notebook is true""" path = _get_property_from_spark_context("spark.databricks.notebook.path") @@ -144,6 +177,7 @@ def get_notebook_path(): return _get_extra_context("notebook_path") +@_use_repl_context_if_available("runtimeVersion") def get_databricks_runtime(): if is_in_databricks_runtime(): spark_session = _get_active_spark_session() @@ -154,6 +188,7 @@ def get_databricks_runtime(): return None +@_use_repl_context_if_available("clusterId") def get_cluster_id(): spark_session = _get_active_spark_session() if spark_session is None: @@ -161,6 +196,7 @@ def get_cluster_id(): return spark_session.conf.get("spark.databricks.clusterUsageTags.clusterId") +@_use_repl_context_if_available("jobGroupId") def get_job_group_id(): try: dbutils = _get_dbutils() @@ -171,6 +207,7 @@ def get_job_group_id(): return None +@_use_repl_context_if_available("replId") def get_repl_id(): """ :return: The ID of the current Databricks Python REPL @@ -198,6 +235,7 @@ def get_repl_id(): pass +@_use_repl_context_if_available("jobId") def get_job_id(): try: return _get_command_context().jobId().get() @@ -205,6 +243,7 @@ def get_job_id(): return _get_context_tag("jobId") +@_use_repl_context_if_available("idInJob") def get_job_run_id(): try: return _get_command_context().idInJob().get() @@ -212,6 +251,7 @@ def get_job_run_id(): return _get_context_tag("idInJob") +@_use_repl_context_if_available("jobTaskType") def get_job_type(): """Should only be called if is_in_databricks_job is true""" try: @@ -220,6 +260,7 @@ def get_job_type(): return _get_context_tag("jobTaskType") +@_use_repl_context_if_available("commandRunId") def get_command_run_id(): try: return _get_command_context().commandRunId().get() @@ -228,6 +269,7 @@ def get_command_run_id(): return None +@_use_repl_context_if_available("apiUrl") def get_webapp_url(): """Should only be called if is_in_databricks_notebook or is_in_databricks_jobs is true""" url = _get_property_from_spark_context("spark.databricks.api.url") @@ -239,6 +281,7 @@ def get_webapp_url(): return _get_extra_context("api_url") +@_use_repl_context_if_available("workspaceId") def get_workspace_id(): try: return _get_command_context().workspaceId().get() @@ -246,6 +289,7 @@ def get_workspace_id(): return _get_context_tag("orgId") +@_use_repl_context_if_available("browserHostName") def get_browser_hostname(): try: return _get_command_context().browserHostName().get() diff --git a/tests/utils/test_databricks_utils.py b/tests/utils/test_databricks_utils.py index 816598234f895..066e1ccc398ac 100644 --- a/tests/utils/test_databricks_utils.py +++ b/tests/utils/test_databricks_utils.py @@ -255,3 +255,42 @@ def mock_import(name, *args, **kwargs): with mock.patch("builtins.__import__", side_effect=mock_import): assert databricks_utils.get_repl_id() == "testReplId2" + + +def test_use_repl_context_if_available(tmpdir): + # Create a fake databricks_repl_context module + tmpdir.mkdir("dbruntime").join("databricks_repl_context.py").write( + """ +def get_context(): + pass +""" + ) + sys.path.append(tmpdir.strpath) + + with mock.patch( + "dbruntime.databricks_repl_context.get_context", + return_value=mock.MagicMock(jobId="job_id"), + ) as mock_get_context, mock.patch("mlflow.utils.databricks_utils._get_dbutils") as mock_dbutils: + assert databricks_utils.get_job_id() == "job_id" + mock_get_context.assert_called_once() + mock_dbutils.assert_not_called() + + with mock.patch( + "dbruntime.databricks_repl_context.get_context", + return_value=mock.MagicMock(notebookId="notebook_id"), + ) as mock_get_context, mock.patch( + "mlflow.utils.databricks_utils._get_property_from_spark_context" + ) as mock_spark_context: + assert databricks_utils.get_notebook_id() == "notebook_id" + mock_get_context.assert_called_once() + mock_spark_context.assert_not_called() + + with mock.patch( + "dbruntime.databricks_repl_context.get_context", + return_value=mock.MagicMock(isInCluster=True), + ) as mock_get_context, mock.patch( + "mlflow.utils._spark_utils._get_active_spark_session" + ) as mock_spark_session: + assert databricks_utils.is_in_cluster() + mock_get_context.assert_called_once() + mock_spark_session.assert_not_called()