New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use REPL context attributes if available to avoid calling JVM methods #5132
Changes from 5 commits
a035719
c1d440c
70c5c7e
4c981cb
e413f16
b83e7e2
bf30d27
a610852
6278492
a2213e3
757f0e2
78fa884
409bcce
7d07c30
97654f7
d8f7b0e
9971989
9f7dd9d
3c3676c
5ac0475
6804295
3ca6b30
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
import os | ||
import logging | ||
import subprocess | ||
import functools | ||
|
||
from mlflow.exceptions import MlflowException | ||
from mlflow.utils.rest_utils import MlflowHostCreds | ||
|
@@ -10,6 +11,34 @@ | |
|
||
_logger = logging.getLogger(__name__) | ||
|
||
_env_var_prefix = "DATABRICKS_" | ||
|
||
|
||
def _use_env_var_if_exists(env_var, *, if_exists=os.getenv): | ||
""" | ||
Creates a decorator to insert a short circuit that returns `if_exists(env_var)` if | ||
the environment variable `env_var` exists. | ||
|
||
:param env_var: The name of an environment variable to use. | ||
:param if_exists: A function to evaluate if `env_var` exists. Defaults to `os.getenv`. | ||
:return: A decorator to insert the short circuit. | ||
""" | ||
|
||
def decorator(f): | ||
@functools.wraps(f) | ||
def wrapper(*args, **kwargs): | ||
if env_var in os.environ: | ||
return if_exists(env_var) | ||
return f(*args, **kwargs) | ||
|
||
return wrapper | ||
|
||
return decorator | ||
|
||
|
||
def _returns_true(_): | ||
return True | ||
|
||
|
||
def _get_dbutils(): | ||
try: | ||
|
@@ -50,6 +79,7 @@ def _get_context_tag(context_tag_key): | |
return None | ||
|
||
|
||
@_use_env_var_if_exists(_env_var_prefix + "ACL_PATH_OF_ACL_ROOT") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we prefix these environment variables with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We probably don't need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @dbczumar Thanks for the comment! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doh. Sorry - missed that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Makes sense! |
||
def acl_path_of_acl_root(): | ||
try: | ||
return _get_command_context().aclPathOfAclRoot().get() | ||
|
@@ -72,6 +102,7 @@ def is_databricks_default_tracking_uri(tracking_uri): | |
return tracking_uri.lower().strip() == "databricks" | ||
|
||
|
||
@_use_env_var_if_exists(_env_var_prefix + "NOTEBOOK_ID", if_exists=_returns_true) | ||
def is_in_databricks_notebook(): | ||
if _get_property_from_spark_context("spark.databricks.notebook.id") is not None: | ||
return True | ||
|
@@ -111,6 +142,7 @@ def is_dbfs_fuse_available(): | |
return False | ||
|
||
|
||
@_use_env_var_if_exists(_env_var_prefix + "CLUSTER_ID", if_exists=_returns_true) | ||
def is_in_cluster(): | ||
try: | ||
spark_session = _get_active_spark_session() | ||
|
@@ -122,6 +154,7 @@ def is_in_cluster(): | |
return False | ||
|
||
|
||
@_use_env_var_if_exists(_env_var_prefix + "NOTEBOOK_ID") | ||
def get_notebook_id(): | ||
"""Should only be called if is_in_databricks_notebook is true""" | ||
notebook_id = _get_property_from_spark_context("spark.databricks.notebook.id") | ||
|
@@ -133,6 +166,7 @@ def get_notebook_id(): | |
return None | ||
|
||
|
||
@_use_env_var_if_exists(_env_var_prefix + "NOTEBOOK_PATH") | ||
def get_notebook_path(): | ||
"""Should only be called if is_in_databricks_notebook is true""" | ||
path = _get_property_from_spark_context("spark.databricks.notebook.path") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this work with ephemeral notebooks within and without jobs? |
||
|
@@ -144,6 +178,7 @@ def get_notebook_path(): | |
return _get_extra_context("notebook_path") | ||
|
||
|
||
@_use_env_var_if_exists(_env_var_prefix + "RUNTIME_VERSION") | ||
def get_databricks_runtime(): | ||
if is_in_databricks_runtime(): | ||
spark_session = _get_active_spark_session() | ||
|
@@ -154,13 +189,15 @@ def get_databricks_runtime(): | |
return None | ||
|
||
|
||
@_use_env_var_if_exists(_env_var_prefix + "CLUSTER_ID") | ||
def get_cluster_id(): | ||
spark_session = _get_active_spark_session() | ||
if spark_session is None: | ||
return None | ||
return spark_session.conf.get("spark.databricks.clusterUsageTags.clusterId") | ||
|
||
|
||
@_use_env_var_if_exists(_env_var_prefix + "JOB_GROUP_ID") | ||
def get_job_group_id(): | ||
try: | ||
dbutils = _get_dbutils() | ||
|
@@ -171,6 +208,7 @@ def get_job_group_id(): | |
return None | ||
|
||
|
||
@_use_env_var_if_exists(_env_var_prefix + "REPL_ID") | ||
def get_repl_id(): | ||
""" | ||
:return: The ID of the current Databricks Python REPL | ||
|
@@ -198,20 +236,23 @@ def get_repl_id(): | |
pass | ||
|
||
|
||
@_use_env_var_if_exists(_env_var_prefix + "JOB_ID") | ||
def get_job_id(): | ||
try: | ||
return _get_command_context().jobId().get() | ||
except Exception: | ||
return _get_context_tag("jobId") | ||
|
||
|
||
@_use_env_var_if_exists(_env_var_prefix + "JOB_RUN_ID") | ||
def get_job_run_id(): | ||
try: | ||
return _get_command_context().idInJob().get() | ||
except Exception: | ||
return _get_context_tag("idInJob") | ||
|
||
|
||
@_use_env_var_if_exists(_env_var_prefix + "JOB_TYPE") | ||
def get_job_type(): | ||
"""Should only be called if is_in_databricks_job is true""" | ||
try: | ||
|
@@ -228,6 +269,7 @@ def get_command_run_id(): | |
return None | ||
|
||
|
||
@_use_env_var_if_exists(_env_var_prefix + "API_URL") | ||
def get_webapp_url(): | ||
"""Should only be called if is_in_databricks_notebook or is_in_databricks_jobs is true""" | ||
url = _get_property_from_spark_context("spark.databricks.api.url") | ||
|
@@ -239,13 +281,15 @@ def get_webapp_url(): | |
return _get_extra_context("api_url") | ||
|
||
|
||
@_use_env_var_if_exists(_env_var_prefix + "WORKSPACE_ID") | ||
def get_workspace_id(): | ||
try: | ||
return _get_command_context().workspaceId().get() | ||
except Exception: | ||
return _get_context_tag("orgId") | ||
|
||
|
||
@_use_env_var_if_exists(_env_var_prefix + "BROWSER_HOST_NAME") | ||
def get_browser_hostname(): | ||
try: | ||
return _get_command_context().browserHostName().get() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice use of the decorator factory here. +1