Skip to content

Commit

Permalink
Resolve Amazon Hook's region_name and config in wrapper (#25336)
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis committed Aug 5, 2022
1 parent df5a54d commit 4193558
Show file tree
Hide file tree
Showing 4 changed files with 306 additions and 73 deletions.
111 changes: 66 additions & 45 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

from airflow.compat.functools import cached_property
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.hooks.base import BaseHook
from airflow.models.connection import Connection
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
Expand All @@ -66,19 +66,24 @@ class BaseSessionFactory(LoggingMixin):
"""

def __init__(
self, conn: Union[Connection, AwsConnectionWrapper], region_name: Optional[str], config: Config
self,
conn: Optional[Union[Connection, AwsConnectionWrapper]],
region_name: Optional[str] = None,
config: Optional[Config] = None,
) -> None:
super().__init__()
self._conn = conn
self._region_name = region_name
self.config = config
self._config = config

@cached_property
def conn(self) -> AwsConnectionWrapper:
"""Cached AWS Connection Wrapper."""
if isinstance(self._conn, AwsConnectionWrapper):
return self._conn
return AwsConnectionWrapper(self._conn)
return AwsConnectionWrapper(
conn=self._conn,
region_name=self._region_name,
botocore_config=self._config,
)

@cached_property
def basic_session(self) -> boto3.session.Session:
Expand All @@ -92,21 +97,29 @@ def extra_config(self) -> Dict[str, Any]:

@property
def region_name(self) -> Optional[str]:
"""Resolve region name.
"""AWS Region Name read-only property."""
return self.conn.region_name

1. SessionFactory region_name
2. Connection region_name
"""
return self._region_name or self.conn.region_name
@property
def config(self) -> Optional[Config]:
"""Configuration for botocore client read-only property."""
return self.conn.botocore_config

@property
def role_arn(self) -> Optional[str]:
"""Assume Role ARN from AWS Connection"""
return self.conn.role_arn

def create_session(self) -> boto3.session.Session:
"""Create AWS session."""
if not self.role_arn:
"""Create boto3 Session from connection config."""
if not self.conn:
self.log.info(
"No connection ID provided. Fallback on boto3 credential strategy (region_name=%r). "
"See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html",
self.region_name,
)
return boto3.session.Session(region_name=self.region_name)
elif not self.role_arn:
return self.basic_session
return self._create_session_with_assume_role(session_kwargs=self.conn.session_kwargs)

Expand Down Expand Up @@ -381,45 +394,50 @@ def __init__(
self.verify = verify
self.client_type = client_type
self.resource_type = resource_type
self.region_name = region_name
self.config = config

def _get_credentials(self, region_name: Optional[str]) -> Tuple[boto3.session.Session, Optional[str]]:

if not self.aws_conn_id:
session = boto3.session.Session(region_name=region_name)
return session, None
self._region_name = region_name
self._config = config

self.log.debug("Airflow Connection: aws_conn_id=%s", self.aws_conn_id)
@cached_property
def conn_config(self) -> AwsConnectionWrapper:
"""Get the Airflow Connection object and wrap it in helper (cached)."""
connection = None
if self.aws_conn_id:
try:
connection = self.get_connection(self.aws_conn_id)
except AirflowNotFoundException:
warnings.warn(
f"Unable to find AWS Connection ID '{self.aws_conn_id}', switching to empty. "
"This behaviour is deprecated and will be removed in a future releases. "
"Please provide existed AWS connection ID or if required boto3 credential strategy "
"explicit set AWS Connection ID to None.",
DeprecationWarning,
stacklevel=2,
)

try:
# Fetch the Airflow connection object and wrap it in helper
connection_object = AwsConnectionWrapper(self.get_connection(self.aws_conn_id))
return AwsConnectionWrapper(
conn=connection or Connection(conn_id=None, conn_type="aws"),
region_name=self._region_name,
botocore_config=self._config,
)

if connection_object.botocore_config:
# For historical reason botocore.config.Config from connection overwrites
# config which explicitly set in Hook.
self.config = connection_object.botocore_config
@property
def region_name(self) -> Optional[str]:
"""AWS Region Name read-only property."""
return self.conn_config.region_name

session = SessionFactory(
conn=connection_object, region_name=region_name, config=self.config
).create_session()
@property
def config(self) -> Optional[Config]:
"""Configuration for botocore client read-only property."""
return self.conn_config.botocore_config

return session, connection_object.endpoint_url
def _get_credentials(self, region_name: Optional[str]) -> Tuple[boto3.session.Session, Optional[str]]:
self.log.debug("Airflow Connection: aws_conn_id=%s", self.aws_conn_id)

except AirflowException:
self.log.warning(
"Unable to use Airflow Connection for credentials. "
"Fallback on boto3 credential strategy. See: "
"https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html"
)
session = SessionFactory(
conn=self.conn_config, region_name=region_name, config=self.config
).create_session()

self.log.debug(
"Creating session using boto3 credential strategy region_name=%s",
region_name,
)
session = boto3.session.Session(region_name=region_name)
return session, None
return session, self.conn_config.endpoint_url

def get_client_type(
self,
Expand Down Expand Up @@ -491,17 +509,20 @@ def conn(self) -> Union[boto3.client, boto3.resource]:

@cached_property
def conn_client_meta(self) -> ClientMeta:
"""Get botocore client metadata from Hook connection (cached)."""
conn = self.conn
if isinstance(conn, botocore.client.BaseClient):
return conn.meta
return conn.meta.client.meta

@property
def conn_region_name(self) -> str:
"""Get actual AWS Region Name from Hook connection (cached)."""
return self.conn_client_meta.region_name

@property
def conn_partition(self) -> str:
"""Get associated AWS Region Partition from Hook connection (cached)."""
return self.conn_client_meta.partition

def get_conn(self) -> BaseAwsConnection:
Expand Down
64 changes: 53 additions & 11 deletions airflow/providers/amazon/aws/utils/connection_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

import warnings
from copy import deepcopy
from dataclasses import InitVar, dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from dataclasses import MISSING, InitVar, dataclass, field, fields
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union

from botocore.config import Config

Expand All @@ -35,25 +35,43 @@ class AwsConnectionWrapper(LoggingMixin):
"""
AWS Connection Wrapper class helper.
Use for validate and resolve AWS Connection parameters.
``conn`` reference to Airflow Connection object or AwsConnectionWrapper
if it set to ``None`` than default values would use.
The precedence rules for ``region_name``
1. Explicit set (in Hook) ``region_name``.
2. Airflow Connection Extra 'region_name'.
The precedence rules for ``botocore_config``
1. Explicit set (in Hook) ``botocore_config``.
2. Construct from Airflow Connection Extra 'botocore_kwargs'.
3. The wrapper's default value
"""

conn: InitVar[Optional["Connection"]]
conn: InitVar[Optional[Union["Connection", "AwsConnectionWrapper"]]]
region_name: Optional[str] = field(default=None)
botocore_config: Optional[Config] = field(default=None)

# Reference to Airflow Connection attributes
# ``extra_config`` contains original Airflow Connection Extra.
conn_id: Optional[str] = field(init=False, default=None)
conn_type: Optional[str] = field(init=False, default=None)
login: Optional[str] = field(init=False, repr=False, default=None)
password: Optional[str] = field(init=False, repr=False, default=None)
extra_config: Dict[str, Any] = field(init=False, repr=False, default_factory=dict)

aws_access_key_id: Optional[str] = field(init=False)
aws_secret_access_key: Optional[str] = field(init=False)
aws_session_token: Optional[str] = field(init=False)
# AWS Credentials from connection.
aws_access_key_id: Optional[str] = field(init=False, default=None)
aws_secret_access_key: Optional[str] = field(init=False, default=None)
aws_session_token: Optional[str] = field(init=False, default=None)

region_name: Optional[str] = field(init=False, default=None)
# Additional boto3.session.Session keyword arguments.
session_kwargs: Dict[str, Any] = field(init=False, default_factory=dict)
botocore_config: Optional[Config] = field(init=False, default=None)
# Custom endpoint_url for boto3.client and boto3.resource
endpoint_url: Optional[str] = field(init=False, default=None)

# Assume Role Configurations
role_arn: Optional[str] = field(init=False, default=None)
assume_role_method: Optional[str] = field(init=False, default=None)
assume_role_kwargs: Dict[str, Any] = field(init=False, default_factory=dict)
Expand All @@ -63,7 +81,30 @@ def conn_repr(self):
return f"AWS Connection (conn_id={self.conn_id!r}, conn_type={self.conn_type!r})"

def __post_init__(self, conn: "Connection"):
if not conn:
if isinstance(conn, type(self)):
# For every field with init=False we copy reference value from original wrapper
# For every field with init=True we use init values if it not equal default
# We can't use ``dataclasses.replace`` in classmethod because
# we limited by InitVar arguments since it not stored in object,
# and also we do not want to run __post_init__ method again which print all logs/warnings again.
for fl in fields(conn):
value = getattr(conn, fl.name)
if not fl.init:
setattr(self, fl.name, value)
else:
if fl.default is not MISSING:
default = fl.default
elif fl.default_factory is not MISSING:
default = fl.default_factory() # zero-argument callable
else:
continue # Value mandatory, skip

orig_value = getattr(self, fl.name)
if orig_value == default:
# Only replace value if it not equal default value
setattr(self, fl.name, value)
return
elif not conn:
return

extra = deepcopy(conn.extra_dejson)
Expand All @@ -86,7 +127,7 @@ def __post_init__(self, conn: "Connection"):
init_credentials = self._get_credentials(**extra)
self.aws_access_key_id, self.aws_secret_access_key, self.aws_session_token = init_credentials

if "region_name" in extra:
if not self.region_name and "region_name" in extra:
self.region_name = extra["region_name"]
self.log.info("Retrieving region_name=%s from %s extra.", self.region_name, self.conn_repr)

Expand All @@ -106,7 +147,7 @@ def __post_init__(self, conn: "Connection"):
)

config_kwargs = extra.get("config_kwargs")
if config_kwargs:
if not self.botocore_config and config_kwargs:
# https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
self.log.info("Retrieving botocore config=%s from %s extra.", config_kwargs, self.conn_repr)
self.botocore_config = Config(**config_kwargs)
Expand All @@ -119,6 +160,7 @@ def __post_init__(self, conn: "Connection"):

@property
def extra_dejson(self):
"""Compatibility with `airflow.models.Connection.extra_dejson` property."""
return self.extra_config

def __bool__(self):
Expand Down

0 comments on commit 4193558

Please sign in to comment.