From a035719fcf20b12118a4ac724697ace8696fe282 Mon Sep 17 00:00:00 2001 From: harupy <17039389+harupy@users.noreply.github.com> Date: Tue, 30 Nov 2021 21:57:16 +0900 Subject: [PATCH 01/21] update get_notebook_path Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- mlflow/utils/databricks_utils.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/mlflow/utils/databricks_utils.py b/mlflow/utils/databricks_utils.py index f088af5c83ef5..61041aa2d700d 100644 --- a/mlflow/utils/databricks_utils.py +++ b/mlflow/utils/databricks_utils.py @@ -135,13 +135,14 @@ def get_notebook_id(): def get_notebook_path(): """Should only be called if is_in_databricks_notebook is true""" - path = _get_property_from_spark_context("spark.databricks.notebook.path") - if path is not None: - return path - try: - return _get_command_context().notebookPath().get() - except Exception: - return _get_extra_context("notebook_path") + return os.environ.get("NOTEBOOK_PATH") + # path = _get_property_from_spark_context("spark.databricks.notebook.path") + # if path is not None: + # return path + # try: + # return _get_command_context().notebookPath().get() + # except Exception: + # return _get_extra_context("notebook_path") def get_databricks_runtime(): From c1d440cb4699c6a89ced3d1965b4dbd2f73e9146 Mon Sep 17 00:00:00 2001 From: harupy <17039389+harupy@users.noreply.github.com> Date: Wed, 1 Dec 2021 09:15:45 +0900 Subject: [PATCH 02/21] add decorator Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- mlflow/utils/databricks_utils.py | 55 +++++++++++++++++++++++++++----- 1 file changed, 47 insertions(+), 8 deletions(-) diff --git a/mlflow/utils/databricks_utils.py b/mlflow/utils/databricks_utils.py index 61041aa2d700d..f433db51d2e9a 100644 --- a/mlflow/utils/databricks_utils.py +++ b/mlflow/utils/databricks_utils.py @@ -1,6 +1,7 @@ import os import logging import subprocess +import functools from mlflow.exceptions import MlflowException from mlflow.utils.rest_utils import MlflowHostCreds @@ -10,6 +11,30 @@ _logger = logging.getLogger(__name__) +_env_var_prefix = "DATABRICKS_" + + +def _use_env_var_if_exists(env_var, *, if_exists=lambda x: os.environ[x]): + """ + Creates a decorator to insert a short circuit that's activated when the specified environment + variable exists. + + :param env_var: The name of an environment variable to use. + :param if_exists: A function to evalute if `env_var` exists. Defaults to + `lambda x: os.environ[x]`. + """ + + def decorator(f): + @functools.wraps + def wrapper(*args, **kwargs): + if env_var in os.environ: + return if_exists(env_var) + return f(*args, **kwargs) + + return wrapper + + return decorator + def _get_dbutils(): try: @@ -50,6 +75,7 @@ def _get_context_tag(context_tag_key): return None +@_use_env_var_if_exists(_env_var_prefix + "ACL_PATH_OF_ACL_ROOT") def acl_path_of_acl_root(): try: return _get_command_context().aclPathOfAclRoot().get() @@ -72,6 +98,7 @@ def is_databricks_default_tracking_uri(tracking_uri): return tracking_uri.lower().strip() == "databricks" +@_use_env_var_if_exists(_env_var_prefix + "NOTEBOOK_ID", if_exists=lambda x: x in os.environ) def is_in_databricks_notebook(): if _get_property_from_spark_context("spark.databricks.notebook.id") is not None: return True @@ -111,6 +138,7 @@ def is_dbfs_fuse_available(): return False +@_use_env_var_if_exists(_env_var_prefix + "CLUSTER_ID", if_exists=lambda x: x in os.environ) def is_in_cluster(): try: spark_session = _get_active_spark_session() @@ -122,6 +150,7 @@ def is_in_cluster(): return False +@_use_env_var_if_exists(_env_var_prefix + "NOTEBOOK_ID") def get_notebook_id(): """Should only be called if is_in_databricks_notebook is true""" notebook_id = _get_property_from_spark_context("spark.databricks.notebook.id") @@ -133,18 +162,19 @@ def get_notebook_id(): return None +@_use_env_var_if_exists(_env_var_prefix + "NOTEBOOK_PATH") def get_notebook_path(): """Should only be called if is_in_databricks_notebook is true""" - return os.environ.get("NOTEBOOK_PATH") - # path = _get_property_from_spark_context("spark.databricks.notebook.path") - # if path is not None: - # return path - # try: - # return _get_command_context().notebookPath().get() - # except Exception: - # return _get_extra_context("notebook_path") + path = _get_property_from_spark_context("spark.databricks.notebook.path") + if path is not None: + return path + try: + return _get_command_context().notebookPath().get() + except Exception: + return _get_extra_context("notebook_path") +@_use_env_var_if_exists(_env_var_prefix + "RUNTIME_VERSION") def get_databricks_runtime(): if is_in_databricks_runtime(): spark_session = _get_active_spark_session() @@ -155,6 +185,7 @@ def get_databricks_runtime(): return None +@_use_env_var_if_exists(_env_var_prefix + "CLUSTER_ID") def get_cluster_id(): spark_session = _get_active_spark_session() if spark_session is None: @@ -162,6 +193,7 @@ def get_cluster_id(): return spark_session.conf.get("spark.databricks.clusterUsageTags.clusterId") +@_use_env_var_if_exists(_env_var_prefix + "JOB_GROUP_ID") def get_job_group_id(): try: dbutils = _get_dbutils() @@ -172,6 +204,7 @@ def get_job_group_id(): return None +@_use_env_var_if_exists(_env_var_prefix + "REPL_ID") def get_repl_id(): """ :return: The ID of the current Databricks Python REPL @@ -199,6 +232,7 @@ def get_repl_id(): pass +@_use_env_var_if_exists(_env_var_prefix + "JOB_ID") def get_job_id(): try: return _get_command_context().jobId().get() @@ -206,6 +240,7 @@ def get_job_id(): return _get_context_tag("jobId") +@_use_env_var_if_exists(_env_var_prefix + "JOB_RUN_ID") def get_job_run_id(): try: return _get_command_context().idInJob().get() @@ -213,6 +248,7 @@ def get_job_run_id(): return _get_context_tag("idInJob") +@_use_env_var_if_exists(_env_var_prefix + "JOB_TYPE") def get_job_type(): """Should only be called if is_in_databricks_job is true""" try: @@ -229,6 +265,7 @@ def get_command_run_id(): return None +@_use_env_var_if_exists(_env_var_prefix + "API_URL") def get_webapp_url(): """Should only be called if is_in_databricks_notebook or is_in_databricks_jobs is true""" url = _get_property_from_spark_context("spark.databricks.api.url") @@ -240,6 +277,7 @@ def get_webapp_url(): return _get_extra_context("api_url") +@_use_env_var_if_exists(_env_var_prefix + "WORKSPACE_ID") def get_workspace_id(): try: return _get_command_context().workspaceId().get() @@ -247,6 +285,7 @@ def get_workspace_id(): return _get_context_tag("orgId") +@_use_env_var_if_exists(_env_var_prefix + "BROWSER_HOST_NAME") def get_browser_hostname(): try: return _get_command_context().browserHostName().get() From 70c5c7ea39d2b1218ee4fed677d4b0eda95c9103 Mon Sep 17 00:00:00 2001 From: harupy <17039389+harupy@users.noreply.github.com> Date: Wed, 1 Dec 2021 09:24:37 +0900 Subject: [PATCH 03/21] pass func to wraps Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- mlflow/utils/databricks_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlflow/utils/databricks_utils.py b/mlflow/utils/databricks_utils.py index f433db51d2e9a..628d951269443 100644 --- a/mlflow/utils/databricks_utils.py +++ b/mlflow/utils/databricks_utils.py @@ -25,7 +25,7 @@ def _use_env_var_if_exists(env_var, *, if_exists=lambda x: os.environ[x]): """ def decorator(f): - @functools.wraps + @functools.wraps(f) def wrapper(*args, **kwargs): if env_var in os.environ: return if_exists(env_var) From 4c981cbac9f7bef41944c3733a67e7769386ba8e Mon Sep 17 00:00:00 2001 From: harupy <17039389+harupy@users.noreply.github.com> Date: Wed, 1 Dec 2021 11:14:44 +0900 Subject: [PATCH 04/21] refactor Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- mlflow/utils/databricks_utils.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/mlflow/utils/databricks_utils.py b/mlflow/utils/databricks_utils.py index 628d951269443..9c1c70120c5e0 100644 --- a/mlflow/utils/databricks_utils.py +++ b/mlflow/utils/databricks_utils.py @@ -14,14 +14,13 @@ _env_var_prefix = "DATABRICKS_" -def _use_env_var_if_exists(env_var, *, if_exists=lambda x: os.environ[x]): +def _use_env_var_if_exists(env_var, *, if_exists=os.getenv): """ - Creates a decorator to insert a short circuit that's activated when the specified environment - variable exists. + Creates a decorator to insert a short circuit that returns `if_exists(env_var)` if `env_var` + exists. :param env_var: The name of an environment variable to use. - :param if_exists: A function to evalute if `env_var` exists. Defaults to - `lambda x: os.environ[x]`. + :param if_exists: A function to evaluate if `env_var` exists. Defaults to `os.getenv`. """ def decorator(f): @@ -36,6 +35,10 @@ def wrapper(*args, **kwargs): return decorator +def _returns_true(_): + return True + + def _get_dbutils(): try: import IPython @@ -98,7 +101,7 @@ def is_databricks_default_tracking_uri(tracking_uri): return tracking_uri.lower().strip() == "databricks" -@_use_env_var_if_exists(_env_var_prefix + "NOTEBOOK_ID", if_exists=lambda x: x in os.environ) +@_use_env_var_if_exists(_env_var_prefix + "NOTEBOOK_ID", if_exists=_returns_true) def is_in_databricks_notebook(): if _get_property_from_spark_context("spark.databricks.notebook.id") is not None: return True @@ -138,7 +141,7 @@ def is_dbfs_fuse_available(): return False -@_use_env_var_if_exists(_env_var_prefix + "CLUSTER_ID", if_exists=lambda x: x in os.environ) +@_use_env_var_if_exists(_env_var_prefix + "CLUSTER_ID", if_exists=_returns_true) def is_in_cluster(): try: spark_session = _get_active_spark_session() From e413f168b5f4c482592d5ba1b079dbe14e107926 Mon Sep 17 00:00:00 2001 From: harupy <17039389+harupy@users.noreply.github.com> Date: Wed, 1 Dec 2021 11:24:58 +0900 Subject: [PATCH 05/21] update docstring Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- mlflow/utils/databricks_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mlflow/utils/databricks_utils.py b/mlflow/utils/databricks_utils.py index 9c1c70120c5e0..b2a9109f4ab07 100644 --- a/mlflow/utils/databricks_utils.py +++ b/mlflow/utils/databricks_utils.py @@ -16,11 +16,12 @@ def _use_env_var_if_exists(env_var, *, if_exists=os.getenv): """ - Creates a decorator to insert a short circuit that returns `if_exists(env_var)` if `env_var` - exists. + Creates a decorator to insert a short circuit that returns `if_exists(env_var)` if + the environment variable `env_var` exists. :param env_var: The name of an environment variable to use. :param if_exists: A function to evaluate if `env_var` exists. Defaults to `os.getenv`. + :return: A decorator to insert the short circuit. """ def decorator(f): From b83e7e2930f5eeb0bd1195030da3a7bf165e780b Mon Sep 17 00:00:00 2001 From: harupy <17039389+harupy@users.noreply.github.com> Date: Wed, 1 Dec 2021 14:01:30 +0900 Subject: [PATCH 06/21] test Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- tests/utils/test_databricks_utils.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/tests/utils/test_databricks_utils.py b/tests/utils/test_databricks_utils.py index 94a837006e3d6..bca2b723f26e7 100644 --- a/tests/utils/test_databricks_utils.py +++ b/tests/utils/test_databricks_utils.py @@ -11,7 +11,7 @@ is_databricks_default_tracking_uri, ) from mlflow.utils.uri import construct_db_uri_from_profile -from tests.helper_functions import mock_method_chain +from tests.helper_functions import mock_method_chain, multi_context def test_no_throw(): @@ -254,3 +254,24 @@ def mock_import(name, *args, **kwargs): with mock.patch("builtins.__import__", side_effect=mock_import): assert databricks_utils.get_repl_id() == "testReplId2" + + +def test_use_env_var_if_exists(): + with mock.patch.dict( + "os.environ", + { + databricks_utils._ENV_VAR_PREFIX + "NOTEBOOK_ID": "1", + databricks_utils._ENV_VAR_PREFIX + "CLUSTER_ID": "a", + }, + clear=True, + ): + with multi_context( + mock.patch("mlflow.utils.databricks_utils._get_dbutils"), + mock.patch("mlflow.utils.databricks_utils._get_property_from_spark_context"), + mock.patch("mlflow.utils._spark_utils._get_active_spark_session"), + ) as mocks: + assert databricks_utils.get_notebook_id() == "1" + assert databricks_utils.is_in_databricks_notebook() + assert databricks_utils.get_cluster_id() == "a" + assert databricks_utils.is_in_cluster() + assert all(m.call_count == 0 for m in mocks) From bf30d272cd9db0c122126cd6b51081d2c5b6b420 Mon Sep 17 00:00:00 2001 From: harupy <17039389+harupy@users.noreply.github.com> Date: Wed, 1 Dec 2021 15:35:28 +0900 Subject: [PATCH 07/21] fix Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- mlflow/utils/databricks_utils.py | 34 ++++++++++++++++---------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/mlflow/utils/databricks_utils.py b/mlflow/utils/databricks_utils.py index b2a9109f4ab07..5b9e1e563a16c 100644 --- a/mlflow/utils/databricks_utils.py +++ b/mlflow/utils/databricks_utils.py @@ -11,7 +11,7 @@ _logger = logging.getLogger(__name__) -_env_var_prefix = "DATABRICKS_" +_ENV_VAR_PREFIX = "DATABRICKS_" def _use_env_var_if_exists(env_var, *, if_exists=os.getenv): @@ -36,7 +36,7 @@ def wrapper(*args, **kwargs): return decorator -def _returns_true(_): +def _return_true(_): return True @@ -79,7 +79,7 @@ def _get_context_tag(context_tag_key): return None -@_use_env_var_if_exists(_env_var_prefix + "ACL_PATH_OF_ACL_ROOT") +@_use_env_var_if_exists(_ENV_VAR_PREFIX + "ACL_PATH_OF_ACL_ROOT") def acl_path_of_acl_root(): try: return _get_command_context().aclPathOfAclRoot().get() @@ -102,7 +102,7 @@ def is_databricks_default_tracking_uri(tracking_uri): return tracking_uri.lower().strip() == "databricks" -@_use_env_var_if_exists(_env_var_prefix + "NOTEBOOK_ID", if_exists=_returns_true) +@_use_env_var_if_exists(_ENV_VAR_PREFIX + "NOTEBOOK_ID", if_exists=_return_true) def is_in_databricks_notebook(): if _get_property_from_spark_context("spark.databricks.notebook.id") is not None: return True @@ -142,7 +142,7 @@ def is_dbfs_fuse_available(): return False -@_use_env_var_if_exists(_env_var_prefix + "CLUSTER_ID", if_exists=_returns_true) +@_use_env_var_if_exists(_ENV_VAR_PREFIX + "CLUSTER_ID", if_exists=_return_true) def is_in_cluster(): try: spark_session = _get_active_spark_session() @@ -154,7 +154,7 @@ def is_in_cluster(): return False -@_use_env_var_if_exists(_env_var_prefix + "NOTEBOOK_ID") +@_use_env_var_if_exists(_ENV_VAR_PREFIX + "NOTEBOOK_ID") def get_notebook_id(): """Should only be called if is_in_databricks_notebook is true""" notebook_id = _get_property_from_spark_context("spark.databricks.notebook.id") @@ -166,7 +166,7 @@ def get_notebook_id(): return None -@_use_env_var_if_exists(_env_var_prefix + "NOTEBOOK_PATH") +@_use_env_var_if_exists(_ENV_VAR_PREFIX + "NOTEBOOK_PATH") def get_notebook_path(): """Should only be called if is_in_databricks_notebook is true""" path = _get_property_from_spark_context("spark.databricks.notebook.path") @@ -178,7 +178,7 @@ def get_notebook_path(): return _get_extra_context("notebook_path") -@_use_env_var_if_exists(_env_var_prefix + "RUNTIME_VERSION") +@_use_env_var_if_exists(_ENV_VAR_PREFIX + "RUNTIME_VERSION") def get_databricks_runtime(): if is_in_databricks_runtime(): spark_session = _get_active_spark_session() @@ -189,7 +189,7 @@ def get_databricks_runtime(): return None -@_use_env_var_if_exists(_env_var_prefix + "CLUSTER_ID") +@_use_env_var_if_exists(_ENV_VAR_PREFIX + "CLUSTER_ID") def get_cluster_id(): spark_session = _get_active_spark_session() if spark_session is None: @@ -197,7 +197,7 @@ def get_cluster_id(): return spark_session.conf.get("spark.databricks.clusterUsageTags.clusterId") -@_use_env_var_if_exists(_env_var_prefix + "JOB_GROUP_ID") +@_use_env_var_if_exists(_ENV_VAR_PREFIX + "JOB_GROUP_ID") def get_job_group_id(): try: dbutils = _get_dbutils() @@ -208,7 +208,7 @@ def get_job_group_id(): return None -@_use_env_var_if_exists(_env_var_prefix + "REPL_ID") +@_use_env_var_if_exists(_ENV_VAR_PREFIX + "REPL_ID") def get_repl_id(): """ :return: The ID of the current Databricks Python REPL @@ -236,7 +236,7 @@ def get_repl_id(): pass -@_use_env_var_if_exists(_env_var_prefix + "JOB_ID") +@_use_env_var_if_exists(_ENV_VAR_PREFIX + "JOB_ID") def get_job_id(): try: return _get_command_context().jobId().get() @@ -244,7 +244,7 @@ def get_job_id(): return _get_context_tag("jobId") -@_use_env_var_if_exists(_env_var_prefix + "JOB_RUN_ID") +@_use_env_var_if_exists(_ENV_VAR_PREFIX + "ID_IN_JOB") def get_job_run_id(): try: return _get_command_context().idInJob().get() @@ -252,7 +252,7 @@ def get_job_run_id(): return _get_context_tag("idInJob") -@_use_env_var_if_exists(_env_var_prefix + "JOB_TYPE") +@_use_env_var_if_exists(_ENV_VAR_PREFIX + "JOB_TASK_TYPE") def get_job_type(): """Should only be called if is_in_databricks_job is true""" try: @@ -269,7 +269,7 @@ def get_command_run_id(): return None -@_use_env_var_if_exists(_env_var_prefix + "API_URL") +@_use_env_var_if_exists(_ENV_VAR_PREFIX + "API_URL") def get_webapp_url(): """Should only be called if is_in_databricks_notebook or is_in_databricks_jobs is true""" url = _get_property_from_spark_context("spark.databricks.api.url") @@ -281,7 +281,7 @@ def get_webapp_url(): return _get_extra_context("api_url") -@_use_env_var_if_exists(_env_var_prefix + "WORKSPACE_ID") +@_use_env_var_if_exists(_ENV_VAR_PREFIX + "WORKSPACE_ID") def get_workspace_id(): try: return _get_command_context().workspaceId().get() @@ -289,7 +289,7 @@ def get_workspace_id(): return _get_context_tag("orgId") -@_use_env_var_if_exists(_env_var_prefix + "BROWSER_HOST_NAME") +@_use_env_var_if_exists(_ENV_VAR_PREFIX + "BROWSER_HOST_NAME") def get_browser_hostname(): try: return _get_command_context().browserHostName().get() From a610852349e60fd10e3d51f1b345e29da5f3e173 Mon Sep 17 00:00:00 2001 From: harupy <17039389+harupy@users.noreply.github.com> Date: Thu, 2 Dec 2021 15:56:29 +0900 Subject: [PATCH 08/21] hardcode prefix Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- mlflow/utils/databricks_utils.py | 32 +++++++++++++--------------- tests/utils/test_databricks_utils.py | 5 +---- 2 files changed, 16 insertions(+), 21 deletions(-) diff --git a/mlflow/utils/databricks_utils.py b/mlflow/utils/databricks_utils.py index 5b9e1e563a16c..a3fed8c98586d 100644 --- a/mlflow/utils/databricks_utils.py +++ b/mlflow/utils/databricks_utils.py @@ -11,8 +11,6 @@ _logger = logging.getLogger(__name__) -_ENV_VAR_PREFIX = "DATABRICKS_" - def _use_env_var_if_exists(env_var, *, if_exists=os.getenv): """ @@ -79,7 +77,7 @@ def _get_context_tag(context_tag_key): return None -@_use_env_var_if_exists(_ENV_VAR_PREFIX + "ACL_PATH_OF_ACL_ROOT") +@_use_env_var_if_exists("DATABRICKS_ACL_PATH_OF_ACL_ROOT") def acl_path_of_acl_root(): try: return _get_command_context().aclPathOfAclRoot().get() @@ -102,7 +100,7 @@ def is_databricks_default_tracking_uri(tracking_uri): return tracking_uri.lower().strip() == "databricks" -@_use_env_var_if_exists(_ENV_VAR_PREFIX + "NOTEBOOK_ID", if_exists=_return_true) +@_use_env_var_if_exists("DATABRICKS_NOTEBOOK_ID", if_exists=_return_true) def is_in_databricks_notebook(): if _get_property_from_spark_context("spark.databricks.notebook.id") is not None: return True @@ -142,7 +140,7 @@ def is_dbfs_fuse_available(): return False -@_use_env_var_if_exists(_ENV_VAR_PREFIX + "CLUSTER_ID", if_exists=_return_true) +@_use_env_var_if_exists("DATABRICKS_CLUSTER_ID", if_exists=_return_true) def is_in_cluster(): try: spark_session = _get_active_spark_session() @@ -154,7 +152,7 @@ def is_in_cluster(): return False -@_use_env_var_if_exists(_ENV_VAR_PREFIX + "NOTEBOOK_ID") +@_use_env_var_if_exists("DATABRICKS_NOTEBOOK_ID") def get_notebook_id(): """Should only be called if is_in_databricks_notebook is true""" notebook_id = _get_property_from_spark_context("spark.databricks.notebook.id") @@ -166,7 +164,7 @@ def get_notebook_id(): return None -@_use_env_var_if_exists(_ENV_VAR_PREFIX + "NOTEBOOK_PATH") +@_use_env_var_if_exists("DATABRICKS_NOTEBOOK_PATH") def get_notebook_path(): """Should only be called if is_in_databricks_notebook is true""" path = _get_property_from_spark_context("spark.databricks.notebook.path") @@ -178,7 +176,7 @@ def get_notebook_path(): return _get_extra_context("notebook_path") -@_use_env_var_if_exists(_ENV_VAR_PREFIX + "RUNTIME_VERSION") +@_use_env_var_if_exists("DATABRICKS_RUNTIME_VERSION") def get_databricks_runtime(): if is_in_databricks_runtime(): spark_session = _get_active_spark_session() @@ -189,7 +187,7 @@ def get_databricks_runtime(): return None -@_use_env_var_if_exists(_ENV_VAR_PREFIX + "CLUSTER_ID") +@_use_env_var_if_exists("DATABRICKS_CLUSTER_ID") def get_cluster_id(): spark_session = _get_active_spark_session() if spark_session is None: @@ -197,7 +195,7 @@ def get_cluster_id(): return spark_session.conf.get("spark.databricks.clusterUsageTags.clusterId") -@_use_env_var_if_exists(_ENV_VAR_PREFIX + "JOB_GROUP_ID") +@_use_env_var_if_exists("DATABRICKS_JOB_GROUP_ID") def get_job_group_id(): try: dbutils = _get_dbutils() @@ -208,7 +206,7 @@ def get_job_group_id(): return None -@_use_env_var_if_exists(_ENV_VAR_PREFIX + "REPL_ID") +@_use_env_var_if_exists("DATABRICKS_REPL_ID") def get_repl_id(): """ :return: The ID of the current Databricks Python REPL @@ -236,7 +234,7 @@ def get_repl_id(): pass -@_use_env_var_if_exists(_ENV_VAR_PREFIX + "JOB_ID") +@_use_env_var_if_exists("DATABRICKS_JOB_ID") def get_job_id(): try: return _get_command_context().jobId().get() @@ -244,7 +242,7 @@ def get_job_id(): return _get_context_tag("jobId") -@_use_env_var_if_exists(_ENV_VAR_PREFIX + "ID_IN_JOB") +@_use_env_var_if_exists("DATABRICKS_ID_IN_JOB") def get_job_run_id(): try: return _get_command_context().idInJob().get() @@ -252,7 +250,7 @@ def get_job_run_id(): return _get_context_tag("idInJob") -@_use_env_var_if_exists(_ENV_VAR_PREFIX + "JOB_TASK_TYPE") +@_use_env_var_if_exists("DATABRICKS_JOB_TASK_TYPE") def get_job_type(): """Should only be called if is_in_databricks_job is true""" try: @@ -269,7 +267,7 @@ def get_command_run_id(): return None -@_use_env_var_if_exists(_ENV_VAR_PREFIX + "API_URL") +@_use_env_var_if_exists("DATABRICKS_API_URL") def get_webapp_url(): """Should only be called if is_in_databricks_notebook or is_in_databricks_jobs is true""" url = _get_property_from_spark_context("spark.databricks.api.url") @@ -281,7 +279,7 @@ def get_webapp_url(): return _get_extra_context("api_url") -@_use_env_var_if_exists(_ENV_VAR_PREFIX + "WORKSPACE_ID") +@_use_env_var_if_exists("DATABRICKS_WORKSPACE_ID") def get_workspace_id(): try: return _get_command_context().workspaceId().get() @@ -289,7 +287,7 @@ def get_workspace_id(): return _get_context_tag("orgId") -@_use_env_var_if_exists(_ENV_VAR_PREFIX + "BROWSER_HOST_NAME") +@_use_env_var_if_exists("DATABRICKS_BROWSER_HOST_NAME") def get_browser_hostname(): try: return _get_command_context().browserHostName().get() diff --git a/tests/utils/test_databricks_utils.py b/tests/utils/test_databricks_utils.py index bca2b723f26e7..a9c65e4adc951 100644 --- a/tests/utils/test_databricks_utils.py +++ b/tests/utils/test_databricks_utils.py @@ -259,10 +259,7 @@ def mock_import(name, *args, **kwargs): def test_use_env_var_if_exists(): with mock.patch.dict( "os.environ", - { - databricks_utils._ENV_VAR_PREFIX + "NOTEBOOK_ID": "1", - databricks_utils._ENV_VAR_PREFIX + "CLUSTER_ID": "a", - }, + {"DATABRICKS_NOTEBOOK_ID": "1", "DATABRICKS_CLUSTER_ID": "a"}, clear=True, ): with multi_context( From 62784928b560db4f610929e05c8755d9a631475b Mon Sep 17 00:00:00 2001 From: harupy Date: Thu, 2 Dec 2021 22:04:17 +0900 Subject: [PATCH 09/21] add _use_message_metadata_if_exists Signed-off-by: harupy --- mlflow/utils/databricks_utils.py | 34 ++++++++++++++++++++++++++++ tests/utils/test_databricks_utils.py | 9 ++++++++ 2 files changed, 43 insertions(+) diff --git a/mlflow/utils/databricks_utils.py b/mlflow/utils/databricks_utils.py index a3fed8c98586d..975d42a31f6c7 100644 --- a/mlflow/utils/databricks_utils.py +++ b/mlflow/utils/databricks_utils.py @@ -38,6 +38,39 @@ def _return_true(_): return True +def _get_message_metadata(): + try: + import IPython + + ip_shell = IPython.get_ipython() + return ip_shell.parent_header["metadata"] + except Exception: + return None + + +def _use_message_metadata_if_exists(key): + """ + Creates a decorator to insert a short circuit that returns specified Jupyter message metadata + if it exists. + + :param key: Metadata key. + :return: A decorator to insert the short circuit. + """ + + def decorator(f): + @functools.wraps(f) + def wrapper(*args, **kwargs): + metadata = _get_message_metadata() + if metadata and key in metadata: + return metadata[key] + + return f(*args, **kwargs) + + return wrapper + + return decorator + + def _get_dbutils(): try: import IPython @@ -259,6 +292,7 @@ def get_job_type(): return _get_context_tag("jobTaskType") +@_use_message_metadata_if_exists("commandRunId") def get_command_run_id(): try: return _get_command_context().commandRunId().get() diff --git a/tests/utils/test_databricks_utils.py b/tests/utils/test_databricks_utils.py index a9c65e4adc951..97974a38d155e 100644 --- a/tests/utils/test_databricks_utils.py +++ b/tests/utils/test_databricks_utils.py @@ -272,3 +272,12 @@ def test_use_env_var_if_exists(): assert databricks_utils.get_cluster_id() == "a" assert databricks_utils.is_in_cluster() assert all(m.call_count == 0 for m in mocks) + + +def test_use_message_metadata_if_exists(): + with mock.patch( + "mlflow.utils.databricks_utils._get_message_metadata", + return_value={"commandRunId": "1"}, + ), mock.patch("mlflow.utils.databricks_utils._get_dbutils") as mock_dbutils: + assert databricks_utils.get_command_run_id() == "1" + mock_dbutils.assert_not_called() From a2213e36418e028d4a19bc785e6a1947ab9c3670 Mon Sep 17 00:00:00 2001 From: harupy <17039389+harupy@users.noreply.github.com> Date: Mon, 6 Dec 2021 11:30:14 +0900 Subject: [PATCH 10/21] use context metadata Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- mlflow/utils/databricks_utils.py | 70 ++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 30 deletions(-) diff --git a/mlflow/utils/databricks_utils.py b/mlflow/utils/databricks_utils.py index 975d42a31f6c7..8a66fc1a302ce 100644 --- a/mlflow/utils/databricks_utils.py +++ b/mlflow/utils/databricks_utils.py @@ -12,21 +12,31 @@ _logger = logging.getLogger(__name__) -def _use_env_var_if_exists(env_var, *, if_exists=os.getenv): +def _get_context_metadata(name): + try: + from dbruntime.set_context import context + + return getattr(context, name) + except Exception: + return None + + +def _use_context_metadata_if_available(name, *, if_available=lambda x: x): """ - Creates a decorator to insert a short circuit that returns `if_exists(env_var)` if - the environment variable `env_var` exists. + Creates a decorator to insert a short circuit that returns `if_available(name)` + if the specified metadata is available. - :param env_var: The name of an environment variable to use. - :param if_exists: A function to evaluate if `env_var` exists. Defaults to `os.getenv`. + :param key: Metadata name. + :param if_available: A function to evaluate if `env_var` exists. Defaults to `lambda x: x`. :return: A decorator to insert the short circuit. """ def decorator(f): @functools.wraps(f) def wrapper(*args, **kwargs): - if env_var in os.environ: - return if_exists(env_var) + metadata = _get_context_metadata(name) + if metadata: + return if_available(metadata) return f(*args, **kwargs) return wrapper @@ -38,20 +48,20 @@ def _return_true(_): return True -def _get_message_metadata(): +def _get_message_metadata(name): try: import IPython ip_shell = IPython.get_ipython() - return ip_shell.parent_header["metadata"] + return ip_shell.parent_header["metadata"][name] except Exception: return None -def _use_message_metadata_if_exists(key): +def _use_message_metadata_if_available(name): """ Creates a decorator to insert a short circuit that returns specified Jupyter message metadata - if it exists. + if it's available. :param key: Metadata key. :return: A decorator to insert the short circuit. @@ -60,9 +70,9 @@ def _use_message_metadata_if_exists(key): def decorator(f): @functools.wraps(f) def wrapper(*args, **kwargs): - metadata = _get_message_metadata() - if metadata and key in metadata: - return metadata[key] + metadata = _get_message_metadata(name) + if metadata: + return metadata return f(*args, **kwargs) @@ -110,7 +120,7 @@ def _get_context_tag(context_tag_key): return None -@_use_env_var_if_exists("DATABRICKS_ACL_PATH_OF_ACL_ROOT") +@_use_context_metadata_if_available("aclPathOfAclRoot") def acl_path_of_acl_root(): try: return _get_command_context().aclPathOfAclRoot().get() @@ -133,7 +143,7 @@ def is_databricks_default_tracking_uri(tracking_uri): return tracking_uri.lower().strip() == "databricks" -@_use_env_var_if_exists("DATABRICKS_NOTEBOOK_ID", if_exists=_return_true) +@_use_context_metadata_if_available("notebookId", if_available=_return_true) def is_in_databricks_notebook(): if _get_property_from_spark_context("spark.databricks.notebook.id") is not None: return True @@ -173,7 +183,7 @@ def is_dbfs_fuse_available(): return False -@_use_env_var_if_exists("DATABRICKS_CLUSTER_ID", if_exists=_return_true) +@_use_context_metadata_if_available("clusterId", if_available=_return_true) def is_in_cluster(): try: spark_session = _get_active_spark_session() @@ -185,7 +195,7 @@ def is_in_cluster(): return False -@_use_env_var_if_exists("DATABRICKS_NOTEBOOK_ID") +@_use_context_metadata_if_available("notebookId") def get_notebook_id(): """Should only be called if is_in_databricks_notebook is true""" notebook_id = _get_property_from_spark_context("spark.databricks.notebook.id") @@ -197,7 +207,7 @@ def get_notebook_id(): return None -@_use_env_var_if_exists("DATABRICKS_NOTEBOOK_PATH") +@_use_context_metadata_if_available("notebookPath") def get_notebook_path(): """Should only be called if is_in_databricks_notebook is true""" path = _get_property_from_spark_context("spark.databricks.notebook.path") @@ -209,7 +219,7 @@ def get_notebook_path(): return _get_extra_context("notebook_path") -@_use_env_var_if_exists("DATABRICKS_RUNTIME_VERSION") +@_use_context_metadata_if_available("runtimeVersion") def get_databricks_runtime(): if is_in_databricks_runtime(): spark_session = _get_active_spark_session() @@ -220,7 +230,7 @@ def get_databricks_runtime(): return None -@_use_env_var_if_exists("DATABRICKS_CLUSTER_ID") +@_use_context_metadata_if_available("clusterId") def get_cluster_id(): spark_session = _get_active_spark_session() if spark_session is None: @@ -228,7 +238,7 @@ def get_cluster_id(): return spark_session.conf.get("spark.databricks.clusterUsageTags.clusterId") -@_use_env_var_if_exists("DATABRICKS_JOB_GROUP_ID") +@_use_context_metadata_if_available("groupId") def get_job_group_id(): try: dbutils = _get_dbutils() @@ -239,7 +249,7 @@ def get_job_group_id(): return None -@_use_env_var_if_exists("DATABRICKS_REPL_ID") +@_use_context_metadata_if_available("replId") def get_repl_id(): """ :return: The ID of the current Databricks Python REPL @@ -267,7 +277,7 @@ def get_repl_id(): pass -@_use_env_var_if_exists("DATABRICKS_JOB_ID") +@_use_context_metadata_if_available("jobId") def get_job_id(): try: return _get_command_context().jobId().get() @@ -275,7 +285,7 @@ def get_job_id(): return _get_context_tag("jobId") -@_use_env_var_if_exists("DATABRICKS_ID_IN_JOB") +@_use_context_metadata_if_available("idInJob") def get_job_run_id(): try: return _get_command_context().idInJob().get() @@ -283,7 +293,7 @@ def get_job_run_id(): return _get_context_tag("idInJob") -@_use_env_var_if_exists("DATABRICKS_JOB_TASK_TYPE") +@_use_context_metadata_if_available("jobTaskType") def get_job_type(): """Should only be called if is_in_databricks_job is true""" try: @@ -292,7 +302,7 @@ def get_job_type(): return _get_context_tag("jobTaskType") -@_use_message_metadata_if_exists("commandRunId") +@_use_message_metadata_if_available("commandRunId") def get_command_run_id(): try: return _get_command_context().commandRunId().get() @@ -301,7 +311,7 @@ def get_command_run_id(): return None -@_use_env_var_if_exists("DATABRICKS_API_URL") +@_use_context_metadata_if_available("apiUrl") def get_webapp_url(): """Should only be called if is_in_databricks_notebook or is_in_databricks_jobs is true""" url = _get_property_from_spark_context("spark.databricks.api.url") @@ -313,7 +323,7 @@ def get_webapp_url(): return _get_extra_context("api_url") -@_use_env_var_if_exists("DATABRICKS_WORKSPACE_ID") +@_use_context_metadata_if_available("workspaceId") def get_workspace_id(): try: return _get_command_context().workspaceId().get() @@ -321,7 +331,7 @@ def get_workspace_id(): return _get_context_tag("orgId") -@_use_env_var_if_exists("DATABRICKS_BROWSER_HOST_NAME") +@_use_context_metadata_if_available("browserHostName") def get_browser_hostname(): try: return _get_command_context().browserHostName().get() From 757f0e2887cc3c7e9b5525e75be00ac85d674925 Mon Sep 17 00:00:00 2001 From: harupy <17039389+harupy@users.noreply.github.com> Date: Mon, 6 Dec 2021 12:52:48 +0900 Subject: [PATCH 11/21] debug Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- mlflow/tracking/fluent.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mlflow/tracking/fluent.py b/mlflow/tracking/fluent.py index 793e8b1020aa2..9fd2408e56bd0 100644 --- a/mlflow/tracking/fluent.py +++ b/mlflow/tracking/fluent.py @@ -285,6 +285,7 @@ def start_run( user_specified_tags[MLFLOW_RUN_NAME] = run_name tags = context_registry.resolve_tags(user_specified_tags) + print(tags) # debug active_run_obj = client.create_run(experiment_id=exp_id_for_run, tags=tags) From 78fa8843b07ea8f7ebc736c50c287a0b51f4a637 Mon Sep 17 00:00:00 2001 From: harupy <17039389+harupy@users.noreply.github.com> Date: Mon, 6 Dec 2021 12:59:08 +0900 Subject: [PATCH 12/21] fix attr name Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- mlflow/utils/databricks_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlflow/utils/databricks_utils.py b/mlflow/utils/databricks_utils.py index 8a66fc1a302ce..bdcdfdd18f4a7 100644 --- a/mlflow/utils/databricks_utils.py +++ b/mlflow/utils/databricks_utils.py @@ -238,7 +238,7 @@ def get_cluster_id(): return spark_session.conf.get("spark.databricks.clusterUsageTags.clusterId") -@_use_context_metadata_if_available("groupId") +@_use_context_metadata_if_available("jobGroupId") def get_job_group_id(): try: dbutils = _get_dbutils() From 409bccefc74233b7cd97ef044834af7bfb7aedb0 Mon Sep 17 00:00:00 2001 From: harupy <17039389+harupy@users.noreply.github.com> Date: Mon, 6 Dec 2021 13:04:20 +0900 Subject: [PATCH 13/21] remove print Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- mlflow/tracking/fluent.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mlflow/tracking/fluent.py b/mlflow/tracking/fluent.py index 9fd2408e56bd0..793e8b1020aa2 100644 --- a/mlflow/tracking/fluent.py +++ b/mlflow/tracking/fluent.py @@ -285,7 +285,6 @@ def start_run( user_specified_tags[MLFLOW_RUN_NAME] = run_name tags = context_registry.resolve_tags(user_specified_tags) - print(tags) # debug active_run_obj = client.create_run(experiment_id=exp_id_for_run, tags=tags) From 7d07c30a3e3f0bc54ddb3ddaec8d0d3a2891b29d Mon Sep 17 00:00:00 2001 From: harupy <17039389+harupy@users.noreply.github.com> Date: Mon, 6 Dec 2021 14:46:15 +0900 Subject: [PATCH 14/21] fix tests Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- tests/utils/test_databricks_utils.py | 70 +++++++++++++++++++--------- 1 file changed, 47 insertions(+), 23 deletions(-) diff --git a/tests/utils/test_databricks_utils.py b/tests/utils/test_databricks_utils.py index 97974a38d155e..44b1107de1a07 100644 --- a/tests/utils/test_databricks_utils.py +++ b/tests/utils/test_databricks_utils.py @@ -11,7 +11,7 @@ is_databricks_default_tracking_uri, ) from mlflow.utils.uri import construct_db_uri_from_profile -from tests.helper_functions import mock_method_chain, multi_context +from tests.helper_functions import mock_method_chain def test_no_throw(): @@ -256,28 +256,52 @@ def mock_import(name, *args, **kwargs): assert databricks_utils.get_repl_id() == "testReplId2" -def test_use_env_var_if_exists(): - with mock.patch.dict( - "os.environ", - {"DATABRICKS_NOTEBOOK_ID": "1", "DATABRICKS_CLUSTER_ID": "a"}, - clear=True, - ): - with multi_context( - mock.patch("mlflow.utils.databricks_utils._get_dbutils"), - mock.patch("mlflow.utils.databricks_utils._get_property_from_spark_context"), - mock.patch("mlflow.utils._spark_utils._get_active_spark_session"), - ) as mocks: - assert databricks_utils.get_notebook_id() == "1" - assert databricks_utils.is_in_databricks_notebook() - assert databricks_utils.get_cluster_id() == "a" - assert databricks_utils.is_in_cluster() - assert all(m.call_count == 0 for m in mocks) - - -def test_use_message_metadata_if_exists(): +def test_use_context_metadata_if_available(): + + with mock.patch( + "mlflow.utils.databricks_utils._get_context_metadata", + return_value="job_id", + ) as mock_context_metadata, mock.patch( + "mlflow.utils.databricks_utils._get_dbutils" + ) as mock_dbutils: + assert databricks_utils.get_job_id() == "job_id" + mock_context_metadata.assert_called_once_with("jobId") + mock_dbutils.assert_not_called() + + with mock.patch( + "mlflow.utils.databricks_utils._get_context_metadata", + return_value="notebook_id", + ) as mock_context_metadata, mock.patch( + "mlflow.utils.databricks_utils._get_property_from_spark_context" + ) as mock_spark_context: + assert databricks_utils.get_notebook_id() == "notebook_id" + mock_context_metadata.assert_called_once_with("notebookId") + mock_context_metadata.reset_mock() + assert databricks_utils.is_in_databricks_notebook() + mock_context_metadata.assert_called_once_with("notebookId") + mock_spark_context.assert_not_called() + + with mock.patch( + "mlflow.utils.databricks_utils._get_context_metadata", + return_value="cluster_id", + ) as mock_context_metadata, mock.patch( + "mlflow.utils._spark_utils._get_active_spark_session" + ) as mock_spark_session: + assert databricks_utils.get_cluster_id() == "cluster_id" + mock_context_metadata.assert_called_once_with("clusterId") + mock_context_metadata.reset_mock() + assert databricks_utils.is_in_cluster() + mock_context_metadata.assert_called_once_with("clusterId") + mock_spark_session.assert_not_called() + + +def test_use_message_metadata_if_available(): with mock.patch( "mlflow.utils.databricks_utils._get_message_metadata", - return_value={"commandRunId": "1"}, - ), mock.patch("mlflow.utils.databricks_utils._get_dbutils") as mock_dbutils: - assert databricks_utils.get_command_run_id() == "1" + return_value="command_run_id", + ) as mock_message_metadata, mock.patch( + "mlflow.utils.databricks_utils._get_dbutils" + ) as mock_dbutils: + assert databricks_utils.get_command_run_id() == "command_run_id" + mock_message_metadata.assert_called_once_with("commandRunId") mock_dbutils.assert_not_called() From 97654f78378b7f4b6d27e0e996aff83c092ed9f1 Mon Sep 17 00:00:00 2001 From: harupy <17039389+harupy@users.noreply.github.com> Date: Mon, 6 Dec 2021 16:32:43 +0900 Subject: [PATCH 15/21] fix docstring Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- mlflow/utils/databricks_utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mlflow/utils/databricks_utils.py b/mlflow/utils/databricks_utils.py index bdcdfdd18f4a7..6b31dc83cd124 100644 --- a/mlflow/utils/databricks_utils.py +++ b/mlflow/utils/databricks_utils.py @@ -26,9 +26,10 @@ def _use_context_metadata_if_available(name, *, if_available=lambda x: x): Creates a decorator to insert a short circuit that returns `if_available(name)` if the specified metadata is available. - :param key: Metadata name. - :param if_available: A function to evaluate if `env_var` exists. Defaults to `lambda x: x`. - :return: A decorator to insert the short circuit. + :param name: Metadata name. + :param if_available: Function to evaluate when the specified metadata is available. + Defaults to `lambda x: x`. + :return: Decorator to insert the short circuit. """ def decorator(f): @@ -63,8 +64,8 @@ def _use_message_metadata_if_available(name): Creates a decorator to insert a short circuit that returns specified Jupyter message metadata if it's available. - :param key: Metadata key. - :return: A decorator to insert the short circuit. + :param name: Metadata name. + :return: Decorator to insert the short circuit. """ def decorator(f): From d8f7b0eab972da3caff4a5732c218e81b3e33848 Mon Sep 17 00:00:00 2001 From: harupy <17039389+harupy@users.noreply.github.com> Date: Fri, 10 Dec 2021 11:33:41 +0900 Subject: [PATCH 16/21] use get_context Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- mlflow/utils/databricks_utils.py | 52 ++++++++++++++++---------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/mlflow/utils/databricks_utils.py b/mlflow/utils/databricks_utils.py index 6b31dc83cd124..25689717065f1 100644 --- a/mlflow/utils/databricks_utils.py +++ b/mlflow/utils/databricks_utils.py @@ -14,9 +14,9 @@ def _get_context_metadata(name): try: - from dbruntime.set_context import context + from dbruntime.set_context import get_context - return getattr(context, name) + return getattr(get_context(), name) except Exception: return None @@ -49,37 +49,37 @@ def _return_true(_): return True -def _get_message_metadata(name): - try: - import IPython +# def _get_message_metadata(name): +# try: +# import IPython - ip_shell = IPython.get_ipython() - return ip_shell.parent_header["metadata"][name] - except Exception: - return None +# ip_shell = IPython.get_ipython() +# return ip_shell.parent_header["metadata"][name] +# except Exception: +# return None -def _use_message_metadata_if_available(name): - """ - Creates a decorator to insert a short circuit that returns specified Jupyter message metadata - if it's available. +# def _use_message_metadata_if_available(name): +# """ +# Creates a decorator to insert a short circuit that returns specified Jupyter message metadata +# if it's available. - :param name: Metadata name. - :return: Decorator to insert the short circuit. - """ +# :param name: Metadata name. +# :return: Decorator to insert the short circuit. +# """ - def decorator(f): - @functools.wraps(f) - def wrapper(*args, **kwargs): - metadata = _get_message_metadata(name) - if metadata: - return metadata +# def decorator(f): +# @functools.wraps(f) +# def wrapper(*args, **kwargs): +# metadata = _get_message_metadata(name) +# if metadata: +# return metadata - return f(*args, **kwargs) +# return f(*args, **kwargs) - return wrapper +# return wrapper - return decorator +# return decorator def _get_dbutils(): @@ -303,7 +303,7 @@ def get_job_type(): return _get_context_tag("jobTaskType") -@_use_message_metadata_if_available("commandRunId") +@_use_context_metadata_if_available("commandRunId") def get_command_run_id(): try: return _get_command_context().commandRunId().get() From 99719897f0f45b1cd3ac34cb0e78f48785bd812b Mon Sep 17 00:00:00 2001 From: harupy <17039389+harupy@users.noreply.github.com> Date: Fri, 10 Dec 2021 12:22:56 +0900 Subject: [PATCH 17/21] rename functions Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- mlflow/utils/databricks_utils.py | 77 ++++++++-------------------- tests/utils/test_databricks_utils.py | 20 ++------ 2 files changed, 26 insertions(+), 71 deletions(-) diff --git a/mlflow/utils/databricks_utils.py b/mlflow/utils/databricks_utils.py index 25689717065f1..d3692cdefebe6 100644 --- a/mlflow/utils/databricks_utils.py +++ b/mlflow/utils/databricks_utils.py @@ -12,7 +12,7 @@ _logger = logging.getLogger(__name__) -def _get_context_metadata(name): +def _get_context_attribute(name): try: from dbruntime.set_context import get_context @@ -21,13 +21,13 @@ def _get_context_metadata(name): return None -def _use_context_metadata_if_available(name, *, if_available=lambda x: x): +def _use_context_attribute_if_available(name, *, if_available=lambda x: x): """ Creates a decorator to insert a short circuit that returns `if_available(name)` - if the specified metadata is available. + if the specified context attribute is available. - :param name: Metadata name. - :param if_available: Function to evaluate when the specified metadata is available. + :param name: Context attribute name. + :param if_available: Function to evaluate when the specified context attribute is available. Defaults to `lambda x: x`. :return: Decorator to insert the short circuit. """ @@ -35,7 +35,7 @@ def _use_context_metadata_if_available(name, *, if_available=lambda x: x): def decorator(f): @functools.wraps(f) def wrapper(*args, **kwargs): - metadata = _get_context_metadata(name) + metadata = _get_context_attribute(name) if metadata: return if_available(metadata) return f(*args, **kwargs) @@ -49,39 +49,6 @@ def _return_true(_): return True -# def _get_message_metadata(name): -# try: -# import IPython - -# ip_shell = IPython.get_ipython() -# return ip_shell.parent_header["metadata"][name] -# except Exception: -# return None - - -# def _use_message_metadata_if_available(name): -# """ -# Creates a decorator to insert a short circuit that returns specified Jupyter message metadata -# if it's available. - -# :param name: Metadata name. -# :return: Decorator to insert the short circuit. -# """ - -# def decorator(f): -# @functools.wraps(f) -# def wrapper(*args, **kwargs): -# metadata = _get_message_metadata(name) -# if metadata: -# return metadata - -# return f(*args, **kwargs) - -# return wrapper - -# return decorator - - def _get_dbutils(): try: import IPython @@ -121,7 +88,7 @@ def _get_context_tag(context_tag_key): return None -@_use_context_metadata_if_available("aclPathOfAclRoot") +@_use_context_attribute_if_available("aclPathOfAclRoot") def acl_path_of_acl_root(): try: return _get_command_context().aclPathOfAclRoot().get() @@ -144,7 +111,7 @@ def is_databricks_default_tracking_uri(tracking_uri): return tracking_uri.lower().strip() == "databricks" -@_use_context_metadata_if_available("notebookId", if_available=_return_true) +@_use_context_attribute_if_available("notebookId", if_available=_return_true) def is_in_databricks_notebook(): if _get_property_from_spark_context("spark.databricks.notebook.id") is not None: return True @@ -184,7 +151,7 @@ def is_dbfs_fuse_available(): return False -@_use_context_metadata_if_available("clusterId", if_available=_return_true) +@_use_context_attribute_if_available("clusterId", if_available=_return_true) def is_in_cluster(): try: spark_session = _get_active_spark_session() @@ -196,7 +163,7 @@ def is_in_cluster(): return False -@_use_context_metadata_if_available("notebookId") +@_use_context_attribute_if_available("notebookId") def get_notebook_id(): """Should only be called if is_in_databricks_notebook is true""" notebook_id = _get_property_from_spark_context("spark.databricks.notebook.id") @@ -208,7 +175,7 @@ def get_notebook_id(): return None -@_use_context_metadata_if_available("notebookPath") +@_use_context_attribute_if_available("notebookPath") def get_notebook_path(): """Should only be called if is_in_databricks_notebook is true""" path = _get_property_from_spark_context("spark.databricks.notebook.path") @@ -220,7 +187,7 @@ def get_notebook_path(): return _get_extra_context("notebook_path") -@_use_context_metadata_if_available("runtimeVersion") +@_use_context_attribute_if_available("runtimeVersion") def get_databricks_runtime(): if is_in_databricks_runtime(): spark_session = _get_active_spark_session() @@ -231,7 +198,7 @@ def get_databricks_runtime(): return None -@_use_context_metadata_if_available("clusterId") +@_use_context_attribute_if_available("clusterId") def get_cluster_id(): spark_session = _get_active_spark_session() if spark_session is None: @@ -239,7 +206,7 @@ def get_cluster_id(): return spark_session.conf.get("spark.databricks.clusterUsageTags.clusterId") -@_use_context_metadata_if_available("jobGroupId") +@_use_context_attribute_if_available("jobGroupId") def get_job_group_id(): try: dbutils = _get_dbutils() @@ -250,7 +217,7 @@ def get_job_group_id(): return None -@_use_context_metadata_if_available("replId") +@_use_context_attribute_if_available("replId") def get_repl_id(): """ :return: The ID of the current Databricks Python REPL @@ -278,7 +245,7 @@ def get_repl_id(): pass -@_use_context_metadata_if_available("jobId") +@_use_context_attribute_if_available("jobId") def get_job_id(): try: return _get_command_context().jobId().get() @@ -286,7 +253,7 @@ def get_job_id(): return _get_context_tag("jobId") -@_use_context_metadata_if_available("idInJob") +@_use_context_attribute_if_available("idInJob") def get_job_run_id(): try: return _get_command_context().idInJob().get() @@ -294,7 +261,7 @@ def get_job_run_id(): return _get_context_tag("idInJob") -@_use_context_metadata_if_available("jobTaskType") +@_use_context_attribute_if_available("jobTaskType") def get_job_type(): """Should only be called if is_in_databricks_job is true""" try: @@ -303,7 +270,7 @@ def get_job_type(): return _get_context_tag("jobTaskType") -@_use_context_metadata_if_available("commandRunId") +@_use_context_attribute_if_available("commandRunId") def get_command_run_id(): try: return _get_command_context().commandRunId().get() @@ -312,7 +279,7 @@ def get_command_run_id(): return None -@_use_context_metadata_if_available("apiUrl") +@_use_context_attribute_if_available("apiUrl") def get_webapp_url(): """Should only be called if is_in_databricks_notebook or is_in_databricks_jobs is true""" url = _get_property_from_spark_context("spark.databricks.api.url") @@ -324,7 +291,7 @@ def get_webapp_url(): return _get_extra_context("api_url") -@_use_context_metadata_if_available("workspaceId") +@_use_context_attribute_if_available("workspaceId") def get_workspace_id(): try: return _get_command_context().workspaceId().get() @@ -332,7 +299,7 @@ def get_workspace_id(): return _get_context_tag("orgId") -@_use_context_metadata_if_available("browserHostName") +@_use_context_attribute_if_available("browserHostName") def get_browser_hostname(): try: return _get_command_context().browserHostName().get() diff --git a/tests/utils/test_databricks_utils.py b/tests/utils/test_databricks_utils.py index 44b1107de1a07..3b7f7405804a4 100644 --- a/tests/utils/test_databricks_utils.py +++ b/tests/utils/test_databricks_utils.py @@ -256,10 +256,10 @@ def mock_import(name, *args, **kwargs): assert databricks_utils.get_repl_id() == "testReplId2" -def test_use_context_metadata_if_available(): +def test_use_context_attribute_if_available(): with mock.patch( - "mlflow.utils.databricks_utils._get_context_metadata", + "mlflow.utils.databricks_utils._get_context_attribute", return_value="job_id", ) as mock_context_metadata, mock.patch( "mlflow.utils.databricks_utils._get_dbutils" @@ -269,7 +269,7 @@ def test_use_context_metadata_if_available(): mock_dbutils.assert_not_called() with mock.patch( - "mlflow.utils.databricks_utils._get_context_metadata", + "mlflow.utils.databricks_utils._get_context_attribute", return_value="notebook_id", ) as mock_context_metadata, mock.patch( "mlflow.utils.databricks_utils._get_property_from_spark_context" @@ -282,7 +282,7 @@ def test_use_context_metadata_if_available(): mock_spark_context.assert_not_called() with mock.patch( - "mlflow.utils.databricks_utils._get_context_metadata", + "mlflow.utils.databricks_utils._get_context_attribute", return_value="cluster_id", ) as mock_context_metadata, mock.patch( "mlflow.utils._spark_utils._get_active_spark_session" @@ -293,15 +293,3 @@ def test_use_context_metadata_if_available(): assert databricks_utils.is_in_cluster() mock_context_metadata.assert_called_once_with("clusterId") mock_spark_session.assert_not_called() - - -def test_use_message_metadata_if_available(): - with mock.patch( - "mlflow.utils.databricks_utils._get_message_metadata", - return_value="command_run_id", - ) as mock_message_metadata, mock.patch( - "mlflow.utils.databricks_utils._get_dbutils" - ) as mock_dbutils: - assert databricks_utils.get_command_run_id() == "command_run_id" - mock_message_metadata.assert_called_once_with("commandRunId") - mock_dbutils.assert_not_called() From 9f7dd9d9d7a2c2065458abf8683827d7cd0b76e5 Mon Sep 17 00:00:00 2001 From: harupy <17039389+harupy@users.noreply.github.com> Date: Tue, 14 Dec 2021 21:04:41 +0900 Subject: [PATCH 18/21] fix module name Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- mlflow/utils/databricks_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlflow/utils/databricks_utils.py b/mlflow/utils/databricks_utils.py index d3692cdefebe6..873ed5b96ac88 100644 --- a/mlflow/utils/databricks_utils.py +++ b/mlflow/utils/databricks_utils.py @@ -14,7 +14,7 @@ def _get_context_attribute(name): try: - from dbruntime.set_context import get_context + from dbruntime.databricks_metadata_context import get_context return getattr(get_context(), name) except Exception: From 3c3676c01da7aea9c6829c3a8084133659fd4468 Mon Sep 17 00:00:00 2001 From: harupy <17039389+harupy@users.noreply.github.com> Date: Tue, 14 Dec 2021 21:56:12 +0900 Subject: [PATCH 19/21] fix module name Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- mlflow/utils/databricks_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlflow/utils/databricks_utils.py b/mlflow/utils/databricks_utils.py index 873ed5b96ac88..f5d774b0c915b 100644 --- a/mlflow/utils/databricks_utils.py +++ b/mlflow/utils/databricks_utils.py @@ -14,7 +14,7 @@ def _get_context_attribute(name): try: - from dbruntime.databricks_metadata_context import get_context + from dbruntime.databricks_repl_context import get_context return getattr(get_context(), name) except Exception: From 5ac047586a413a3854f3c6249328c766aa9756c2 Mon Sep 17 00:00:00 2001 From: harupy <17039389+harupy@users.noreply.github.com> Date: Thu, 16 Dec 2021 09:29:27 +0900 Subject: [PATCH 20/21] use boolean attributes Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- mlflow/utils/databricks_utils.py | 34 +++++++++++----------------- tests/utils/test_databricks_utils.py | 8 +++---- 2 files changed, 17 insertions(+), 25 deletions(-) diff --git a/mlflow/utils/databricks_utils.py b/mlflow/utils/databricks_utils.py index f5d774b0c915b..70d84629cc7ff 100644 --- a/mlflow/utils/databricks_utils.py +++ b/mlflow/utils/databricks_utils.py @@ -13,42 +13,33 @@ def _get_context_attribute(name): - try: - from dbruntime.databricks_repl_context import get_context + from dbruntime.databricks_repl_context import get_context - return getattr(get_context(), name) - except Exception: - return None + return getattr(get_context(), name) -def _use_context_attribute_if_available(name, *, if_available=lambda x: x): +def _use_context_attribute_if_available(name): """ - Creates a decorator to insert a short circuit that returns `if_available(name)` - if the specified context attribute is available. + Creates a decorator to insert a short circuit that returns the specified context attribute if + it's available. - :param name: Context attribute name. - :param if_available: Function to evaluate when the specified context attribute is available. - Defaults to `lambda x: x`. + :param name: Context attribute name (e.g. "api_url"). :return: Decorator to insert the short circuit. """ def decorator(f): @functools.wraps(f) def wrapper(*args, **kwargs): - metadata = _get_context_attribute(name) - if metadata: - return if_available(metadata) - return f(*args, **kwargs) + try: + return _get_context_attribute(name) + except Exception: + return f(*args, **kwargs) return wrapper return decorator -def _return_true(_): - return True - - def _get_dbutils(): try: import IPython @@ -111,7 +102,7 @@ def is_databricks_default_tracking_uri(tracking_uri): return tracking_uri.lower().strip() == "databricks" -@_use_context_attribute_if_available("notebookId", if_available=_return_true) +@_use_context_attribute_if_available("isInNotebook") def is_in_databricks_notebook(): if _get_property_from_spark_context("spark.databricks.notebook.id") is not None: return True @@ -121,6 +112,7 @@ def is_in_databricks_notebook(): return False +@_use_context_attribute_if_available("isInJob") def is_in_databricks_job(): try: return get_job_id() is not None and get_job_run_id() is not None @@ -151,7 +143,7 @@ def is_dbfs_fuse_available(): return False -@_use_context_attribute_if_available("clusterId", if_available=_return_true) +@_use_context_attribute_if_available("isInCluster") def is_in_cluster(): try: spark_session = _get_active_spark_session() diff --git a/tests/utils/test_databricks_utils.py b/tests/utils/test_databricks_utils.py index 3b7f7405804a4..66d1b510f54b1 100644 --- a/tests/utils/test_databricks_utils.py +++ b/tests/utils/test_databricks_utils.py @@ -270,7 +270,7 @@ def test_use_context_attribute_if_available(): with mock.patch( "mlflow.utils.databricks_utils._get_context_attribute", - return_value="notebook_id", + side_effect={"notebookId": "notebook_id", "isInNotebook": True}.get, ) as mock_context_metadata, mock.patch( "mlflow.utils.databricks_utils._get_property_from_spark_context" ) as mock_spark_context: @@ -278,12 +278,12 @@ def test_use_context_attribute_if_available(): mock_context_metadata.assert_called_once_with("notebookId") mock_context_metadata.reset_mock() assert databricks_utils.is_in_databricks_notebook() - mock_context_metadata.assert_called_once_with("notebookId") + mock_context_metadata.assert_called_once_with("isInNotebook") mock_spark_context.assert_not_called() with mock.patch( "mlflow.utils.databricks_utils._get_context_attribute", - return_value="cluster_id", + side_effect={"clusterId": "cluster_id", "isInCluster": True}.get, ) as mock_context_metadata, mock.patch( "mlflow.utils._spark_utils._get_active_spark_session" ) as mock_spark_session: @@ -291,5 +291,5 @@ def test_use_context_attribute_if_available(): mock_context_metadata.assert_called_once_with("clusterId") mock_context_metadata.reset_mock() assert databricks_utils.is_in_cluster() - mock_context_metadata.assert_called_once_with("clusterId") + mock_context_metadata.assert_called_once_with("isInCluster") mock_spark_session.assert_not_called() From 6804295b866ecb98aa39811bfae4d0ee23311374 Mon Sep 17 00:00:00 2001 From: harupy <17039389+harupy@users.noreply.github.com> Date: Tue, 21 Dec 2021 17:22:17 +0900 Subject: [PATCH 21/21] refactor Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- mlflow/utils/databricks_utils.py | 54 ++++++++++++++-------------- tests/utils/test_databricks_utils.py | 42 +++++++++++----------- 2 files changed, 47 insertions(+), 49 deletions(-) diff --git a/mlflow/utils/databricks_utils.py b/mlflow/utils/databricks_utils.py index 70d84629cc7ff..f9701dca1d0b5 100644 --- a/mlflow/utils/databricks_utils.py +++ b/mlflow/utils/databricks_utils.py @@ -12,18 +12,12 @@ _logger = logging.getLogger(__name__) -def _get_context_attribute(name): - from dbruntime.databricks_repl_context import get_context - - return getattr(get_context(), name) - - -def _use_context_attribute_if_available(name): +def _use_repl_context_if_available(name): """ - Creates a decorator to insert a short circuit that returns the specified context attribute if - it's available. + Creates a decorator to insert a short circuit that returns the specified REPL context attribute + if it's available. - :param name: Context attribute name (e.g. "api_url"). + :param name: Attribute name (e.g. "apiUrl"). :return: Decorator to insert the short circuit. """ @@ -31,7 +25,11 @@ def decorator(f): @functools.wraps(f) def wrapper(*args, **kwargs): try: - return _get_context_attribute(name) + from dbruntime.databricks_repl_context import get_context + + context = get_context() + if context is not None and hasattr(context, name): + return getattr(context, name) except Exception: return f(*args, **kwargs) @@ -79,7 +77,7 @@ def _get_context_tag(context_tag_key): return None -@_use_context_attribute_if_available("aclPathOfAclRoot") +@_use_repl_context_if_available("aclPathOfAclRoot") def acl_path_of_acl_root(): try: return _get_command_context().aclPathOfAclRoot().get() @@ -102,7 +100,7 @@ def is_databricks_default_tracking_uri(tracking_uri): return tracking_uri.lower().strip() == "databricks" -@_use_context_attribute_if_available("isInNotebook") +@_use_repl_context_if_available("isInNotebook") def is_in_databricks_notebook(): if _get_property_from_spark_context("spark.databricks.notebook.id") is not None: return True @@ -112,7 +110,7 @@ def is_in_databricks_notebook(): return False -@_use_context_attribute_if_available("isInJob") +@_use_repl_context_if_available("isInJob") def is_in_databricks_job(): try: return get_job_id() is not None and get_job_run_id() is not None @@ -143,7 +141,7 @@ def is_dbfs_fuse_available(): return False -@_use_context_attribute_if_available("isInCluster") +@_use_repl_context_if_available("isInCluster") def is_in_cluster(): try: spark_session = _get_active_spark_session() @@ -155,7 +153,7 @@ def is_in_cluster(): return False -@_use_context_attribute_if_available("notebookId") +@_use_repl_context_if_available("notebookId") def get_notebook_id(): """Should only be called if is_in_databricks_notebook is true""" notebook_id = _get_property_from_spark_context("spark.databricks.notebook.id") @@ -167,7 +165,7 @@ def get_notebook_id(): return None -@_use_context_attribute_if_available("notebookPath") +@_use_repl_context_if_available("notebookPath") def get_notebook_path(): """Should only be called if is_in_databricks_notebook is true""" path = _get_property_from_spark_context("spark.databricks.notebook.path") @@ -179,7 +177,7 @@ def get_notebook_path(): return _get_extra_context("notebook_path") -@_use_context_attribute_if_available("runtimeVersion") +@_use_repl_context_if_available("runtimeVersion") def get_databricks_runtime(): if is_in_databricks_runtime(): spark_session = _get_active_spark_session() @@ -190,7 +188,7 @@ def get_databricks_runtime(): return None -@_use_context_attribute_if_available("clusterId") +@_use_repl_context_if_available("clusterId") def get_cluster_id(): spark_session = _get_active_spark_session() if spark_session is None: @@ -198,7 +196,7 @@ def get_cluster_id(): return spark_session.conf.get("spark.databricks.clusterUsageTags.clusterId") -@_use_context_attribute_if_available("jobGroupId") +@_use_repl_context_if_available("jobGroupId") def get_job_group_id(): try: dbutils = _get_dbutils() @@ -209,7 +207,7 @@ def get_job_group_id(): return None -@_use_context_attribute_if_available("replId") +@_use_repl_context_if_available("replId") def get_repl_id(): """ :return: The ID of the current Databricks Python REPL @@ -237,7 +235,7 @@ def get_repl_id(): pass -@_use_context_attribute_if_available("jobId") +@_use_repl_context_if_available("jobId") def get_job_id(): try: return _get_command_context().jobId().get() @@ -245,7 +243,7 @@ def get_job_id(): return _get_context_tag("jobId") -@_use_context_attribute_if_available("idInJob") +@_use_repl_context_if_available("idInJob") def get_job_run_id(): try: return _get_command_context().idInJob().get() @@ -253,7 +251,7 @@ def get_job_run_id(): return _get_context_tag("idInJob") -@_use_context_attribute_if_available("jobTaskType") +@_use_repl_context_if_available("jobTaskType") def get_job_type(): """Should only be called if is_in_databricks_job is true""" try: @@ -262,7 +260,7 @@ def get_job_type(): return _get_context_tag("jobTaskType") -@_use_context_attribute_if_available("commandRunId") +@_use_repl_context_if_available("commandRunId") def get_command_run_id(): try: return _get_command_context().commandRunId().get() @@ -271,7 +269,7 @@ def get_command_run_id(): return None -@_use_context_attribute_if_available("apiUrl") +@_use_repl_context_if_available("apiUrl") def get_webapp_url(): """Should only be called if is_in_databricks_notebook or is_in_databricks_jobs is true""" url = _get_property_from_spark_context("spark.databricks.api.url") @@ -283,7 +281,7 @@ def get_webapp_url(): return _get_extra_context("api_url") -@_use_context_attribute_if_available("workspaceId") +@_use_repl_context_if_available("workspaceId") def get_workspace_id(): try: return _get_command_context().workspaceId().get() @@ -291,7 +289,7 @@ def get_workspace_id(): return _get_context_tag("orgId") -@_use_context_attribute_if_available("browserHostName") +@_use_repl_context_if_available("browserHostName") def get_browser_hostname(): try: return _get_command_context().browserHostName().get() diff --git a/tests/utils/test_databricks_utils.py b/tests/utils/test_databricks_utils.py index 66d1b510f54b1..a45c8cb42cc89 100644 --- a/tests/utils/test_databricks_utils.py +++ b/tests/utils/test_databricks_utils.py @@ -256,40 +256,40 @@ def mock_import(name, *args, **kwargs): assert databricks_utils.get_repl_id() == "testReplId2" -def test_use_context_attribute_if_available(): +def test_use_repl_context_if_available(tmpdir): + # Create a fake databricks_repl_context module + tmpdir.mkdir("dbruntime").join("databricks_repl_context.py").write( + """ +def get_context(): + pass +""" + ) + sys.path.append(tmpdir.strpath) with mock.patch( - "mlflow.utils.databricks_utils._get_context_attribute", - return_value="job_id", - ) as mock_context_metadata, mock.patch( - "mlflow.utils.databricks_utils._get_dbutils" - ) as mock_dbutils: + "dbruntime.databricks_repl_context.get_context", + return_value=mock.MagicMock(jobId="job_id"), + ) as mock_get_context, mock.patch("mlflow.utils.databricks_utils._get_dbutils") as mock_dbutils: assert databricks_utils.get_job_id() == "job_id" - mock_context_metadata.assert_called_once_with("jobId") + mock_get_context.assert_called_once() mock_dbutils.assert_not_called() with mock.patch( - "mlflow.utils.databricks_utils._get_context_attribute", - side_effect={"notebookId": "notebook_id", "isInNotebook": True}.get, - ) as mock_context_metadata, mock.patch( + "dbruntime.databricks_repl_context.get_context", + return_value=mock.MagicMock(notebookId="notebook_id"), + ) as mock_get_context, mock.patch( "mlflow.utils.databricks_utils._get_property_from_spark_context" ) as mock_spark_context: assert databricks_utils.get_notebook_id() == "notebook_id" - mock_context_metadata.assert_called_once_with("notebookId") - mock_context_metadata.reset_mock() - assert databricks_utils.is_in_databricks_notebook() - mock_context_metadata.assert_called_once_with("isInNotebook") + mock_get_context.assert_called_once() mock_spark_context.assert_not_called() with mock.patch( - "mlflow.utils.databricks_utils._get_context_attribute", - side_effect={"clusterId": "cluster_id", "isInCluster": True}.get, - ) as mock_context_metadata, mock.patch( + "dbruntime.databricks_repl_context.get_context", + return_value=mock.MagicMock(isInCluster=True), + ) as mock_get_context, mock.patch( "mlflow.utils._spark_utils._get_active_spark_session" ) as mock_spark_session: - assert databricks_utils.get_cluster_id() == "cluster_id" - mock_context_metadata.assert_called_once_with("clusterId") - mock_context_metadata.reset_mock() assert databricks_utils.is_in_cluster() - mock_context_metadata.assert_called_once_with("isInCluster") + mock_get_context.assert_called_once() mock_spark_session.assert_not_called()