Skip to content

Commit

Permalink
Refactor amazon providers tests which use moto (#27214)
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis committed Oct 27, 2022
1 parent 5e6cec8 commit 671029b
Show file tree
Hide file tree
Showing 33 changed files with 246 additions and 447 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def write_version(filename: str = str(AIRFLOW_SOURCES_ROOT / "airflow" / "git_ve
"jira",
"jsondiff",
"mongomock",
"moto[cloudformation, glue]>=3.1.12",
"moto[cloudformation, glue]>=4.0",
"parameterized",
"paramiko",
"pipdeptree",
Expand Down
34 changes: 0 additions & 34 deletions tests/providers/amazon/aws/hooks/conftest.py

This file was deleted.

16 changes: 2 additions & 14 deletions tests/providers/amazon/aws/hooks/test_base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,8 @@
import pytest
from botocore.config import Config
from botocore.credentials import ReadOnlyCredentials

try:
from moto.core import DEFAULT_ACCOUNT_ID
except ImportError:
from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID

from moto import mock_dynamodb, mock_emr, mock_iam, mock_sts
from moto.core import DEFAULT_ACCOUNT_ID

from airflow.models.connection import Connection
from airflow.providers.amazon.aws.hooks.base_aws import (
Expand Down Expand Up @@ -226,7 +221,6 @@ def test_create_session_from_credentials(self, mock_boto3_session, region_name,
mock_boto3_session.assert_called_once_with(**expected_arguments)
assert session == MOCK_BOTO3_SESSION

@pytest.mark.skipif(mock_sts is None, reason="mock_sts package not present")
@mock_sts
@pytest.mark.parametrize(
"conn_id, conn_extra",
Expand Down Expand Up @@ -260,7 +254,6 @@ def test_get_credentials_from_role_arn(self, conn_id, conn_extra, region_name):


class TestAwsBaseHook:
@unittest.skipIf(mock_emr is None, "mock_emr package not present")
@mock_emr
def test_get_client_type_set_in_class_attribute(self):
client = boto3.client("emr", region_name="us-east-1")
Expand Down Expand Up @@ -308,7 +301,6 @@ def test_get_session_returns_a_boto3_session(self):

assert table.item_count == 0

@unittest.skipIf(mock_sts is None, "mock_sts package not present")
@mock.patch.object(AwsBaseHook, "get_connection")
@mock_sts
def test_assume_role(self, mock_get_connection):
Expand Down Expand Up @@ -439,7 +431,6 @@ def import_mock(name, *args):
[mock.call.get_default_id_token_credentials(target_audience="aws-federation.airflow.apache.org")]
)

@unittest.skipIf(mock_sts is None, "mock_sts package not present")
@mock.patch.object(AwsBaseHook, "get_connection")
@mock_sts
def test_assume_role_with_saml(self, mock_get_connection):
Expand Down Expand Up @@ -531,7 +522,6 @@ def mock_assume_role_with_saml(**kwargs):
]
mock_boto3.assert_has_calls(calls_assume_role_with_saml)

@unittest.skipIf(mock_iam is None, "mock_iam package not present")
@mock_iam
def test_expand_role(self):
conn = boto3.client("iam", region_name="us-east-1")
Expand All @@ -547,7 +537,6 @@ def test_use_default_boto3_behaviour_without_conn_id(self):
# should cause no exception
hook.get_client_type("s3")

@unittest.skipIf(mock_sts is None, "mock_sts package not present")
@mock.patch.object(AwsBaseHook, "get_connection")
@mock_sts
def test_refreshable_credentials(self, mock_get_connection):
Expand Down Expand Up @@ -668,7 +657,6 @@ def test_connection_client_resource_types_check(self, client_type, resource_type
with pytest.raises(ValueError, match="Either client_type=.* or resource_type=.* must be provided"):
hook.get_conn()

@unittest.skipIf(mock_sts is None, "mock_sts package not present")
@mock_sts
def test_hook_connection_test(self):
hook = AwsBaseHook(client_type="s3")
Expand Down Expand Up @@ -849,7 +837,7 @@ def _non_retryable_test(thing):
return thing()


class TestRetryDecorator(unittest.TestCase): # ptlint: disable=invalid-name
class TestRetryDecorator: # ptlint: disable=invalid-name
def test_do_nothing_on_non_exception(self):
result = _retryable_test(lambda: 42)
assert result, 42
Expand Down
13 changes: 4 additions & 9 deletions tests/providers/amazon/aws/hooks/test_cloud_formation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,14 @@
from __future__ import annotations

import json
import unittest

from airflow.providers.amazon.aws.hooks.cloud_formation import CloudFormationHook
from moto import mock_cloudformation

try:
from moto import mock_cloudformation
except ImportError:
mock_cloudformation = None
from airflow.providers.amazon.aws.hooks.cloud_formation import CloudFormationHook


@unittest.skipIf(mock_cloudformation is None, "moto package not present")
class TestCloudFormationHook(unittest.TestCase):
def setUp(self):
class TestCloudFormationHook:
def setup_method(self):
self.hook = CloudFormationHook(aws_conn_id="aws_default")

def create_stack(self, stack_name):
Expand Down
96 changes: 36 additions & 60 deletions tests/providers/amazon/aws/hooks/test_ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@
# under the License.
from __future__ import annotations

import unittest

import pytest
from moto import mock_ec2

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook


class TestEC2Hook(unittest.TestCase):
class TestEC2Hook:
def test_init(self):
ec2_hook = EC2Hook(
aws_conn_id="aws_conn_test",
Expand All @@ -37,6 +36,27 @@ def test_init(self):
assert ec2_hook.aws_conn_id == "aws_conn_test"
assert ec2_hook.region_name == "region-test"

@classmethod
def _create_instances(cls, hook: EC2Hook, max_count=1, min_count=1):
"""Create Instances and return all instance ids."""
conn = hook.get_conn()
try:
ec2_client = conn.meta.client
except AttributeError:
ec2_client = conn

# We need existed AMI Image ID otherwise `moto` will raise DeprecationWarning.
images = ec2_client.describe_images()["Images"]
response = ec2_client.run_instances(
MaxCount=max_count, MinCount=min_count, ImageId=images[0]["ImageId"]
)
return [instance["InstanceId"] for instance in response["Instances"]]

@classmethod
def _create_instance(cls, hook: EC2Hook):
"""Create Instance and return instance id."""
return cls._create_instances(hook)[0]

@mock_ec2
def test_get_conn_returns_boto3_resource(self):
ec2_hook = EC2Hook()
Expand All @@ -52,24 +72,15 @@ def test_client_type_get_conn_returns_boto3_resource(self):
@mock_ec2
def test_get_instance(self):
ec2_hook = EC2Hook()
created_instances = ec2_hook.conn.create_instances(
MaxCount=1,
MinCount=1,
)
created_instance_id = created_instances[0].instance_id
created_instance_id = self._create_instance(ec2_hook)
# test get_instance method
existing_instance = ec2_hook.get_instance(instance_id=created_instance_id)
assert created_instance_id == existing_instance.instance_id

@mock_ec2
def test_get_instance_state(self):
ec2_hook = EC2Hook()
created_instances = ec2_hook.conn.create_instances(
MaxCount=1,
MinCount=1,
)

created_instance_id = created_instances[0].instance_id
created_instance_id = self._create_instance(ec2_hook)
all_instances = list(ec2_hook.conn.instances.all())
created_instance_state = all_instances[0].state["Name"]
# test get_instance_state method
Expand All @@ -79,12 +90,7 @@ def test_get_instance_state(self):
@mock_ec2
def test_client_type_get_instance_state(self):
ec2_hook = EC2Hook(api_type="client_type")
created_instances = ec2_hook.conn.run_instances(
MaxCount=1,
MinCount=1,
)

created_instance_id = created_instances["Instances"][0]["InstanceId"]
created_instance_id = self._create_instance(ec2_hook)
all_instances = ec2_hook.get_instances()
created_instance_state = all_instances[0]["State"]["Name"]

Expand All @@ -94,12 +100,7 @@ def test_client_type_get_instance_state(self):
@mock_ec2
def test_client_type_start_instances(self):
ec2_hook = EC2Hook(api_type="client_type")
created_instances = ec2_hook.conn.run_instances(
MaxCount=1,
MinCount=1,
)

created_instance_id = created_instances["Instances"][0]["InstanceId"]
created_instance_id = self._create_instance(ec2_hook)
response = ec2_hook.start_instances(instance_ids=[created_instance_id])

assert response["StartingInstances"][0]["InstanceId"] == created_instance_id
Expand All @@ -108,12 +109,7 @@ def test_client_type_start_instances(self):
@mock_ec2
def test_client_type_stop_instances(self):
ec2_hook = EC2Hook(api_type="client_type")
created_instances = ec2_hook.conn.run_instances(
MaxCount=1,
MinCount=1,
)

created_instance_id = created_instances["Instances"][0]["InstanceId"]
created_instance_id = self._create_instance(ec2_hook)
response = ec2_hook.stop_instances(instance_ids=[created_instance_id])

assert response["StoppingInstances"][0]["InstanceId"] == created_instance_id
Expand All @@ -122,12 +118,7 @@ def test_client_type_stop_instances(self):
@mock_ec2
def test_client_type_terminate_instances(self):
ec2_hook = EC2Hook(api_type="client_type")
created_instances = ec2_hook.conn.run_instances(
MaxCount=1,
MinCount=1,
)

created_instance_id = created_instances["Instances"][0]["InstanceId"]
created_instance_id = self._create_instance(ec2_hook)
response = ec2_hook.terminate_instances(instance_ids=[created_instance_id])

assert response["TerminatingInstances"][0]["InstanceId"] == created_instance_id
Expand All @@ -136,12 +127,7 @@ def test_client_type_terminate_instances(self):
@mock_ec2
def test_client_type_describe_instances(self):
ec2_hook = EC2Hook(api_type="client_type")
created_instances = ec2_hook.conn.run_instances(
MaxCount=1,
MinCount=1,
)

created_instance_id = created_instances["Instances"][0]["InstanceId"]
created_instance_id = self._create_instance(ec2_hook)

# Without filter
response = ec2_hook.describe_instances(instance_ids=[created_instance_id])
Expand All @@ -168,13 +154,8 @@ def test_client_type_describe_instances(self):
@mock_ec2
def test_client_type_get_instances(self):
ec2_hook = EC2Hook(api_type="client_type")
created_instances = ec2_hook.conn.run_instances(
MaxCount=2,
MinCount=2,
)

created_instance_id_1 = created_instances["Instances"][0]["InstanceId"]
created_instance_id_2 = created_instances["Instances"][1]["InstanceId"]
created_instances = self._create_instances(ec2_hook, max_count=2, min_count=2)
created_instance_id_1, created_instance_id_2 = created_instances

# Without filter
response = ec2_hook.get_instances(instance_ids=[created_instance_id_1, created_instance_id_2])
Expand Down Expand Up @@ -210,13 +191,8 @@ def test_client_type_get_instances(self):
@mock_ec2
def test_client_type_get_instance_ids(self):
ec2_hook = EC2Hook(api_type="client_type")
created_instances = ec2_hook.conn.run_instances(
MaxCount=2,
MinCount=2,
)

created_instance_id_1 = created_instances["Instances"][0]["InstanceId"]
created_instance_id_2 = created_instances["Instances"][1]["InstanceId"]
created_instances = self._create_instances(ec2_hook, max_count=2, min_count=2)
created_instance_id_1, created_instance_id_2 = created_instances

# Without filter
response = ec2_hook.get_instance_ids()
Expand Down Expand Up @@ -244,12 +220,12 @@ def test_decorator_only_client_type(self):
ec2_hook = EC2Hook()

# Try calling a method which is only supported by client_type API
with self.assertRaises(AirflowException):
with pytest.raises(AirflowException):
ec2_hook.get_instances()

# Explicitly provide resource_type as api_type
ec2_hook = EC2Hook(api_type="resource_type")

# Try calling a method which is only supported by client_type API
with self.assertRaises(AirflowException):
with pytest.raises(AirflowException):
ec2_hook.describe_instances()
6 changes: 0 additions & 6 deletions tests/providers/amazon/aws/hooks/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,6 @@
from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart
from airflow.providers.amazon.aws.hooks.ecs import EcsHook, EcsTaskLogFetcher, should_retry, should_retry_eni

try:
from moto import mock_ecs
except ImportError:
mock_ecs = None

DEFAULT_CONN_ID: str = "aws_default"
REGION: str = "us-east-1"

Expand All @@ -41,7 +36,6 @@ def mock_conn():
yield _conn


@pytest.mark.skipif(mock_ecs is None, reason="mock_ecs package not present")
class TestEksHooks:
def test_hook(self) -> None:
hook = EcsHook(region_name=REGION)
Expand Down

0 comments on commit 671029b

Please sign in to comment.