Skip to content

Commit

Permalink
Allow and prefer non-prefixed extra fields for remaining azure (#27220)
Browse files Browse the repository at this point in the history
From airflow version 2.3, extra prefixes are not required so we enable them here.

Hooks updated:
* batch
* container volume
* cosmos
* data lake
* synapse
  • Loading branch information
dstandish committed Oct 28, 2022
1 parent c49740e commit 5df1d6e
Show file tree
Hide file tree
Showing 10 changed files with 233 additions and 42 deletions.
15 changes: 11 additions & 4 deletions airflow/providers/microsoft/azure/hooks/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models import Connection
from airflow.providers.microsoft.azure.utils import get_field
from airflow.utils import timezone


Expand All @@ -43,6 +44,14 @@ class AzureBatchHook(BaseHook):
conn_type = "azure_batch"
hook_name = "Azure Batch Service"

def _get_field(self, extras, name):
return get_field(
conn_id=self.conn_id,
conn_type=self.conn_type,
extras=extras,
field_name=name,
)

@staticmethod
def get_connection_form_widgets() -> dict[str, Any]:
"""Returns connection widgets to add to connection form"""
Expand All @@ -51,9 +60,7 @@ def get_connection_form_widgets() -> dict[str, Any]:
from wtforms import StringField

return {
"extra__azure_batch__account_url": StringField(
lazy_gettext("Batch Account URL"), widget=BS3TextFieldWidget()
),
"account_url": StringField(lazy_gettext("Batch Account URL"), widget=BS3TextFieldWidget()),
}

@staticmethod
Expand Down Expand Up @@ -85,7 +92,7 @@ def get_conn(self):
"""
conn = self._connection()

batch_account_url = conn.extra_dejson.get("extra__azure_batch__account_url")
batch_account_url = self._get_field(conn.extra_dejson, "account_url")
if not batch_account_url:
raise AirflowException("Batch Account URL parameter is missing.")

Expand Down
22 changes: 16 additions & 6 deletions airflow/providers/microsoft/azure/hooks/container_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from azure.mgmt.containerinstance.models import AzureFileVolume, Volume

from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import _ensure_prefixes, get_field


class AzureContainerVolumeHook(BaseHook):
Expand All @@ -42,6 +43,14 @@ def __init__(self, azure_container_volume_conn_id: str = "azure_container_volume
super().__init__()
self.conn_id = azure_container_volume_conn_id

def _get_field(self, extras, name):
return get_field(
conn_id=self.conn_id,
conn_type=self.conn_type,
extras=extras,
field_name=name,
)

@staticmethod
def get_connection_form_widgets() -> dict[str, Any]:
"""Returns connection widgets to add to connection form"""
Expand All @@ -50,12 +59,13 @@ def get_connection_form_widgets() -> dict[str, Any]:
from wtforms import PasswordField

return {
"extra__azure_container_volume__connection_string": PasswordField(
"connection_string": PasswordField(
lazy_gettext("Blob Storage Connection String (optional)"), widget=BS3PasswordFieldWidget()
),
}

@staticmethod
@_ensure_prefixes(conn_type="azure_container_volume")
def get_ui_field_behaviour() -> dict[str, Any]:
"""Returns custom field behaviour"""
return {
Expand All @@ -67,17 +77,17 @@ def get_ui_field_behaviour() -> dict[str, Any]:
"placeholders": {
"login": "client_id (token credentials auth)",
"password": "secret (token credentials auth)",
"extra__azure_container_volume__connection_string": "connection string auth",
"connection_string": "connection string auth",
},
}

def get_storagekey(self) -> str:
"""Get Azure File Volume storage key"""
conn = self.get_connection(self.conn_id)
service_options = conn.extra_dejson

if "extra__azure_container_volume__connection_string" in service_options:
for keyvalue in service_options["extra__azure_container_volume__connection_string"].split(";"):
extras = conn.extra_dejson
connection_string = self._get_field(extras, "connection_string")
if connection_string:
for keyvalue in connection_string.split(";"):
key, value = keyvalue.split("=", 1)
if key == "AccountKey":
return value
Expand Down
26 changes: 16 additions & 10 deletions airflow/providers/microsoft/azure/hooks/cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from airflow.exceptions import AirflowBadRequest
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import _ensure_prefixes, get_field


class AzureCosmosDBHook(BaseHook):
Expand Down Expand Up @@ -61,15 +62,16 @@ def get_connection_form_widgets() -> dict[str, Any]:
from wtforms import StringField

return {
"extra__azure_cosmos__database_name": StringField(
"database_name": StringField(
lazy_gettext("Cosmos Database Name (optional)"), widget=BS3TextFieldWidget()
),
"extra__azure_cosmos__collection_name": StringField(
"collection_name": StringField(
lazy_gettext("Cosmos Collection Name (optional)"), widget=BS3TextFieldWidget()
),
}

@staticmethod
@_ensure_prefixes(conn_type="azure_cosmos") # todo: remove when min airflow version >= 2.5
def get_ui_field_behaviour() -> dict[str, Any]:
"""Returns custom field behaviour"""
return {
Expand All @@ -81,8 +83,8 @@ def get_ui_field_behaviour() -> dict[str, Any]:
"placeholders": {
"login": "endpoint uri",
"password": "master key",
"extra__azure_cosmos__database_name": "database name",
"extra__azure_cosmos__collection_name": "collection name",
"database_name": "database name",
"collection_name": "collection name",
},
}

Expand All @@ -94,6 +96,14 @@ def __init__(self, azure_cosmos_conn_id: str = default_conn_name) -> None:
self.default_database_name = None
self.default_collection_name = None

def _get_field(self, extras, name):
return get_field(
conn_id=self.conn_id,
conn_type=self.conn_type,
extras=extras,
field_name=name,
)

def get_conn(self) -> CosmosClient:
"""Return a cosmos db client."""
if not self._conn:
Expand All @@ -102,12 +112,8 @@ def get_conn(self) -> CosmosClient:
endpoint_uri = conn.login
master_key = conn.password

self.default_database_name = extras.get("database_name") or extras.get(
"extra__azure_cosmos__database_name"
)
self.default_collection_name = extras.get("collection_name") or extras.get(
"extra__azure_cosmos__collection_name"
)
self.default_database_name = self._get_field(extras, "database_name")
self.default_collection_name = self._get_field(extras, "collection_name")

# Initialize the Python Azure Cosmos DB client
self._conn = CosmosClient(endpoint_uri, {"masterKey": master_key})
Expand Down
29 changes: 17 additions & 12 deletions airflow/providers/microsoft/azure/hooks/data_lake.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import _ensure_prefixes, get_field


class AzureDataLakeHook(BaseHook):
Expand All @@ -57,15 +58,14 @@ def get_connection_form_widgets() -> dict[str, Any]:
from wtforms import StringField

return {
"extra__azure_data_lake__tenant": StringField(
lazy_gettext("Azure Tenant ID"), widget=BS3TextFieldWidget()
),
"extra__azure_data_lake__account_name": StringField(
"tenant": StringField(lazy_gettext("Azure Tenant ID"), widget=BS3TextFieldWidget()),
"account_name": StringField(
lazy_gettext("Azure DataLake Store Name"), widget=BS3TextFieldWidget()
),
}

@staticmethod
@_ensure_prefixes(conn_type="azure_data_lake")
def get_ui_field_behaviour() -> dict[str, Any]:
"""Returns custom field behaviour"""
return {
Expand All @@ -77,8 +77,8 @@ def get_ui_field_behaviour() -> dict[str, Any]:
"placeholders": {
"login": "client id",
"password": "secret",
"extra__azure_data_lake__tenant": "tenant id",
"extra__azure_data_lake__account_name": "datalake store",
"tenant": "tenant id",
"account_name": "datalake store",
},
}

Expand All @@ -88,16 +88,21 @@ def __init__(self, azure_data_lake_conn_id: str = default_conn_name) -> None:
self._conn: core.AzureDLFileSystem | None = None
self.account_name: str | None = None

def _get_field(self, extras, name):
return get_field(
conn_id=self.conn_id,
conn_type=self.conn_type,
extras=extras,
field_name=name,
)

def get_conn(self) -> core.AzureDLFileSystem:
"""Return a AzureDLFileSystem object."""
if not self._conn:
conn = self.get_connection(self.conn_id)
service_options = conn.extra_dejson
self.account_name = service_options.get("account_name") or service_options.get(
"extra__azure_data_lake__account_name"
)
tenant = service_options.get("tenant") or service_options.get("extra__azure_data_lake__tenant")

extras = conn.extra_dejson
self.account_name = self._get_field(extras, "account_name")
tenant = self._get_field(extras, "tenant")
adl_creds = lib.auth(tenant_id=tenant, client_secret=conn.password, client_id=conn.login)
self._conn = core.AzureDLFileSystem(adl_creds, store_name=self.account_name)
self._conn.connect()
Expand Down
25 changes: 15 additions & 10 deletions airflow/providers/microsoft/azure/hooks/synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from airflow.exceptions import AirflowTaskTimeout
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import get_field

Credentials = Union[ClientSecretCredential, DefaultAzureCredential]

Expand Down Expand Up @@ -66,12 +67,8 @@ def get_connection_form_widgets() -> dict[str, Any]:
from wtforms import StringField

return {
"extra__azure_synapse__tenantId": StringField(
lazy_gettext("Tenant ID"), widget=BS3TextFieldWidget()
),
"extra__azure_synapse__subscriptionId": StringField(
lazy_gettext("Subscription ID"), widget=BS3TextFieldWidget()
),
"tenantId": StringField(lazy_gettext("Tenant ID"), widget=BS3TextFieldWidget()),
"subscriptionId": StringField(lazy_gettext("Subscription ID"), widget=BS3TextFieldWidget()),
}

@staticmethod
Expand All @@ -89,18 +86,26 @@ def __init__(self, azure_synapse_conn_id: str = default_conn_name, spark_pool: s
self.spark_pool = spark_pool
super().__init__()

def _get_field(self, extras, name):
return get_field(
conn_id=self.conn_id,
conn_type=self.conn_type,
extras=extras,
field_name=name,
)

def get_conn(self) -> SparkClient:
if self._conn is not None:
return self._conn

conn = self.get_connection(self.conn_id)
tenant = conn.extra_dejson.get("extra__azure_synapse__tenantId")
extras = conn.extra_dejson
tenant = self._get_field(extras, "tenantId")
spark_pool = self.spark_pool
livy_api_version = "2022-02-22-preview"

try:
subscription_id = conn.extra_dejson["extra__azure_synapse__subscriptionId"]
except KeyError:
subscription_id = self._get_field(extras, "subscriptionId")
if not subscription_id:
raise ValueError("A Subscription ID is required to connect to Azure Synapse.")

credential: Credentials
Expand Down
26 changes: 26 additions & 0 deletions airflow/providers/microsoft/azure/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

import warnings
from functools import wraps


Expand Down Expand Up @@ -46,3 +47,28 @@ def _ensure_prefix(field):
return inner

return dec


def get_field(*, conn_id: str, conn_type: str, extras: dict, field_name: str):
"""Get field from extra, first checking short name, then for backcompat we check for prefixed name."""
backcompat_prefix = f"extra__{conn_type}__"
backcompat_key = f"{backcompat_prefix}{field_name}"
ret = None
if field_name.startswith("extra__"):
raise ValueError(
f"Got prefixed name {field_name}; please remove the '{backcompat_prefix}' prefix "
"when using this method."
)
if field_name in extras:
if backcompat_key in extras:
warnings.warn(
f"Conflicting params `{field_name}` and `{backcompat_key}` found in extras for conn "
f"{conn_id}. Using value for `{field_name}`. Please ensure this is the correct "
f"value and remove the backcompat key `{backcompat_key}`."
)
ret = extras[field_name]
elif backcompat_key in extras:
ret = extras.get(backcompat_key)
if ret == "":
return None
return ret
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.container_volume import AzureContainerVolumeHook
from airflow.utils import db
from tests.test_utils.providers import get_provider_min_airflow_version


class TestAzureContainerVolumeHook(unittest.TestCase):
Expand Down Expand Up @@ -65,3 +66,20 @@ def test_get_file_volume_connection_string(self):
assert volume.azure_file.storage_account_key == "1"
assert volume.azure_file.storage_account_name == "storage"
assert volume.azure_file.read_only is True

def test_get_ui_field_behaviour_placeholders(self):
"""
Check that ensure_prefixes decorator working properly
Note: remove this test and the _ensure_prefixes decorator after min airflow version >= 2.5.0
"""
assert list(AzureContainerVolumeHook.get_ui_field_behaviour()["placeholders"].keys()) == [
"login",
"password",
"extra__azure_container_volume__connection_string",
]
if get_provider_min_airflow_version("apache-airflow-providers-microsoft-azure") >= (2, 5):
raise Exception(
"You must now remove `_ensure_prefixes` from azure utils."
" The functionality is now taken care of by providers manager."
)
19 changes: 19 additions & 0 deletions tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.cosmos import AzureCosmosDBHook
from airflow.utils import db
from tests.test_utils.providers import get_provider_min_airflow_version


class TestAzureCosmosDbHook(unittest.TestCase):
Expand Down Expand Up @@ -254,3 +255,21 @@ def test_connection_failure(self, mock_cosmos):
status, msg = hook.test_connection()
assert status is False
assert msg == "Authentication failed."

def test_get_ui_field_behaviour_placeholders(self):
"""
Check that ensure_prefixes decorator working properly
Note: remove this test and the _ensure_prefixes decorator after min airflow version >= 2.5.0
"""
assert list(AzureCosmosDBHook.get_ui_field_behaviour()["placeholders"].keys()) == [
"login",
"password",
"extra__azure_cosmos__database_name",
"extra__azure_cosmos__collection_name",
]
if get_provider_min_airflow_version("apache-airflow-providers-microsoft-azure") >= (2, 5):
raise Exception(
"You must now remove `_ensure_prefixes` from azure utils."
" The functionality is now taken care of by providers manager."
)

0 comments on commit 5df1d6e

Please sign in to comment.