Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use REPL context attributes if available to avoid calling JVM methods #5132

Merged
merged 22 commits into from Dec 22, 2021
Merged
40 changes: 40 additions & 0 deletions mlflow/utils/databricks_utils.py
@@ -1,6 +1,7 @@
import os
import logging
import subprocess
import functools

from mlflow.exceptions import MlflowException
from mlflow.utils.rest_utils import MlflowHostCreds
Expand All @@ -10,6 +11,30 @@

_logger = logging.getLogger(__name__)

_env_var_prefix = "DATABRICKS_"


def _use_env_var_if_exists(env_var, *, if_exists=lambda x: os.environ[x]):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Introduced this decorator to make it easier to preserve the existing logic for older runtime versions.

"""
Creates a decorator to insert a short circuit that's activated when the specified environment
variable exists.

:param env_var: The name of an environment variable to use.
:param if_exists: A function to evalute if `env_var` exists. Defaults to
`lambda x: os.environ[x]`.
"""

def decorator(f):
@functools.wraps(f)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice use of the decorator factory here. +1

def wrapper(*args, **kwargs):
if env_var in os.environ:
return if_exists(env_var)
return f(*args, **kwargs)

return wrapper

return decorator


def _get_dbutils():
try:
Expand Down Expand Up @@ -50,6 +75,7 @@ def _get_context_tag(context_tag_key):
return None


@_use_env_var_if_exists(_env_var_prefix + "ACL_PATH_OF_ACL_ROOT")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we prefix these environment variables with DATABRICKS?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably don't need ACL_PATH_OF_ACL_ROOT, since this is used for is_in_databricks_notebook / get_notebook_id. We can rely on DATABRICKS_NOTEBOOK_ID for those.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dbczumar Thanks for the comment! _env_var_prefix adds DATABRICKS_ or am I missing something?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doh. Sorry - missed that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably don't need ACL_PATH_OF_ACL_ROOT, since this is used for is_in_databricks_notebook / get_notebook_id. We can rely on DATABRICKS_NOTEBOOK_ID for those.

Makes sense!

def acl_path_of_acl_root():
try:
return _get_command_context().aclPathOfAclRoot().get()
Expand All @@ -72,6 +98,7 @@ def is_databricks_default_tracking_uri(tracking_uri):
return tracking_uri.lower().strip() == "databricks"


@_use_env_var_if_exists(_env_var_prefix + "NOTEBOOK_ID", if_exists=lambda x: x in os.environ)
def is_in_databricks_notebook():
if _get_property_from_spark_context("spark.databricks.notebook.id") is not None:
return True
Expand Down Expand Up @@ -111,6 +138,7 @@ def is_dbfs_fuse_available():
return False


@_use_env_var_if_exists(_env_var_prefix + "CLUSTER_ID", if_exists=lambda x: x in os.environ)
def is_in_cluster():
try:
spark_session = _get_active_spark_session()
Expand All @@ -122,6 +150,7 @@ def is_in_cluster():
return False


@_use_env_var_if_exists(_env_var_prefix + "NOTEBOOK_ID")
def get_notebook_id():
"""Should only be called if is_in_databricks_notebook is true"""
notebook_id = _get_property_from_spark_context("spark.databricks.notebook.id")
Expand All @@ -133,6 +162,7 @@ def get_notebook_id():
return None


@_use_env_var_if_exists(_env_var_prefix + "NOTEBOOK_PATH")
def get_notebook_path():
"""Should only be called if is_in_databricks_notebook is true"""
path = _get_property_from_spark_context("spark.databricks.notebook.path")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work with ephemeral notebooks within and without jobs?

Expand All @@ -144,6 +174,7 @@ def get_notebook_path():
return _get_extra_context("notebook_path")


@_use_env_var_if_exists(_env_var_prefix + "RUNTIME_VERSION")
def get_databricks_runtime():
if is_in_databricks_runtime():
spark_session = _get_active_spark_session()
Expand All @@ -154,13 +185,15 @@ def get_databricks_runtime():
return None


@_use_env_var_if_exists(_env_var_prefix + "CLUSTER_ID")
def get_cluster_id():
spark_session = _get_active_spark_session()
if spark_session is None:
return None
return spark_session.conf.get("spark.databricks.clusterUsageTags.clusterId")


@_use_env_var_if_exists(_env_var_prefix + "JOB_GROUP_ID")
def get_job_group_id():
try:
dbutils = _get_dbutils()
Expand All @@ -171,6 +204,7 @@ def get_job_group_id():
return None


@_use_env_var_if_exists(_env_var_prefix + "REPL_ID")
def get_repl_id():
"""
:return: The ID of the current Databricks Python REPL
Expand Down Expand Up @@ -198,20 +232,23 @@ def get_repl_id():
pass


@_use_env_var_if_exists(_env_var_prefix + "JOB_ID")
def get_job_id():
try:
return _get_command_context().jobId().get()
except Exception:
return _get_context_tag("jobId")


@_use_env_var_if_exists(_env_var_prefix + "JOB_RUN_ID")
def get_job_run_id():
try:
return _get_command_context().idInJob().get()
except Exception:
return _get_context_tag("idInJob")


@_use_env_var_if_exists(_env_var_prefix + "JOB_TYPE")
def get_job_type():
"""Should only be called if is_in_databricks_job is true"""
try:
Expand All @@ -228,6 +265,7 @@ def get_command_run_id():
return None


@_use_env_var_if_exists(_env_var_prefix + "API_URL")
def get_webapp_url():
"""Should only be called if is_in_databricks_notebook or is_in_databricks_jobs is true"""
url = _get_property_from_spark_context("spark.databricks.api.url")
Expand All @@ -239,13 +277,15 @@ def get_webapp_url():
return _get_extra_context("api_url")


@_use_env_var_if_exists(_env_var_prefix + "WORKSPACE_ID")
def get_workspace_id():
try:
return _get_command_context().workspaceId().get()
except Exception:
return _get_context_tag("orgId")


@_use_env_var_if_exists(_env_var_prefix + "BROWSER_HOST_NAME")
def get_browser_hostname():
try:
return _get_command_context().browserHostName().get()
Expand Down