Skip to content

Commit

Permalink
use boolean attributes
Browse files Browse the repository at this point in the history
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
  • Loading branch information
harupy committed Dec 16, 2021
1 parent 3c3676c commit 5ac0475
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 25 deletions.
34 changes: 13 additions & 21 deletions mlflow/utils/databricks_utils.py
Expand Up @@ -13,42 +13,33 @@


def _get_context_attribute(name):
try:
from dbruntime.databricks_repl_context import get_context
from dbruntime.databricks_repl_context import get_context

return getattr(get_context(), name)
except Exception:
return None
return getattr(get_context(), name)


def _use_context_attribute_if_available(name, *, if_available=lambda x: x):
def _use_context_attribute_if_available(name):
"""
Creates a decorator to insert a short circuit that returns `if_available(name)`
if the specified context attribute is available.
Creates a decorator to insert a short circuit that returns the specified context attribute if
it's available.
:param name: Context attribute name.
:param if_available: Function to evaluate when the specified context attribute is available.
Defaults to `lambda x: x`.
:param name: Context attribute name (e.g. "api_url").
:return: Decorator to insert the short circuit.
"""

def decorator(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
metadata = _get_context_attribute(name)
if metadata:
return if_available(metadata)
return f(*args, **kwargs)
try:
return _get_context_attribute(name)
except Exception:
return f(*args, **kwargs)

return wrapper

return decorator


def _return_true(_):
return True


def _get_dbutils():
try:
import IPython
Expand Down Expand Up @@ -111,7 +102,7 @@ def is_databricks_default_tracking_uri(tracking_uri):
return tracking_uri.lower().strip() == "databricks"


@_use_context_attribute_if_available("notebookId", if_available=_return_true)
@_use_context_attribute_if_available("isInNotebook")
def is_in_databricks_notebook():
if _get_property_from_spark_context("spark.databricks.notebook.id") is not None:
return True
Expand All @@ -121,6 +112,7 @@ def is_in_databricks_notebook():
return False


@_use_context_attribute_if_available("isInJob")
def is_in_databricks_job():
try:
return get_job_id() is not None and get_job_run_id() is not None
Expand Down Expand Up @@ -151,7 +143,7 @@ def is_dbfs_fuse_available():
return False


@_use_context_attribute_if_available("clusterId", if_available=_return_true)
@_use_context_attribute_if_available("isInCluster")
def is_in_cluster():
try:
spark_session = _get_active_spark_session()
Expand Down
8 changes: 4 additions & 4 deletions tests/utils/test_databricks_utils.py
Expand Up @@ -270,26 +270,26 @@ def test_use_context_attribute_if_available():

with mock.patch(
"mlflow.utils.databricks_utils._get_context_attribute",
return_value="notebook_id",
side_effect={"notebookId": "notebook_id", "isInNotebook": True}.get,
) as mock_context_metadata, mock.patch(
"mlflow.utils.databricks_utils._get_property_from_spark_context"
) as mock_spark_context:
assert databricks_utils.get_notebook_id() == "notebook_id"
mock_context_metadata.assert_called_once_with("notebookId")
mock_context_metadata.reset_mock()
assert databricks_utils.is_in_databricks_notebook()
mock_context_metadata.assert_called_once_with("notebookId")
mock_context_metadata.assert_called_once_with("isInNotebook")
mock_spark_context.assert_not_called()

with mock.patch(
"mlflow.utils.databricks_utils._get_context_attribute",
return_value="cluster_id",
side_effect={"clusterId": "cluster_id", "isInCluster": True}.get,
) as mock_context_metadata, mock.patch(
"mlflow.utils._spark_utils._get_active_spark_session"
) as mock_spark_session:
assert databricks_utils.get_cluster_id() == "cluster_id"
mock_context_metadata.assert_called_once_with("clusterId")
mock_context_metadata.reset_mock()
assert databricks_utils.is_in_cluster()
mock_context_metadata.assert_called_once_with("clusterId")
mock_context_metadata.assert_called_once_with("isInCluster")
mock_spark_session.assert_not_called()

0 comments on commit 5ac0475

Please sign in to comment.