From 5ac047586a413a3854f3c6249328c766aa9756c2 Mon Sep 17 00:00:00 2001 From: harupy <17039389+harupy@users.noreply.github.com> Date: Thu, 16 Dec 2021 09:29:27 +0900 Subject: [PATCH] use boolean attributes Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- mlflow/utils/databricks_utils.py | 34 +++++++++++----------------- tests/utils/test_databricks_utils.py | 8 +++---- 2 files changed, 17 insertions(+), 25 deletions(-) diff --git a/mlflow/utils/databricks_utils.py b/mlflow/utils/databricks_utils.py index f5d774b0c915b..70d84629cc7ff 100644 --- a/mlflow/utils/databricks_utils.py +++ b/mlflow/utils/databricks_utils.py @@ -13,42 +13,33 @@ def _get_context_attribute(name): - try: - from dbruntime.databricks_repl_context import get_context + from dbruntime.databricks_repl_context import get_context - return getattr(get_context(), name) - except Exception: - return None + return getattr(get_context(), name) -def _use_context_attribute_if_available(name, *, if_available=lambda x: x): +def _use_context_attribute_if_available(name): """ - Creates a decorator to insert a short circuit that returns `if_available(name)` - if the specified context attribute is available. + Creates a decorator to insert a short circuit that returns the specified context attribute if + it's available. - :param name: Context attribute name. - :param if_available: Function to evaluate when the specified context attribute is available. - Defaults to `lambda x: x`. + :param name: Context attribute name (e.g. "api_url"). :return: Decorator to insert the short circuit. """ def decorator(f): @functools.wraps(f) def wrapper(*args, **kwargs): - metadata = _get_context_attribute(name) - if metadata: - return if_available(metadata) - return f(*args, **kwargs) + try: + return _get_context_attribute(name) + except Exception: + return f(*args, **kwargs) return wrapper return decorator -def _return_true(_): - return True - - def _get_dbutils(): try: import IPython @@ -111,7 +102,7 @@ def is_databricks_default_tracking_uri(tracking_uri): return tracking_uri.lower().strip() == "databricks" -@_use_context_attribute_if_available("notebookId", if_available=_return_true) +@_use_context_attribute_if_available("isInNotebook") def is_in_databricks_notebook(): if _get_property_from_spark_context("spark.databricks.notebook.id") is not None: return True @@ -121,6 +112,7 @@ def is_in_databricks_notebook(): return False +@_use_context_attribute_if_available("isInJob") def is_in_databricks_job(): try: return get_job_id() is not None and get_job_run_id() is not None @@ -151,7 +143,7 @@ def is_dbfs_fuse_available(): return False -@_use_context_attribute_if_available("clusterId", if_available=_return_true) +@_use_context_attribute_if_available("isInCluster") def is_in_cluster(): try: spark_session = _get_active_spark_session() diff --git a/tests/utils/test_databricks_utils.py b/tests/utils/test_databricks_utils.py index 3b7f7405804a4..66d1b510f54b1 100644 --- a/tests/utils/test_databricks_utils.py +++ b/tests/utils/test_databricks_utils.py @@ -270,7 +270,7 @@ def test_use_context_attribute_if_available(): with mock.patch( "mlflow.utils.databricks_utils._get_context_attribute", - return_value="notebook_id", + side_effect={"notebookId": "notebook_id", "isInNotebook": True}.get, ) as mock_context_metadata, mock.patch( "mlflow.utils.databricks_utils._get_property_from_spark_context" ) as mock_spark_context: @@ -278,12 +278,12 @@ def test_use_context_attribute_if_available(): mock_context_metadata.assert_called_once_with("notebookId") mock_context_metadata.reset_mock() assert databricks_utils.is_in_databricks_notebook() - mock_context_metadata.assert_called_once_with("notebookId") + mock_context_metadata.assert_called_once_with("isInNotebook") mock_spark_context.assert_not_called() with mock.patch( "mlflow.utils.databricks_utils._get_context_attribute", - return_value="cluster_id", + side_effect={"clusterId": "cluster_id", "isInCluster": True}.get, ) as mock_context_metadata, mock.patch( "mlflow.utils._spark_utils._get_active_spark_session" ) as mock_spark_session: @@ -291,5 +291,5 @@ def test_use_context_attribute_if_available(): mock_context_metadata.assert_called_once_with("clusterId") mock_context_metadata.reset_mock() assert databricks_utils.is_in_cluster() - mock_context_metadata.assert_called_once_with("clusterId") + mock_context_metadata.assert_called_once_with("isInCluster") mock_spark_session.assert_not_called()