diff --git a/tests/always/test_secrets.py b/tests/always/test_secrets.py index df076eead3d6e..4c67a46fe8e0d 100644 --- a/tests/always/test_secrets.py +++ b/tests/always/test_secrets.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from airflow.configuration import ensure_secrets_loaded, initialize_secrets_backends @@ -26,7 +25,7 @@ from tests.test_utils.db import clear_db_variables -class TestConnectionsFromSecrets(unittest.TestCase): +class TestConnectionsFromSecrets: @mock.patch("airflow.secrets.metastore.MetastoreBackend.get_connection") @mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_connection") def test_get_connection_second_try(self, mock_env_get, mock_meta_get): @@ -112,11 +111,11 @@ def test_backend_fallback_to_env_var(self, mock_get_connection): assert 'mysql://airflow:airflow@host:5432/airflow' == conn.get_uri() -class TestVariableFromSecrets(unittest.TestCase): - def setUp(self) -> None: +class TestVariableFromSecrets: + def setup_method(self) -> None: clear_db_variables() - def tearDown(self) -> None: + def teardown_method(self) -> None: clear_db_variables() @mock.patch("airflow.secrets.metastore.MetastoreBackend.get_variable") diff --git a/tests/always/test_secrets_backends.py b/tests/always/test_secrets_backends.py index 947053848c6ee..212a578310f9c 100644 --- a/tests/always/test_secrets_backends.py +++ b/tests/always/test_secrets_backends.py @@ -18,10 +18,9 @@ from __future__ import annotations import os -import unittest from unittest import mock -from parameterized import parameterized +import pytest from airflow.models.connection import Connection from airflow.models.variable import Variable @@ -41,21 +40,23 @@ def __init__(self, conn_id, variation: str): self.conn = Connection(conn_id=self.conn_id, uri=self.conn_uri) -class TestBaseSecretsBackend(unittest.TestCase): - def setUp(self) -> None: +class TestBaseSecretsBackend: + def setup_method(self) -> None: clear_db_variables() - def tearDown(self) -> None: + def teardown_method(self) -> None: clear_db_connections() clear_db_variables() - @parameterized.expand( + @pytest.mark.parametrize( + "kwargs, output", [ - ('default', {"path_prefix": "PREFIX", "secret_id": "ID"}, "PREFIX/ID"), - ('with_sep', {"path_prefix": "PREFIX", "secret_id": "ID", "sep": "-"}, "PREFIX-ID"), - ] + ({"path_prefix": "PREFIX", "secret_id": "ID"}, "PREFIX/ID"), + ({"path_prefix": "PREFIX", "secret_id": "ID", "sep": "-"}, "PREFIX-ID"), + ], + ids=["default", "with_sep"], ) - def test_build_path(self, _, kwargs, output): + def test_build_path(self, kwargs, output): build_path = BaseSecretsBackend.build_path assert build_path(**kwargs) == output diff --git a/tests/always/test_secrets_local_filesystem.py b/tests/always/test_secrets_local_filesystem.py index 2d2a49a1a6f34..bb2c40abf708a 100644 --- a/tests/always/test_secrets_local_filesystem.py +++ b/tests/always/test_secrets_local_filesystem.py @@ -18,13 +18,11 @@ import json import re -import unittest from contextlib import contextmanager from tempfile import NamedTemporaryFile from unittest import mock import pytest -from parameterized import parameterized from airflow.configuration import ensure_secrets_loaded from airflow.exceptions import AirflowException, AirflowFileParseException, ConnectionNotUnique @@ -42,24 +40,26 @@ def mock_local_file(content): yield file_mock -class FileParsers(unittest.TestCase): - @parameterized.expand( - ( +class TestFileParsers: + @pytest.mark.parametrize( + "content, expected_message", + [ ("AA", 'Invalid line format. The line should contain at least one equal sign ("=")'), ("=", "Invalid line format. Key is empty."), - ) + ], ) def test_env_file_invalid_format(self, content, expected_message): with mock_local_file(content): with pytest.raises(AirflowFileParseException, match=re.escape(expected_message)): local_filesystem.load_variables("a.env") - @parameterized.expand( - ( + @pytest.mark.parametrize( + "content, expected_message", + [ ("[]", "The file should contain the object."), ("{AAAAA}", "Expecting property name enclosed in double quotes"), ("", "The file is empty."), - ) + ], ) def test_json_file_invalid_format(self, content, expected_message): with mock_local_file(content): @@ -67,34 +67,41 @@ def test_json_file_invalid_format(self, content, expected_message): local_filesystem.load_variables("a.json") -class TestLoadVariables(unittest.TestCase): - @parameterized.expand( - ( +class TestLoadVariables: + @pytest.mark.parametrize( + "file_content, expected_variables", + [ ("", {}), ("KEY=AAA", {"KEY": "AAA"}), ("KEY_A=AAA\nKEY_B=BBB", {"KEY_A": "AAA", "KEY_B": "BBB"}), ("KEY_A=AAA\n # AAAA\nKEY_B=BBB", {"KEY_A": "AAA", "KEY_B": "BBB"}), ("\n\n\n\nKEY_A=AAA\n\n\n\n\nKEY_B=BBB\n\n\n", {"KEY_A": "AAA", "KEY_B": "BBB"}), - ) + ], ) def test_env_file_should_load_variables(self, file_content, expected_variables): with mock_local_file(file_content): variables = local_filesystem.load_variables("a.env") assert expected_variables == variables - @parameterized.expand((("AA=A\nAA=B", "The \"a.env\" file contains multiple values for keys: ['AA']"),)) + @pytest.mark.parametrize( + "content, expected_message", + [ + ("AA=A\nAA=B", "The \"a.env\" file contains multiple values for keys: ['AA']"), + ], + ) def test_env_file_invalid_logic(self, content, expected_message): with mock_local_file(content): with pytest.raises(AirflowException, match=re.escape(expected_message)): local_filesystem.load_variables("a.env") - @parameterized.expand( - ( + @pytest.mark.parametrize( + "file_content, expected_variables", + [ ({}, {}), ({"KEY": "AAA"}, {"KEY": "AAA"}), ({"KEY_A": "AAA", "KEY_B": "BBB"}, {"KEY_A": "AAA", "KEY_B": "BBB"}), ({"KEY_A": "AAA", "KEY_B": "BBB"}, {"KEY_A": "AAA", "KEY_B": "BBB"}), - ) + ], ) def test_json_file_should_load_variables(self, file_content, expected_variables): with mock_local_file(json.dumps(file_content)): @@ -109,8 +116,9 @@ def test_missing_file(self, mock_exists): ): local_filesystem.load_variables("a.json") - @parameterized.expand( - ( + @pytest.mark.parametrize( + "file_content, expected_variables", + [ ("KEY: AAA", {"KEY": "AAA"}), ( """ @@ -119,7 +127,7 @@ def test_missing_file(self, mock_exists): """, {"KEY_A": "AAA", "KEY_B": "BBB"}, ), - ) + ], ) def test_yaml_file_should_load_variables(self, file_content, expected_variables): with mock_local_file(file_content): @@ -128,9 +136,10 @@ def test_yaml_file_should_load_variables(self, file_content, expected_variables) assert expected_variables == vars_yaml == vars_yml -class TestLoadConnection(unittest.TestCase): - @parameterized.expand( - ( +class TestLoadConnection: + @pytest.mark.parametrize( + "file_content, expected_connection_uris", + [ ("CONN_ID=mysql://host_1/", {"CONN_ID": "mysql://host_1"}), ( "CONN_ID1=mysql://host_1/\nCONN_ID2=mysql://host_2/", @@ -144,7 +153,7 @@ class TestLoadConnection(unittest.TestCase): "\n\n\n\nCONN_ID1=mysql://host_1/\n\n\n\n\nCONN_ID2=mysql://host_2/\n\n\n", {"CONN_ID1": "mysql://host_1", "CONN_ID2": "mysql://host_2"}, ), - ) + ], ) def test_env_file_should_load_connection(self, file_content, expected_connection_uris): with mock_local_file(file_content): @@ -155,13 +164,14 @@ def test_env_file_should_load_connection(self, file_content, expected_connection assert expected_connection_uris == connection_uris_by_conn_id - @parameterized.expand( - ( + @pytest.mark.parametrize( + "content, expected_connection_uris", + [ ( "CONN_ID=mysql://host_1/?param1=val1¶m2=val2", {"CONN_ID": "mysql://host_1/?param1=val1¶m2=val2"}, ), - ) + ], ) def test_parsing_with_params(self, content, expected_connection_uris): with mock_local_file(content): @@ -172,24 +182,26 @@ def test_parsing_with_params(self, content, expected_connection_uris): assert expected_connection_uris == connection_uris_by_conn_id - @parameterized.expand( - ( + @pytest.mark.parametrize( + "content, expected_message", + [ ("AA", 'Invalid line format. The line should contain at least one equal sign ("=")'), ("=", "Invalid line format. Key is empty."), - ) + ], ) def test_env_file_invalid_format(self, content, expected_message): with mock_local_file(content): with pytest.raises(AirflowFileParseException, match=re.escape(expected_message)): local_filesystem.load_connections_dict("a.env") - @parameterized.expand( - ( + @pytest.mark.parametrize( + "file_content, expected_connection_uris", + [ ({"CONN_ID": "mysql://host_1"}, {"CONN_ID": "mysql://host_1"}), ({"CONN_ID": ["mysql://host_1"]}, {"CONN_ID": "mysql://host_1"}), ({"CONN_ID": {"uri": "mysql://host_1"}}, {"CONN_ID": "mysql://host_1"}), ({"CONN_ID": [{"uri": "mysql://host_1"}]}, {"CONN_ID": "mysql://host_1"}), - ) + ], ) def test_json_file_should_load_connection(self, file_content, expected_connection_uris): with mock_local_file(json.dumps(file_content)): @@ -200,8 +212,9 @@ def test_json_file_should_load_connection(self, file_content, expected_connectio assert expected_connection_uris == connection_uris_by_conn_id - @parameterized.expand( - ( + @pytest.mark.parametrize( + "file_content, expected_connection_uris", + [ ({"CONN_ID": None}, "Unexpected value type: ."), ({"CONN_ID": 1}, "Unexpected value type: ."), ({"CONN_ID": [2]}, "Unexpected value type: ."), @@ -209,7 +222,7 @@ def test_json_file_should_load_connection(self, file_content, expected_connectio ({"CONN_ID": {"AAA": "mysql://host_1"}}, "The object have illegal keys: AAA."), ({"CONN_ID": {"conn_id": "BBBB"}}, "Mismatch conn_id."), ({"CONN_ID": ["mysql://", "mysql://"]}, "Found multiple values for CONN_ID in a.json."), - ) + ], ) def test_env_file_invalid_input(self, file_content, expected_connection_uris): with mock_local_file(json.dumps(file_content)): @@ -224,8 +237,9 @@ def test_missing_file(self, mock_exists): ): local_filesystem.load_connections_dict("a.json") - @parameterized.expand( - ( + @pytest.mark.parametrize( + "file_content, expected_attrs_dict", + [ ( """CONN_A: 'mysql://host_a'""", {"CONN_A": {'conn_type': 'mysql', 'host': 'host_a'}}, @@ -262,7 +276,7 @@ def test_missing_file(self, mock_exists): }, }, ), - ) + ], ) def test_yaml_file_should_load_connection(self, file_content, expected_attrs_dict): with mock_local_file(file_content): @@ -272,8 +286,9 @@ def test_yaml_file_should_load_connection(self, file_content, expected_attrs_dic actual_attrs = {k: getattr(connection, k) for k in expected_attrs.keys()} assert actual_attrs == expected_attrs - @parameterized.expand( - ( + @pytest.mark.parametrize( + "file_content, expected_extras", + [ ( """ conn_c: @@ -323,7 +338,7 @@ def test_yaml_file_should_load_connection(self, file_content, expected_attrs_dic """, {"conn_d": {"extra__google_cloud_platform__keyfile_dict": {"a": "b"}}}, ), - ) + ], ) def test_yaml_file_should_load_connection_extras(self, file_content, expected_extras): with mock_local_file(file_content): @@ -333,8 +348,9 @@ def test_yaml_file_should_load_connection_extras(self, file_content, expected_ex } assert expected_extras == connection_uris_by_conn_id - @parameterized.expand( - ( + @pytest.mark.parametrize( + "file_content, expected_message", + [ ( """conn_c: conn_type: scheme @@ -351,50 +367,46 @@ def test_yaml_file_should_load_connection_extras(self, file_content, expected_ex """, "The extra and extra_dejson parameters are mutually exclusive.", ), - ) + ], ) def test_yaml_invalid_extra(self, file_content, expected_message): with mock_local_file(file_content): with pytest.raises(AirflowException, match=re.escape(expected_message)): local_filesystem.load_connections_dict("a.yaml") - @parameterized.expand( - ("CONN_ID=mysql://host_1/\nCONN_ID=mysql://host_2/",), - ) + @pytest.mark.parametrize("file_content", ["CONN_ID=mysql://host_1/\nCONN_ID=mysql://host_2/"]) def test_ensure_unique_connection_env(self, file_content): with mock_local_file(file_content): with pytest.raises(ConnectionNotUnique): local_filesystem.load_connections_dict("a.env") - @parameterized.expand( - ( - ({"CONN_ID": ["mysql://host_1", "mysql://host_2"]},), - ({"CONN_ID": [{"uri": "mysql://host_1"}, {"uri": "mysql://host_2"}]},), - ) + @pytest.mark.parametrize( + "file_content", + [ + {"CONN_ID": ["mysql://host_1", "mysql://host_2"]}, + {"CONN_ID": [{"uri": "mysql://host_1"}, {"uri": "mysql://host_2"}]}, + ], ) def test_ensure_unique_connection_json(self, file_content): with mock_local_file(json.dumps(file_content)): with pytest.raises(ConnectionNotUnique): local_filesystem.load_connections_dict("a.json") - @parameterized.expand( - ( - ( - """ + @pytest.mark.parametrize( + "file_content", + [ + """ conn_a: - mysql://hosta - mysql://hostb""" - ), - ), + ], ) def test_ensure_unique_connection_yaml(self, file_content): with mock_local_file(file_content): with pytest.raises(ConnectionNotUnique): local_filesystem.load_connections_dict("a.yaml") - @parameterized.expand( - (("conn_a: mysql://hosta"),), - ) + @pytest.mark.parametrize("file_content", ["conn_a: mysql://hosta"]) def test_yaml_extension_parsers_return_same_result(self, file_content): with mock_local_file(file_content): conn_uri_by_conn_id_yaml = { @@ -408,7 +420,7 @@ def test_yaml_extension_parsers_return_same_result(self, file_content): assert conn_uri_by_conn_id_yaml == conn_uri_by_conn_id_yml -class TestLocalFileBackend(unittest.TestCase): +class TestLocalFileBackend: def test_should_read_variable(self): with NamedTemporaryFile(suffix="var.env") as tmp_file: tmp_file.write(b"KEY_A=VAL_A") diff --git a/tests/api/auth/test_client.py b/tests/api/auth/test_client.py index d0454d623f4e5..0f6758ad2b8a4 100644 --- a/tests/api/auth/test_client.py +++ b/tests/api/auth/test_client.py @@ -16,14 +16,13 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from airflow.api.client import get_current_api_client from tests.test_utils.config import conf_vars -class TestGetCurrentApiClient(unittest.TestCase): +class TestGetCurrentApiClient: @mock.patch("airflow.api.client.json_client.Client") @mock.patch("airflow.api.auth.backend.default.CLIENT_AUTH", "CLIENT_AUTH") @conf_vars( diff --git a/tests/api/client/test_local_client.py b/tests/api/client/test_local_client.py index 2210bc2f3c9b6..3fecedabb6289 100644 --- a/tests/api/client/test_local_client.py +++ b/tests/api/client/test_local_client.py @@ -20,7 +20,6 @@ import json import random import string -import unittest from unittest.mock import patch import pendulum @@ -42,20 +41,17 @@ EXECDATE_ISO = EXECDATE_NOFRACTIONS.isoformat() -class TestLocalClient(unittest.TestCase): +class TestLocalClient: @classmethod - def setUpClass(cls): - super().setUpClass() + def setup_class(cls): DagBag(example_bash_operator.__file__).get_dag("example_bash_operator").sync_to_db() - def setUp(self): - super().setUp() + def setup_method(self): clear_db_pools() self.client = Client(api_base_url=None, auth=None) - def tearDown(self): + def teardown_method(self): clear_db_pools() - super().tearDown() @patch.object(DAG, 'create_dagrun') def test_trigger_dag(self, mock): diff --git a/tests/api/common/test_trigger_dag.py b/tests/api/common/test_trigger_dag.py index 5d410d47a9918..43b6eb2f15c14 100644 --- a/tests/api/common/test_trigger_dag.py +++ b/tests/api/common/test_trigger_dag.py @@ -17,11 +17,9 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest -from parameterized import parameterized from airflow.api.common.trigger_dag import _trigger_dag from airflow.exceptions import AirflowException @@ -30,11 +28,11 @@ from tests.test_utils import db -class TestTriggerDag(unittest.TestCase): - def setUp(self) -> None: +class TestTriggerDag: + def setup_method(self) -> None: db.clear_db_runs() - def tearDown(self) -> None: + def teardown_method(self) -> None: db.clear_db_runs() @mock.patch('airflow.models.DagBag') @@ -108,15 +106,16 @@ def test_trigger_dag_with_valid_start_date(self, dag_bag_mock): assert len(triggers) == 1 - @parameterized.expand( + @pytest.mark.parametrize( + "conf, expected_conf", [ (None, {}), ({"foo": "bar"}, {"foo": "bar"}), ('{"foo": "bar"}', {"foo": "bar"}), - ] + ], ) @mock.patch('airflow.models.DagBag') - def test_trigger_dag_with_conf(self, conf, expected_conf, dag_bag_mock): + def test_trigger_dag_with_conf(self, dag_bag_mock, conf, expected_conf): dag_id = "trigger_dag_with_conf" dag = DAG(dag_id) dag_bag_mock.dags = [dag_id] diff --git a/tests/api_connexion/schemas/test_common_schema.py b/tests/api_connexion/schemas/test_common_schema.py index 00a7f05d1ef75..759859c9a80a3 100644 --- a/tests/api_connexion/schemas/test_common_schema.py +++ b/tests/api_connexion/schemas/test_common_schema.py @@ -17,7 +17,6 @@ from __future__ import annotations import datetime -import unittest import pytest from dateutil import relativedelta @@ -31,7 +30,7 @@ ) -class TestTimeDeltaSchema(unittest.TestCase): +class TestTimeDeltaSchema: def test_should_serialize(self): instance = datetime.timedelta(days=12) schema_instance = TimeDeltaSchema() @@ -46,7 +45,7 @@ def test_should_deserialize(self): assert expected_instance == result -class TestRelativeDeltaSchema(unittest.TestCase): +class TestRelativeDeltaSchema: def test_should_serialize(self): instance = relativedelta.relativedelta(days=+12) schema_instance = RelativeDeltaSchema() @@ -78,7 +77,7 @@ def test_should_deserialize(self): assert expected_instance == result -class TestCronExpressionSchema(unittest.TestCase): +class TestCronExpressionSchema: def test_should_deserialize(self): instance = {"__type": "CronExpression", "value": "5 4 * * *"} schema_instance = CronExpressionSchema() @@ -87,7 +86,7 @@ def test_should_deserialize(self): assert expected_instance == result -class TestScheduleIntervalSchema(unittest.TestCase): +class TestScheduleIntervalSchema: def test_should_serialize_timedelta(self): instance = datetime.timedelta(days=12) schema_instance = ScheduleIntervalSchema() diff --git a/tests/api_connexion/schemas/test_connection_schema.py b/tests/api_connexion/schemas/test_connection_schema.py index 531c86ee05df5..1383a68e09749 100644 --- a/tests/api_connexion/schemas/test_connection_schema.py +++ b/tests/api_connexion/schemas/test_connection_schema.py @@ -17,7 +17,6 @@ from __future__ import annotations import re -import unittest import marshmallow import pytest @@ -34,12 +33,12 @@ from tests.test_utils.db import clear_db_connections -class TestConnectionCollectionItemSchema(unittest.TestCase): - def setUp(self) -> None: +class TestConnectionCollectionItemSchema: + def setup_method(self) -> None: with create_session() as session: session.query(Connection).delete() - def tearDown(self) -> None: + def teardown_method(self) -> None: clear_db_connections() @provide_session @@ -106,12 +105,12 @@ def test_deserialize_required_fields(self): connection_collection_item_schema.load(connection_dump_1) -class TestConnectionCollectionSchema(unittest.TestCase): - def setUp(self) -> None: +class TestConnectionCollectionSchema: + def setup_method(self) -> None: with create_session() as session: session.query(Connection).delete() - def tearDown(self) -> None: + def teardown_method(self) -> None: clear_db_connections() @provide_session @@ -148,12 +147,12 @@ def test_serialize(self, session): } -class TestConnectionSchema(unittest.TestCase): - def setUp(self) -> None: +class TestConnectionSchema: + def setup_method(self) -> None: with create_session() as session: session.query(Connection).delete() - def tearDown(self) -> None: + def teardown_method(self) -> None: clear_db_connections() @provide_session @@ -205,7 +204,7 @@ def test_deserialize(self): } -class TestConnectionTestSchema(unittest.TestCase): +class TestConnectionTestSchema: def test_response(self): data = { 'status': True, diff --git a/tests/api_connexion/schemas/test_dag_run_schema.py b/tests/api_connexion/schemas/test_dag_run_schema.py index ae10ba3b6f299..395ce1874e3df 100644 --- a/tests/api_connexion/schemas/test_dag_run_schema.py +++ b/tests/api_connexion/schemas/test_dag_run_schema.py @@ -16,11 +16,8 @@ # under the License. from __future__ import annotations -import unittest - import pytest from dateutil.parser import parse -from parameterized import parameterized from airflow.api_connexion.exceptions import BadRequest from airflow.api_connexion.schemas.dag_run_schema import ( @@ -39,13 +36,13 @@ SECOND_TIME = "2020-06-10T13:59:56.336000+00:00" -class TestDAGRunBase(unittest.TestCase): - def setUp(self) -> None: +class TestDAGRunBase: + def setup_method(self) -> None: clear_db_runs() self.default_time = DEFAULT_TIME self.second_time = SECOND_TIME - def tearDown(self) -> None: + def teardown_method(self) -> None: clear_db_runs() @@ -82,7 +79,8 @@ def test_serialize(self, session): "run_type": "manual", } - @parameterized.expand( + @pytest.mark.parametrize( + "serialized_dagrun, expected_result", [ ( # Conf not provided {"dag_run_id": "my-dag-run", "execution_date": DEFAULT_TIME}, @@ -112,7 +110,7 @@ def test_serialize(self, session): "conf": {"start": "stop"}, }, ), - ] + ], ) def test_deserialize(self, serialized_dagrun, expected_result): result = dagrun_schema.load(serialized_dagrun) diff --git a/tests/api_connexion/schemas/test_error_schema.py b/tests/api_connexion/schemas/test_error_schema.py index 38d5762bf7d01..ca150ac6f2a20 100644 --- a/tests/api_connexion/schemas/test_error_schema.py +++ b/tests/api_connexion/schemas/test_error_schema.py @@ -16,8 +16,6 @@ # under the License. from __future__ import annotations -import unittest - from airflow.api_connexion.schemas.error_schema import ( ImportErrorCollection, import_error_collection_schema, @@ -29,12 +27,12 @@ from tests.test_utils.db import clear_db_import_errors -class TestErrorSchemaBase(unittest.TestCase): - def setUp(self) -> None: +class TestErrorSchemaBase: + def setup_method(self) -> None: clear_db_import_errors() self.timestamp = "2020-06-10T12:02:44" - def tearDown(self) -> None: + def teardown_method(self) -> None: clear_db_import_errors() diff --git a/tests/api_connexion/schemas/test_health_schema.py b/tests/api_connexion/schemas/test_health_schema.py index f1920e95c51df..fe0c83e261572 100644 --- a/tests/api_connexion/schemas/test_health_schema.py +++ b/tests/api_connexion/schemas/test_health_schema.py @@ -16,13 +16,11 @@ # under the License. from __future__ import annotations -import unittest - from airflow.api_connexion.schemas.health_schema import health_schema -class TestHealthSchema(unittest.TestCase): - def setUp(self): +class TestHealthSchema: + def setup_method(self): self.default_datetime = "2020-06-10T12:02:44+00:00" def test_serialize(self): diff --git a/tests/api_connexion/schemas/test_plugin_schema.py b/tests/api_connexion/schemas/test_plugin_schema.py index a0cdd00254cb8..e4f8fe1388793 100644 --- a/tests/api_connexion/schemas/test_plugin_schema.py +++ b/tests/api_connexion/schemas/test_plugin_schema.py @@ -16,8 +16,6 @@ # under the License. from __future__ import annotations -import unittest - from airflow.api_connexion.schemas.plugin_schema import ( PluginCollection, plugin_collection_schema, @@ -26,8 +24,8 @@ from airflow.plugins_manager import AirflowPlugin -class TestPluginBase(unittest.TestCase): - def setUp(self) -> None: +class TestPluginBase: + def setup_method(self) -> None: self.mock_plugin = AirflowPlugin() self.mock_plugin.name = "test_plugin" diff --git a/tests/api_connexion/schemas/test_pool_schemas.py b/tests/api_connexion/schemas/test_pool_schemas.py index 48c0f8ee15db2..f0eb0c0a492f5 100644 --- a/tests/api_connexion/schemas/test_pool_schemas.py +++ b/tests/api_connexion/schemas/test_pool_schemas.py @@ -16,19 +16,17 @@ # under the License. from __future__ import annotations -import unittest - from airflow.api_connexion.schemas.pool_schema import PoolCollection, pool_collection_schema, pool_schema from airflow.models.pool import Pool from airflow.utils.session import provide_session from tests.test_utils.db import clear_db_pools -class TestPoolSchema(unittest.TestCase): - def setUp(self) -> None: +class TestPoolSchema: + def setup_method(self) -> None: clear_db_pools() - def tearDown(self) -> None: + def teardown_method(self) -> None: clear_db_pools() @provide_session @@ -56,11 +54,11 @@ def test_deserialize(self, session): assert not isinstance(deserialized_pool, Pool) # Checks if load_instance is set to True -class TestPoolCollectionSchema(unittest.TestCase): - def setUp(self) -> None: +class TestPoolCollectionSchema: + def setup_method(self) -> None: clear_db_pools() - def tearDown(self) -> None: + def teardown_method(self) -> None: clear_db_pools() def test_serialize(self): diff --git a/tests/api_connexion/schemas/test_task_instance_schema.py b/tests/api_connexion/schemas/test_task_instance_schema.py index 4e01ca57c87d3..321664b4bc05f 100644 --- a/tests/api_connexion/schemas/test_task_instance_schema.py +++ b/tests/api_connexion/schemas/test_task_instance_schema.py @@ -17,11 +17,9 @@ from __future__ import annotations import datetime as dt -import unittest import pytest from marshmallow import ValidationError -from parameterized import parameterized from airflow.api_connexion.schemas.task_instance_schema import ( clear_task_instance_form, @@ -150,70 +148,59 @@ def test_task_instance_schema_with_sla_and_rendered(self, session): assert serialized_ti == expected_json -class TestClearTaskInstanceFormSchema(unittest.TestCase): - @parameterized.expand( +class TestClearTaskInstanceFormSchema: + @pytest.mark.parametrize( + "payload", [ ( - [ - { - "dry_run": False, - "reset_dag_runs": True, - "only_failed": True, - "only_running": True, - } - ] + { + "dry_run": False, + "reset_dag_runs": True, + "only_failed": True, + "only_running": True, + } ), ( - [ - { - "dry_run": False, - "reset_dag_runs": True, - "end_date": "2020-01-01T00:00:00+00:00", - "start_date": "2020-01-02T00:00:00+00:00", - } - ] + { + "dry_run": False, + "reset_dag_runs": True, + "end_date": "2020-01-01T00:00:00+00:00", + "start_date": "2020-01-02T00:00:00+00:00", + } ), ( - [ - { - "dry_run": False, - "reset_dag_runs": True, - "task_ids": [], - } - ] + { + "dry_run": False, + "reset_dag_runs": True, + "task_ids": [], + } ), ( - [ - { - "dry_run": False, - "reset_dag_runs": True, - "dag_run_id": "scheduled__2022-06-19T00:00:00+00:00", - "start_date": "2022-08-03T00:00:00+00:00", - } - ] + { + "dry_run": False, + "reset_dag_runs": True, + "dag_run_id": "scheduled__2022-06-19T00:00:00+00:00", + "start_date": "2022-08-03T00:00:00+00:00", + } ), ( - [ - { - "dry_run": False, - "reset_dag_runs": True, - "dag_run_id": "scheduled__2022-06-19T00:00:00+00:00", - "end_date": "2022-08-03T00:00:00+00:00", - } - ] + { + "dry_run": False, + "reset_dag_runs": True, + "dag_run_id": "scheduled__2022-06-19T00:00:00+00:00", + "end_date": "2022-08-03T00:00:00+00:00", + } ), ( - [ - { - "dry_run": False, - "reset_dag_runs": True, - "dag_run_id": "scheduled__2022-06-19T00:00:00+00:00", - "end_date": "2022-08-04T00:00:00+00:00", - "start_date": "2022-08-03T00:00:00+00:00", - } - ] + { + "dry_run": False, + "reset_dag_runs": True, + "dag_run_id": "scheduled__2022-06-19T00:00:00+00:00", + "end_date": "2022-08-04T00:00:00+00:00", + "start_date": "2022-08-03T00:00:00+00:00", + } ), - ] + ], ) def test_validation_error(self, payload): with pytest.raises(ValidationError): @@ -246,14 +233,15 @@ def test_success(self): } assert expected_result == result - @parameterized.expand( + @pytest.mark.parametrize( + "override_data", [ - ({"task_id": None},), - ({"include_future": "foo"},), - ({"execution_date": "NOW"},), - ({"new_state": "INVALID_STATE"},), - ({"execution_date": "2020-01-01T00:00:00+00:00", "dag_run_id": "some-run-id"},), - ] + {"task_id": None}, + {"include_future": "foo"}, + {"execution_date": "NOW"}, + {"new_state": "INVALID_STATE"}, + {"execution_date": "2020-01-01T00:00:00+00:00", "dag_run_id": "some-run-id"}, + ], ) def test_validation_error(self, override_data): self.current_input.update(override_data) diff --git a/tests/api_connexion/schemas/test_version_schema.py b/tests/api_connexion/schemas/test_version_schema.py index 16ed1660f7e43..57db9d3dfa4ab 100644 --- a/tests/api_connexion/schemas/test_version_schema.py +++ b/tests/api_connexion/schemas/test_version_schema.py @@ -16,21 +16,14 @@ # under the License. from __future__ import annotations -import unittest - -from parameterized import parameterized +import pytest from airflow.api_connexion.endpoints.version_endpoint import VersionInfo from airflow.api_connexion.schemas.version_schema import version_info_schema -class TestVersionInfoSchema(unittest.TestCase): - @parameterized.expand( - [ - ("GIT_COMMIT",), - (None,), - ] - ) +class TestVersionInfoSchema: + @pytest.mark.parametrize("git_commit", ["GIT_COMMIT", None]) def test_serialize(self, git_commit): version_info = VersionInfo("VERSION", git_commit) current_data = version_info_schema.dump(version_info) diff --git a/tests/api_connexion/test_parameters.py b/tests/api_connexion/test_parameters.py index a73a112d1de5d..a0c3b3a185b4a 100644 --- a/tests/api_connexion/test_parameters.py +++ b/tests/api_connexion/test_parameters.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest @@ -34,8 +33,8 @@ from tests.test_utils.config import conf_vars -class TestValidateIsTimezone(unittest.TestCase): - def setUp(self) -> None: +class TestValidateIsTimezone: + def setup_method(self) -> None: from datetime import datetime self.naive = datetime.now() @@ -49,8 +48,8 @@ def test_timezone_passes(self): assert validate_istimezone(self.timezoned) is None -class TestDateTimeParser(unittest.TestCase): - def setUp(self) -> None: +class TestDateTimeParser: + def setup_method(self) -> None: self.default_time = '2020-06-13T22:44:00+00:00' self.default_time_2 = '2020-06-13T22:44:00Z' @@ -72,7 +71,7 @@ def test_raises_400_for_invalid_arg(self): format_datetime(invalid_datetime) -class TestMaximumPagelimit(unittest.TestCase): +class TestMaximumPagelimit: @conf_vars({("api", "maximum_page_limit"): "320"}) def test_maximum_limit_return_val(self): limit = check_limit(300) @@ -99,7 +98,7 @@ def test_negative_limit_raises(self): check_limit(-1) -class TestFormatParameters(unittest.TestCase): +class TestFormatParameters: def test_should_works_with_datetime_formatter(self): decorator = format_parameters({"param_a": format_datetime}) endpoint = mock.MagicMock() diff --git a/tests/cli/commands/test_celery_command.py b/tests/cli/commands/test_celery_command.py index 79f373588f5cf..fee0162ba5ea5 100644 --- a/tests/cli/commands/test_celery_command.py +++ b/tests/cli/commands/test_celery_command.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from argparse import Namespace from tempfile import NamedTemporaryFile from unittest import mock @@ -32,7 +31,7 @@ from tests.test_utils.config import conf_vars -class TestWorkerPrecheck(unittest.TestCase): +class TestWorkerPrecheck: @mock.patch('airflow.settings.validate_session') def test_error(self, mock_validate_session): """ @@ -65,9 +64,9 @@ def test_validate_session_dbapi_exception(self, mock_session): @pytest.mark.integration("redis") @pytest.mark.integration("rabbitmq") @pytest.mark.backend("mysql", "postgres") -class TestWorkerServeLogs(unittest.TestCase): +class TestWorkerServeLogs: @classmethod - def setUpClass(cls): + def setup_class(cls): cls.parser = cli_parser.get_parser() @mock.patch('airflow.cli.commands.celery_command.celery_app') @@ -94,9 +93,9 @@ def test_skip_serve_logs_on_worker_start(self, mock_celery_app): @pytest.mark.backend("mysql", "postgres") -class TestCeleryStopCommand(unittest.TestCase): +class TestCeleryStopCommand: @classmethod - def setUpClass(cls): + def setup_class(cls): cls.parser = cli_parser.get_parser() @mock.patch("airflow.cli.commands.celery_command.setup_locations") @@ -176,9 +175,9 @@ def test_custom_pid_file_is_used_in_start_and_stop( @pytest.mark.backend("mysql", "postgres") -class TestWorkerStart(unittest.TestCase): +class TestWorkerStart: @classmethod - def setUpClass(cls): + def setup_class(cls): cls.parser = cli_parser.get_parser() @mock.patch("airflow.cli.commands.celery_command.setup_locations") @@ -237,9 +236,9 @@ def test_worker_started_with_required_arguments(self, mock_celery_app, mock_pope @pytest.mark.backend("mysql", "postgres") -class TestWorkerFailure(unittest.TestCase): +class TestWorkerFailure: @classmethod - def setUpClass(cls): + def setup_class(cls): cls.parser = cli_parser.get_parser() @mock.patch('airflow.cli.commands.celery_command.Process') @@ -255,9 +254,9 @@ def test_worker_failure_gracefull_shutdown(self, mock_celery_app, mock_popen): @pytest.mark.backend("mysql", "postgres") -class TestFlowerCommand(unittest.TestCase): +class TestFlowerCommand: @classmethod - def setUpClass(cls): + def setup_class(cls): cls.parser = cli_parser.get_parser() @mock.patch('airflow.cli.commands.celery_command.celery_app') diff --git a/tests/cli/commands/test_cheat_sheet_command.py b/tests/cli/commands/test_cheat_sheet_command.py index 56e5f7707dbda..d3afbf03e7c24 100644 --- a/tests/cli/commands/test_cheat_sheet_command.py +++ b/tests/cli/commands/test_cheat_sheet_command.py @@ -18,7 +18,6 @@ import contextlib import io -import unittest from unittest import mock from airflow.cli import cli_parser @@ -89,9 +88,9 @@ def noop(): """ -class TestCheatSheetCommand(unittest.TestCase): +class TestCheatSheetCommand: @classmethod - def setUpClass(cls): + def setup_class(cls): cls.parser = cli_parser.get_parser() @mock.patch('airflow.cli.cli_parser.airflow_commands', MOCK_COMMANDS) diff --git a/tests/cli/commands/test_config_command.py b/tests/cli/commands/test_config_command.py index f93889be30a3f..ae7895d0d671f 100644 --- a/tests/cli/commands/test_config_command.py +++ b/tests/cli/commands/test_config_command.py @@ -18,7 +18,6 @@ import contextlib import io -import unittest from unittest import mock import pytest @@ -28,9 +27,9 @@ from tests.test_utils.config import conf_vars -class TestCliConfigList(unittest.TestCase): +class TestCliConfigList: @classmethod - def setUpClass(cls): + def setup_class(cls): cls.parser = cli_parser.get_parser() @mock.patch("airflow.cli.commands.config_command.io.StringIO") @@ -47,9 +46,9 @@ def test_cli_show_config_should_display_key(self): assert 'testkey = test_value' in temp_stdout.getvalue() -class TestCliConfigGetValue(unittest.TestCase): +class TestCliConfigGetValue: @classmethod - def setUpClass(cls): + def setup_class(cls): cls.parser = cli_parser.get_parser() @conf_vars({('core', 'test_key'): 'test_value'}) diff --git a/tests/cli/commands/test_dag_command.py b/tests/cli/commands/test_dag_command.py index e380d7f16a41e..c748149cc9cd0 100644 --- a/tests/cli/commands/test_dag_command.py +++ b/tests/cli/commands/test_dag_command.py @@ -21,7 +21,6 @@ import io import os import tempfile -import unittest from datetime import datetime, timedelta from unittest import mock from unittest.mock import MagicMock @@ -48,15 +47,15 @@ # TODO: Check if tests needs side effects - locally there's missing DAG -class TestCliDags(unittest.TestCase): +class TestCliDags: @classmethod - def setUpClass(cls): + def setup_class(cls): cls.dagbag = DagBag(include_examples=True) cls.dagbag.sync_to_db() cls.parser = cli_parser.get_parser() @classmethod - def tearDownClass(cls) -> None: + def teardown_class(cls) -> None: clear_db_runs() clear_db_dags() diff --git a/tests/cli/commands/test_dag_processor_command.py b/tests/cli/commands/test_dag_processor_command.py index 23f9980cb6fd1..8c42594d9fb21 100644 --- a/tests/cli/commands/test_dag_processor_command.py +++ b/tests/cli/commands/test_dag_processor_command.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest @@ -28,13 +27,13 @@ from tests.test_utils.config import conf_vars -class TestDagProcessorCommand(unittest.TestCase): +class TestDagProcessorCommand: """ Tests the CLI interface and that it correctly calls the DagProcessor """ @classmethod - def setUpClass(cls): + def setup_class(cls): cls.parser = cli_parser.get_parser() @conf_vars( diff --git a/tests/cli/commands/test_info_command.py b/tests/cli/commands/test_info_command.py index 5812742951053..6a638b78f7b47 100644 --- a/tests/cli/commands/test_info_command.py +++ b/tests/cli/commands/test_info_command.py @@ -21,10 +21,8 @@ import io import logging import os -import unittest import pytest -from parameterized import parameterized from rich.console import Console from airflow.cli import cli_parser @@ -42,15 +40,16 @@ def capture_show_output(instance): return capture.get() -class TestPiiAnonymizer(unittest.TestCase): - def setUp(self) -> None: +class TestPiiAnonymizer: + def setup_method(self) -> None: self.instance = info_command.PiiAnonymizer() def test_should_remove_pii_from_path(self): home_path = os.path.expanduser("~/airflow/config") assert "${HOME}/airflow/config" == self.instance.process_path(home_path) - @parameterized.expand( + @pytest.mark.parametrize( + "before, after", [ ( "postgresql+psycopg2://postgres:airflow@postgres/airflow", @@ -68,7 +67,7 @@ def test_should_remove_pii_from_path(self): "postgresql+psycopg2://postgres/airflow", "postgresql+psycopg2://postgres/airflow", ), - ] + ], ) def test_should_remove_pii_from_url(self, before, after): assert after == self.instance.process_url(before) @@ -77,7 +76,6 @@ def test_should_remove_pii_from_url(self, before, after): class TestAirflowInfo: @classmethod def setup_class(cls): - cls.parser = cli_parser.get_parser() @classmethod diff --git a/tests/cli/commands/test_jobs_command.py b/tests/cli/commands/test_jobs_command.py index f35d1fe8abec2..3e97ea0f1e67c 100644 --- a/tests/cli/commands/test_jobs_command.py +++ b/tests/cli/commands/test_jobs_command.py @@ -18,7 +18,6 @@ import contextlib import io -import unittest import pytest @@ -30,16 +29,16 @@ from tests.test_utils.db import clear_db_jobs -class TestCliConfigList(unittest.TestCase): +class TestCliConfigList: @classmethod - def setUpClass(cls): + def setup_class(cls): cls.parser = cli_parser.get_parser() - def setUp(self) -> None: + def setup_method(self) -> None: clear_db_jobs() self.scheduler_job = None - def tearDown(self) -> None: + def teardown_method(self) -> None: if self.scheduler_job and self.scheduler_job.processor_agent: self.scheduler_job.processor_agent.end() clear_db_jobs() @@ -54,7 +53,7 @@ def test_should_report_success_for_one_working_scheduler(self): with contextlib.redirect_stdout(io.StringIO()) as temp_stdout: jobs_command.check(self.parser.parse_args(['jobs', 'check', '--job-type', 'SchedulerJob'])) - self.assertIn("Found one alive job.", temp_stdout.getvalue()) + assert "Found one alive job." in temp_stdout.getvalue() def test_should_report_success_for_one_working_scheduler_with_hostname(self): with create_session() as session: @@ -71,7 +70,7 @@ def test_should_report_success_for_one_working_scheduler_with_hostname(self): ['jobs', 'check', '--job-type', 'SchedulerJob', '--hostname', 'HOSTNAME'] ) ) - self.assertIn("Found one alive job.", temp_stdout.getvalue()) + assert "Found one alive job." in temp_stdout.getvalue() def test_should_report_success_for_ha_schedulers(self): scheduler_jobs = [] @@ -90,7 +89,7 @@ def test_should_report_success_for_ha_schedulers(self): ['jobs', 'check', '--job-type', 'SchedulerJob', '--limit', '100', '--allow-multiple'] ) ) - self.assertIn("Found 3 alive jobs.", temp_stdout.getvalue()) + assert "Found 3 alive jobs." in temp_stdout.getvalue() for scheduler_job in scheduler_jobs: if scheduler_job.processor_agent: scheduler_job.processor_agent.end() diff --git a/tests/cli/commands/test_kerberos_command.py b/tests/cli/commands/test_kerberos_command.py index b64c6941fcf92..2acf43dd429f9 100644 --- a/tests/cli/commands/test_kerberos_command.py +++ b/tests/cli/commands/test_kerberos_command.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from airflow.cli import cli_parser @@ -24,9 +23,9 @@ from tests.test_utils.config import conf_vars -class TestKerberosCommand(unittest.TestCase): +class TestKerberosCommand: @classmethod - def setUpClass(cls): + def setup_class(cls): cls.parser = cli_parser.get_parser() @mock.patch('airflow.cli.commands.kerberos_command.krb') diff --git a/tests/cli/commands/test_kubernetes_command.py b/tests/cli/commands/test_kubernetes_command.py index 8a0045f27d86d..d30cbb8ce763d 100644 --- a/tests/cli/commands/test_kubernetes_command.py +++ b/tests/cli/commands/test_kubernetes_command.py @@ -18,7 +18,6 @@ import os import tempfile -import unittest from unittest import mock from unittest.mock import MagicMock, call @@ -29,9 +28,9 @@ from airflow.cli.commands import kubernetes_command -class TestGenerateDagYamlCommand(unittest.TestCase): +class TestGenerateDagYamlCommand: @classmethod - def setUpClass(cls): + def setup_class(cls): cls.parser = cli_parser.get_parser() def test_generate_dag_yaml(self): @@ -56,12 +55,12 @@ def test_generate_dag_yaml(self): assert os.stat(out_dir + file_name).st_size > 0 -class TestCleanUpPodsCommand(unittest.TestCase): +class TestCleanUpPodsCommand: label_selector = ','.join(['dag_id', 'task_id', 'try_number', 'airflow_version']) @classmethod - def setUpClass(cls): + def setup_class(cls): cls.parser = cli_parser.get_parser() @mock.patch('kubernetes.client.CoreV1Api.delete_namespaced_pod') diff --git a/tests/cli/commands/test_legacy_commands.py b/tests/cli/commands/test_legacy_commands.py index 9cf927588bcd8..7db800f8568ce 100644 --- a/tests/cli/commands/test_legacy_commands.py +++ b/tests/cli/commands/test_legacy_commands.py @@ -18,7 +18,6 @@ import contextlib import io -import unittest from argparse import ArgumentError from unittest.mock import MagicMock @@ -59,9 +58,9 @@ ] -class TestCliDeprecatedCommandsValue(unittest.TestCase): +class TestCliDeprecatedCommandsValue: @classmethod - def setUpClass(cls): + def setup_class(cls): cls.parser = cli_parser.get_parser() def test_should_display_value(self): diff --git a/tests/cli/commands/test_plugins_command.py b/tests/cli/commands/test_plugins_command.py index e3d4dc9e901d3..6b955eb9e461c 100644 --- a/tests/cli/commands/test_plugins_command.py +++ b/tests/cli/commands/test_plugins_command.py @@ -19,7 +19,6 @@ import io import json import textwrap -import unittest from contextlib import redirect_stdout from airflow.cli import cli_parser @@ -40,9 +39,9 @@ class TestPlugin(AirflowPlugin): hooks = [PluginHook] -class TestPluginsCommand(unittest.TestCase): +class TestPluginsCommand: @classmethod - def setUpClass(cls): + def setup_class(cls): cls.parser = cli_parser.get_parser() @mock_plugin_manager(plugins=[]) @@ -118,4 +117,4 @@ def test_should_display_one_plugins_as_table(self): test-plugin-cli | tests.cli.commands.test_plugins_command.PluginHook """ ) - self.assertEqual(stdout, expected_output) + assert stdout == expected_output diff --git a/tests/cli/commands/test_rotate_fernet_key_command.py b/tests/cli/commands/test_rotate_fernet_key_command.py index 2423857cdaeb8..36da0760f5b86 100644 --- a/tests/cli/commands/test_rotate_fernet_key_command.py +++ b/tests/cli/commands/test_rotate_fernet_key_command.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from cryptography.fernet import Fernet @@ -30,16 +29,16 @@ from tests.test_utils.db import clear_db_connections, clear_db_variables -class TestRotateFernetKeyCommand(unittest.TestCase): +class TestRotateFernetKeyCommand: @classmethod - def setUpClass(cls): + def setup_class(cls): cls.parser = cli_parser.get_parser() - def setUp(self) -> None: + def setup_method(self) -> None: clear_db_connections(add_default_connections_back=False) clear_db_variables() - def tearDown(self) -> None: + def teardown_method(self) -> None: clear_db_connections(add_default_connections_back=False) clear_db_variables() diff --git a/tests/cli/commands/test_scheduler_command.py b/tests/cli/commands/test_scheduler_command.py index a321e857b3bdc..cb820f358379c 100644 --- a/tests/cli/commands/test_scheduler_command.py +++ b/tests/cli/commands/test_scheduler_command.py @@ -17,12 +17,11 @@ # under the License. from __future__ import annotations -import unittest from http.server import BaseHTTPRequestHandler from unittest import mock from unittest.mock import MagicMock -from parameterized import parameterized +import pytest from airflow.cli import cli_parser from airflow.cli.commands import scheduler_command @@ -31,27 +30,28 @@ from tests.test_utils.config import conf_vars -class TestSchedulerCommand(unittest.TestCase): +class TestSchedulerCommand: @classmethod - def setUpClass(cls): + def setup_class(cls): cls.parser = cli_parser.get_parser() - @parameterized.expand( + @pytest.mark.parametrize( + "executor, expect_serve_logs", [ ("CeleryExecutor", False), ("LocalExecutor", True), ("SequentialExecutor", True), ("KubernetesExecutor", False), - ] + ], ) @mock.patch("airflow.cli.commands.scheduler_command.SchedulerJob") @mock.patch("airflow.cli.commands.scheduler_command.Process") def test_serve_logs_on_scheduler( self, - executor, - expect_serve_logs, mock_process, mock_scheduler_job, + executor, + expect_serve_logs, ): args = self.parser.parse_args(['scheduler']) @@ -60,33 +60,23 @@ def test_serve_logs_on_scheduler( if expect_serve_logs: mock_process.assert_has_calls([mock.call(target=serve_logs)]) else: - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): mock_process.assert_has_calls([mock.call(target=serve_logs)]) - @parameterized.expand( - [ - ("LocalExecutor",), - ("SequentialExecutor",), - ] - ) @mock.patch("airflow.cli.commands.scheduler_command.SchedulerJob") @mock.patch("airflow.cli.commands.scheduler_command.Process") - def test_skip_serve_logs(self, executor, mock_process, mock_scheduler_job): + @pytest.mark.parametrize("executor", ["LocalExecutor", "SequentialExecutor"]) + def test_skip_serve_logs(self, mock_process, mock_scheduler_job, executor): args = self.parser.parse_args(['scheduler', '--skip-serve-logs']) with conf_vars({("core", "executor"): executor}): scheduler_command.scheduler(args) - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): mock_process.assert_has_calls([mock.call(target=serve_logs)]) - @parameterized.expand( - [ - ("LocalExecutor",), - ("SequentialExecutor",), - ] - ) @mock.patch("airflow.cli.commands.scheduler_command.SchedulerJob") @mock.patch("airflow.cli.commands.scheduler_command.Process") - def test_graceful_shutdown(self, executor, mock_process, mock_scheduler_job): + @pytest.mark.parametrize("executor", ["LocalExecutor", "SequentialExecutor"]) + def test_graceful_shutdown(self, mock_process, mock_scheduler_job, executor): args = self.parser.parse_args(['scheduler']) with conf_vars({("core", "executor"): executor}): mock_scheduler_job.run.side_effect = Exception('Mock exception to trigger runtime error') @@ -116,7 +106,7 @@ def test_disable_scheduler_health( ): args = self.parser.parse_args(['scheduler']) scheduler_command.scheduler(args) - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): mock_process.assert_has_calls([mock.call(target=serve_health_check)]) @@ -131,8 +121,8 @@ def do_GET(self, path): super().do_GET() -class TestSchedulerHealthServer(unittest.TestCase): - def setUp(self) -> None: +class TestSchedulerHealthServer: + def setup_method(self) -> None: self.mock_server = MockServer() @mock.patch.object(BaseHTTPRequestHandler, "send_error") diff --git a/tests/cli/commands/test_sync_perm_command.py b/tests/cli/commands/test_sync_perm_command.py index f567a84ea5ef8..bcc7e60d72817 100644 --- a/tests/cli/commands/test_sync_perm_command.py +++ b/tests/cli/commands/test_sync_perm_command.py @@ -17,16 +17,15 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from airflow.cli import cli_parser from airflow.cli.commands import sync_perm_command -class TestCliSyncPerm(unittest.TestCase): +class TestCliSyncPerm: @classmethod - def setUpClass(cls): + def setup_class(cls): cls.parser = cli_parser.get_parser() @mock.patch("airflow.cli.commands.sync_perm_command.cached_app") diff --git a/tests/cli/commands/test_triggerer_command.py b/tests/cli/commands/test_triggerer_command.py index 6edee751da835..51086323c4cc2 100644 --- a/tests/cli/commands/test_triggerer_command.py +++ b/tests/cli/commands/test_triggerer_command.py @@ -17,20 +17,19 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from airflow.cli import cli_parser from airflow.cli.commands import triggerer_command -class TestTriggererCommand(unittest.TestCase): +class TestTriggererCommand: """ Tests the CLI interface and that it correctly calls the TriggererJob """ @classmethod - def setUpClass(cls): + def setup_class(cls): cls.parser = cli_parser.get_parser() @mock.patch("airflow.cli.commands.triggerer_command.TriggererJob") diff --git a/tests/cli/commands/test_variable_command.py b/tests/cli/commands/test_variable_command.py index 673323a3172db..ee61c67e318cf 100644 --- a/tests/cli/commands/test_variable_command.py +++ b/tests/cli/commands/test_variable_command.py @@ -20,7 +20,6 @@ import io import os import tempfile -import unittest.mock from contextlib import redirect_stdout import pytest @@ -32,16 +31,16 @@ from tests.test_utils.db import clear_db_variables -class TestCliVariables(unittest.TestCase): +class TestCliVariables: @classmethod - def setUpClass(cls): + def setup_class(cls): cls.dagbag = models.DagBag(include_examples=True) cls.parser = cli_parser.get_parser() - def setUp(self): + def setup_method(self): clear_db_variables() - def tearDown(self): + def teardown_method(self): clear_db_variables() def test_variables_set(self): diff --git a/tests/cli/commands/test_version_command.py b/tests/cli/commands/test_version_command.py index ed9c655d4b650..98a19010d4d75 100644 --- a/tests/cli/commands/test_version_command.py +++ b/tests/cli/commands/test_version_command.py @@ -17,7 +17,6 @@ from __future__ import annotations import io -import unittest from contextlib import redirect_stdout import airflow.cli.commands.version_command @@ -25,9 +24,9 @@ from airflow.version import version -class TestCliVersion(unittest.TestCase): +class TestCliVersion: @classmethod - def setUpClass(cls): + def setup_class(cls): cls.parser = cli_parser.get_parser() def test_cli_version(self): diff --git a/tests/cli/commands/test_webserver_command.py b/tests/cli/commands/test_webserver_command.py index 9aa8e97885060..4a4e70ce658c3 100644 --- a/tests/cli/commands/test_webserver_command.py +++ b/tests/cli/commands/test_webserver_command.py @@ -21,7 +21,6 @@ import sys import tempfile import time -import unittest from unittest import mock import psutil @@ -35,8 +34,8 @@ from tests.test_utils.config import conf_vars -class TestGunicornMonitor(unittest.TestCase): - def setUp(self) -> None: +class TestGunicornMonitor: + def setup_method(self) -> None: self.monitor = GunicornMonitor( gunicorn_master_pid=1, num_workers_expected=4, @@ -127,7 +126,7 @@ def test_should_reload_when_plugin_has_been_changed(self, mock_sleep): assert abs(self.monitor._last_refresh_time - time.monotonic()) < 5 -class TestGunicornMonitorGeneratePluginState(unittest.TestCase): +class TestGunicornMonitorGeneratePluginState: @staticmethod def _prepare_test_file(filepath: str, size: int): os.makedirs(os.path.dirname(filepath), exist_ok=True) @@ -184,12 +183,12 @@ def test_should_detect_changes_in_directory(self): assert 4 == len(state_d) -class TestCLIGetNumReadyWorkersRunning(unittest.TestCase): +class TestCLIGetNumReadyWorkersRunning: @classmethod - def setUpClass(cls): + def setup_class(cls): cls.parser = cli_parser.get_parser() - def setUp(self): + def setup_method(self): self.children = mock.MagicMock() self.child = mock.MagicMock() self.process = mock.MagicMock() diff --git a/tests/cli/test_cli_parser.py b/tests/cli/test_cli_parser.py index a3e8323b2e656..26b917d096bbf 100644 --- a/tests/cli/test_cli_parser.py +++ b/tests/cli/test_cli_parser.py @@ -23,10 +23,8 @@ import io import re from collections import Counter -from unittest import TestCase import pytest -from parameterized import parameterized from airflow.cli import cli_parser from tests.test_utils.config import conf_vars @@ -39,7 +37,7 @@ cli_args = {k: v for k, v in cli_parser.__dict__.items() if k.startswith("ARG_")} -class TestCli(TestCase): +class TestCli: def test_arg_option_long_only(self): """ Test if the name of cli.args long option valid @@ -151,11 +149,11 @@ def test_dag_parser_commands_and_comamnd_group_sections(self): parser = cli_parser.get_parser(dag_parser=True) with contextlib.redirect_stdout(io.StringIO()) as stdout: - with self.assertRaises(SystemExit): + with pytest.raises(SystemExit): parser.parse_args(['--help']) stdout = stdout.getvalue() - self.assertIn("Commands", stdout) - self.assertIn("Groups", stdout) + assert "Commands" in stdout + assert "Groups" in stdout def test_should_display_help(self): parser = cli_parser.get_parser() @@ -186,7 +184,7 @@ def test_dag_cli_should_display_help(self): ) ] for cmd_args in all_command_as_args: - with self.assertRaises(SystemExit): + with pytest.raises(SystemExit): parser.parse_args([*cmd_args, '--help']) def test_positive_int(self): @@ -202,7 +200,7 @@ def test_dag_parser_celery_command_require_celery_executor(self): io.StringIO() ) as stderr: parser = cli_parser.get_parser() - with self.assertRaises(SystemExit): + with pytest.raises(SystemExit): parser.parse_args(['celery']) stderr = stderr.getvalue() assert ( @@ -211,18 +209,19 @@ def test_dag_parser_celery_command_require_celery_executor(self): "your current executor: SequentialExecutor, subclassed from: BaseExecutor, see help above." ) in stderr - @parameterized.expand( + @pytest.mark.parametrize( + "executor", [ "CeleryExecutor", "CeleryKubernetesExecutor", "custom_executor.CustomCeleryExecutor", "custom_executor.CustomCeleryKubernetesExecutor", - ] + ], ) def test_dag_parser_celery_command_accept_celery_executor(self, executor): with conf_vars({('core', 'executor'): executor}), contextlib.redirect_stderr(io.StringIO()) as stderr: parser = cli_parser.get_parser() - with self.assertRaises(SystemExit): + with pytest.raises(SystemExit): parser.parse_args(['celery']) stderr = stderr.getvalue() assert ( diff --git a/tests/core/test_config_templates.py b/tests/core/test_config_templates.py index 1293a24ddaada..ecef25a4a4abe 100644 --- a/tests/core/test_config_templates.py +++ b/tests/core/test_config_templates.py @@ -18,9 +18,8 @@ import configparser import os -import unittest -from parameterized import parameterized +import pytest from tests.test_utils import AIRFLOW_MAIN_FOLDER @@ -70,19 +69,15 @@ ] -class TestAirflowCfg(unittest.TestCase): - @parameterized.expand( - [ - ("default_airflow.cfg",), - ("default_test.cfg",), - ] - ) +class TestAirflowCfg: + @pytest.mark.parametrize("filename", ["default_airflow.cfg", "default_test.cfg"]) def test_should_be_ascii_file(self, filename: str): with open(os.path.join(CONFIG_TEMPLATES_FOLDER, filename), "rb") as f: content = f.read().decode("ascii") assert content - @parameterized.expand( + @pytest.mark.parametrize( + "filename, expected_sections", [ ( "default_airflow.cfg", @@ -92,7 +87,7 @@ def test_should_be_ascii_file(self, filename: str): "default_test.cfg", DEFAULT_TEST_SECTIONS, ), - ] + ], ) def test_should_be_ini_file(self, filename: str, expected_sections): filepath = os.path.join(CONFIG_TEMPLATES_FOLDER, filename) diff --git a/tests/core/test_logging_config.py b/tests/core/test_logging_config.py index 7cd51e78a8b8f..4d69ac59de6f5 100644 --- a/tests/core/test_logging_config.py +++ b/tests/core/test_logging_config.py @@ -24,11 +24,9 @@ import pathlib import sys import tempfile -import unittest from unittest.mock import patch import pytest -from parameterized import parameterized from airflow.configuration import conf from tests.test_utils.config import conf_vars @@ -170,12 +168,12 @@ def settings_context(content, directory=None, name='LOGGING_CONFIG'): sys.path.remove(settings_root) -class TestLoggingSettings(unittest.TestCase): +class TestLoggingSettings: # Make sure that the configure_logging is not cached - def setUp(self): + def setup_method(self): self.old_modules = dict(sys.modules) - def tearDown(self): + def teardown_method(self): # Remove any new modules imported during the test run. This lets us # import the same source files for more than one test. from airflow.config_templates import airflow_local_settings @@ -281,7 +279,8 @@ def test_loading_remote_logging_with_wasb_handler(self): logger = logging.getLogger('airflow.task') assert isinstance(logger.handlers[0], WasbTaskHandler) - @parameterized.expand( + @pytest.mark.parametrize( + "remote_base_log_folder, log_group_arn", [ ( 'cloudwatch://arn:aws:logs:aaaa:bbbbb:log-group:ccccc', @@ -295,7 +294,7 @@ def test_loading_remote_logging_with_wasb_handler(self): 'cloudwatch://arn:aws:logs:aaaa:bbbbb:log-group:/aws/ecs/ccccc', 'arn:aws:logs:aaaa:bbbbb:log-group:/aws/ecs/ccccc', ), - ] + ], ) def test_log_group_arns_remote_logging_with_cloudwatch_handler( self, remote_base_log_folder, log_group_arn diff --git a/tests/core/test_settings.py b/tests/core/test_settings.py index 54ae31eb55a58..229279534371d 100644 --- a/tests/core/test_settings.py +++ b/tests/core/test_settings.py @@ -20,7 +20,7 @@ import os import sys import tempfile -import unittest +from unittest import mock from unittest.mock import MagicMock, call import pytest @@ -78,32 +78,32 @@ def __exit__(self, *exc_info): sys.path.remove(self.settings_root) -class TestLocalSettings(unittest.TestCase): +class TestLocalSettings: # Make sure that the configure_logging is not cached - def setUp(self): + def setup_method(self): self.old_modules = dict(sys.modules) - def tearDown(self): + def teardown_method(self): # Remove any new modules imported during the test run. This lets us # import the same source files for more than one test. for mod in [m for m in sys.modules if m not in self.old_modules]: del sys.modules[mod] - @unittest.mock.patch("airflow.settings.import_local_settings") - @unittest.mock.patch("airflow.settings.prepare_syspath") + @mock.patch("airflow.settings.import_local_settings") + @mock.patch("airflow.settings.prepare_syspath") def test_initialize_order(self, prepare_syspath, import_local_settings): """ Tests that import_local_settings is called after prepare_classpath """ - mock = unittest.mock.Mock() - mock.attach_mock(prepare_syspath, "prepare_syspath") - mock.attach_mock(import_local_settings, "import_local_settings") + mock_local_settings = mock.Mock() + mock_local_settings.attach_mock(prepare_syspath, "prepare_syspath") + mock_local_settings.attach_mock(import_local_settings, "import_local_settings") import airflow.settings airflow.settings.initialize() - mock.assert_has_calls([call.prepare_syspath(), call.import_local_settings()]) + mock_local_settings.assert_has_calls([call.prepare_syspath(), call.import_local_settings()]) def test_import_with_dunder_all_not_specified(self): """ @@ -133,7 +133,7 @@ def test_import_with_dunder_all(self): assert task_instance.run_as_user == "myself" - @unittest.mock.patch("airflow.settings.log.debug") + @mock.patch("airflow.settings.log.debug") def test_import_local_settings_without_syspath(self, log_mock): """ Tests that an ImportError is raised in import_local_settings @@ -186,7 +186,7 @@ def test_custom_policy(self): settings.task_must_have_owners(task_instance) -class TestUpdatedConfigNames(unittest.TestCase): +class TestUpdatedConfigNames: @conf_vars( {("webserver", "session_lifetime_days"): '5', ("webserver", "session_lifetime_minutes"): '43200'} ) diff --git a/tests/core/test_sqlalchemy_config.py b/tests/core/test_sqlalchemy_config.py index 12e774a0f4474..8000edb106b56 100644 --- a/tests/core/test_sqlalchemy_config.py +++ b/tests/core/test_sqlalchemy_config.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest.mock import patch import pytest @@ -30,14 +29,14 @@ SQL_ALCHEMY_CONNECT_ARGS = {'test': 43503, 'dict': {'is': 1, 'supported': 'too'}} -class TestSqlAlchemySettings(unittest.TestCase): - def setUp(self): +class TestSqlAlchemySettings: + def setup_method(self): self.old_engine = settings.engine self.old_session = settings.Session self.old_conn = settings.SQL_ALCHEMY_CONN settings.SQL_ALCHEMY_CONN = "mysql+foobar://user:pass@host/dbname?inline=param&another=param" - def tearDown(self): + def teardown_method(self): settings.engine = self.old_engine settings.Session = self.old_session settings.SQL_ALCHEMY_CONN = self.old_conn diff --git a/tests/core/test_stats.py b/tests/core/test_stats.py index eae5a36e35258..3373bbfac7d8f 100644 --- a/tests/core/test_stats.py +++ b/tests/core/test_stats.py @@ -19,7 +19,6 @@ import importlib import re -import unittest from unittest import mock from unittest.mock import Mock @@ -46,8 +45,8 @@ def __init__(self, host=None, port=None, prefix=None): pass -class TestStats(unittest.TestCase): - def setUp(self): +class TestStats: + def setup_method(self): self.statsd_client = Mock(spec=statsd.StatsClient) self.stats = SafeStatsdLogger(self.statsd_client) @@ -128,8 +127,8 @@ def test_load_invalid_custom_stats_client(self): importlib.reload(airflow.stats) -class TestDogStats(unittest.TestCase): - def setUp(self): +class TestDogStats: + def setup_method(self): pytest.importorskip('datadog') from datadog import DogStatsd @@ -167,6 +166,7 @@ def test_does_send_stats_using_dogstatsd_with_tags(self): ) def test_does_send_stats_using_dogstatsd_when_statsd_and_dogstatsd_both_on(self): + # ToDo: Figure out why it identical to test_does_send_stats_using_dogstatsd_when_dogstatsd_on self.dogstatsd.incr("empty_key") self.dogstatsd_client.increment.assert_called_once_with( metric='empty_key', sample_rate=1, tags=[], value=1 @@ -222,8 +222,8 @@ def test_does_not_send_stats_using_statsd_when_statsd_and_dogstatsd_both_on(self importlib.reload(airflow.stats) -class TestStatsWithAllowList(unittest.TestCase): - def setUp(self): +class TestStatsWithAllowList: + def setup_method(self): self.statsd_client = Mock(spec=statsd.StatsClient) self.stats = SafeStatsdLogger(self.statsd_client, AllowListValidator("stats_one, stats_two")) @@ -240,8 +240,8 @@ def test_not_increment_counter_if_not_allowed(self): self.statsd_client.assert_not_called() -class TestDogStatsWithAllowList(unittest.TestCase): - def setUp(self): +class TestDogStatsWithAllowList: + def setup_method(self): pytest.importorskip('datadog') from datadog import DogStatsd @@ -273,7 +273,7 @@ def always_valid(stat_name): return stat_name -class TestCustomStatsName(unittest.TestCase): +class TestCustomStatsName: @conf_vars( { ('metrics', 'statsd_on'): 'True', @@ -324,6 +324,6 @@ def test_does_send_stats_using_dogstatsd_when_the_name_is_valid(self, mock_dogst metric='empty_key', sample_rate=1, tags=[], value=1 ) - def tearDown(self) -> None: + def teardown_method(self) -> None: # To avoid side-effect importlib.reload(airflow.stats) diff --git a/tests/dag_processing/test_manager.py b/tests/dag_processing/test_manager.py index 458ab18f1004f..b4446a0e60eab 100644 --- a/tests/dag_processing/test_manager.py +++ b/tests/dag_processing/test_manager.py @@ -25,7 +25,6 @@ import socket import sys import threading -import unittest from datetime import datetime, timedelta from logging.config import dictConfig from tempfile import TemporaryDirectory @@ -1059,12 +1058,12 @@ def test_callback_queue(self, tmpdir): assert manager._callback_to_execute[dag1_req1.full_filepath] == [dag1_req1, dag1_sla1, dag1_req2] -class TestDagFileProcessorAgent(unittest.TestCase): - def setUp(self): +class TestDagFileProcessorAgent: + def setup_method(self): # Make sure that the configure_logging is not cached self.old_modules = dict(sys.modules) - def tearDown(self): + def teardown_method(self): # Remove any new modules imported during the test run. This lets us # import the same source files for more than one test. remove_list = [] diff --git a/tests/executors/test_dask_executor.py b/tests/executors/test_dask_executor.py index 51c125699324c..df83e354a6b47 100644 --- a/tests/executors/test_dask_executor.py +++ b/tests/executors/test_dask_executor.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from datetime import timedelta from unittest import mock @@ -55,7 +54,7 @@ @pytest.mark.skipif(skip_dask_tests, reason="The tests are skipped because it needs testing from Dask team") -class TestBaseDask(unittest.TestCase): +class TestBaseDask: def assert_tasks_on_executor(self, executor, timeout_executor=120): # start the executor @@ -87,7 +86,7 @@ def assert_tasks_on_executor(self, executor, timeout_executor=120): @pytest.mark.skipif(skip_dask_tests, reason="The tests are skipped because it needs testing from Dask team") class TestDaskExecutor(TestBaseDask): - def setUp(self): + def setup_method(self): self.dagbag = DagBag(include_examples=True) self.cluster = LocalCluster() @@ -110,7 +109,7 @@ def test_backfill_integration(self): ) job.run() - def tearDown(self): + def teardown_method(self): self.cluster.close(timeout=5) @@ -118,7 +117,7 @@ def tearDown(self): skip_tls_tests, reason="The tests are skipped because distributed framework could not be imported" ) class TestDaskExecutorTLS(TestBaseDask): - def setUp(self): + def setup_method(self): self.dagbag = DagBag(include_examples=True) @conf_vars( @@ -160,13 +159,13 @@ def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock @pytest.mark.skipif(skip_dask_tests, reason="The tests are skipped because it needs testing from Dask team") -class TestDaskExecutorQueue(unittest.TestCase): +class TestDaskExecutorQueue: def test_dask_queues_no_resources(self): self.cluster = LocalCluster() executor = DaskExecutor(cluster_address=self.cluster.scheduler_address) executor.start() - with self.assertRaises(AirflowException): + with pytest.raises(AirflowException): executor.execute_async(key='success', command=SUCCESS_COMMAND, queue='queue1') def test_dask_queues_not_available(self): @@ -174,7 +173,7 @@ def test_dask_queues_not_available(self): executor = DaskExecutor(cluster_address=self.cluster.scheduler_address) executor.start() - with self.assertRaises(AirflowException): + with pytest.raises(AirflowException): # resource 'queue2' doesn't exist on cluster executor.execute_async(key='success', command=SUCCESS_COMMAND, queue='queue2') @@ -219,5 +218,5 @@ def test_dask_queues_no_queue_specified(self): assert success_future.done() assert success_future.exception() is None - def tearDown(self): + def teardown_method(self): self.cluster.close(timeout=5) diff --git a/tests/executors/test_executor_loader.py b/tests/executors/test_executor_loader.py index fec7391807ab3..180e7b961c8c6 100644 --- a/tests/executors/test_executor_loader.py +++ b/tests/executors/test_executor_loader.py @@ -16,10 +16,9 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock -from parameterized import parameterized +import pytest from airflow import plugins_manager from airflow.executors.executor_loader import ExecutorLoader @@ -38,21 +37,22 @@ class FakePlugin(plugins_manager.AirflowPlugin): executors = [FakeExecutor] -class TestExecutorLoader(unittest.TestCase): - def setUp(self) -> None: +class TestExecutorLoader: + def setup_method(self) -> None: ExecutorLoader._default_executor = None - def tearDown(self) -> None: + def teardown_method(self) -> None: ExecutorLoader._default_executor = None - @parameterized.expand( + @pytest.mark.parametrize( + "executor_name", [ - ("CeleryExecutor",), - ("CeleryKubernetesExecutor",), - ("DebugExecutor",), - ("KubernetesExecutor",), - ("LocalExecutor",), - ] + "CeleryExecutor", + "CeleryKubernetesExecutor", + "DebugExecutor", + "KubernetesExecutor", + "LocalExecutor", + ], ) def test_should_support_executor_from_core(self, executor_name): with conf_vars({("core", "executor"): executor_name}): diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py index 574bd0afadcbf..48843cd098e35 100644 --- a/tests/executors/test_kubernetes_executor.py +++ b/tests/executors/test_kubernetes_executor.py @@ -55,7 +55,7 @@ AirflowKubernetesScheduler = None # type: ignore -class TestAirflowKubernetesScheduler(unittest.TestCase): +class TestAirflowKubernetesScheduler: @staticmethod def _gen_random_string(seed, str_len): char_list = [] @@ -895,8 +895,8 @@ def test_clear_not_launched_queued_tasks_not_launched_other_queue( assert mock_kube_client.list_namespaced_pod.call_count == 0 -class TestKubernetesJobWatcher(unittest.TestCase): - def setUp(self): +class TestKubernetesJobWatcher: + def setup_method(self): self.watcher = KubernetesJobWatcher( namespace="airflow", multi_namespace_mode=False, @@ -1009,12 +1009,12 @@ def test_process_error_event_for_raise_if_not_410(self): self.pod.status.phase = 'Pending' raw_object = {"code": 422, "message": message, "reason": "Test"} self.events.append({"type": "ERROR", "object": self.pod, "raw_object": raw_object}) - with self.assertRaises(AirflowException) as e: - self._run() - assert str(e.exception) == ( - f"Kubernetes failure for {raw_object['reason']} " - f"with code {raw_object['code']} and message: {raw_object['message']}" + error_message = ( + fr"Kubernetes failure for {raw_object['reason']} " + fr"with code {raw_object['code']} and message: {raw_object['message']}" ) + with pytest.raises(AirflowException, match=error_message): + self._run() def test_recover_from_resource_too_old(self): # too old resource diff --git a/tests/executors/test_local_executor.py b/tests/executors/test_local_executor.py index a5e8bbbaec1e4..cf7f37b1dd9c5 100644 --- a/tests/executors/test_local_executor.py +++ b/tests/executors/test_local_executor.py @@ -19,7 +19,6 @@ import datetime import subprocess -import unittest from unittest import mock from airflow import settings @@ -28,7 +27,7 @@ from airflow.utils.state import State -class TestLocalExecutor(unittest.TestCase): +class TestLocalExecutor: TEST_SUCCESS_COMMANDS = 5 diff --git a/tests/executors/test_sequential_executor.py b/tests/executors/test_sequential_executor.py index 0e016dbf92457..e52281ff8079c 100644 --- a/tests/executors/test_sequential_executor.py +++ b/tests/executors/test_sequential_executor.py @@ -17,13 +17,12 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from airflow.executors.sequential_executor import SequentialExecutor -class TestSequentialExecutor(unittest.TestCase): +class TestSequentialExecutor: @mock.patch('airflow.executors.sequential_executor.SequentialExecutor.sync') @mock.patch('airflow.executors.base_executor.BaseExecutor.trigger_tasks') @mock.patch('airflow.executors.base_executor.Stats.gauge') diff --git a/tests/hooks/test_subprocess.py b/tests/hooks/test_subprocess.py index 56f010c73308f..c3df0e5e2049f 100644 --- a/tests/hooks/test_subprocess.py +++ b/tests/hooks/test_subprocess.py @@ -17,14 +17,13 @@ # under the License. from __future__ import annotations -import unittest from pathlib import Path from subprocess import PIPE, STDOUT from tempfile import TemporaryDirectory from unittest import mock from unittest.mock import MagicMock -from parameterized import parameterized +import pytest from airflow.hooks.subprocess import SubprocessHook @@ -32,15 +31,17 @@ OS_ENV_VAL = 'this-is-from-os-environ' -class TestSubprocessHook(unittest.TestCase): - @parameterized.expand( +class TestSubprocessHook: + @pytest.mark.parametrize( + "env,expected", [ - ('with env', {'ABC': '123', 'AAA': '456'}, {'ABC': '123', 'AAA': '456', OS_ENV_KEY: ''}), - ('empty env', {}, {OS_ENV_KEY: ''}), - ('no env', None, {OS_ENV_KEY: OS_ENV_VAL}), - ] + ({"ABC": "123", "AAA": "456"}, {"ABC": "123", "AAA": "456", OS_ENV_KEY: ""}), + ({}, {OS_ENV_KEY: ""}), + (None, {OS_ENV_KEY: OS_ENV_VAL}), + ], + ids=["with env", "empty env", "no env"], ) - def test_env(self, name, env, expected): + def test_env(self, env, expected): """ Test that env variables are exported correctly to the command environment. When ``env`` is ``None``, ``os.environ`` should be passed to ``Popen``. @@ -63,13 +64,14 @@ def build_cmd(keys, filename): actual = dict([x.split('=') for x in tmp_file.read_text().splitlines()]) assert actual == expected - @parameterized.expand( + @pytest.mark.parametrize( + "val,expected", [ ('test-val', 'test-val'), ('test-val\ntest-val\n', ''), ('test-val\ntest-val', 'test-val'), ('', ''), - ] + ], ) def test_return_value(self, val, expected): hook = SubprocessHook() diff --git a/tests/kubernetes/models/test_secret.py b/tests/kubernetes/models/test_secret.py index 882b4850aa07b..90636ae4f91dc 100644 --- a/tests/kubernetes/models/test_secret.py +++ b/tests/kubernetes/models/test_secret.py @@ -17,7 +17,6 @@ from __future__ import annotations import sys -import unittest import uuid from unittest import mock @@ -28,7 +27,7 @@ from airflow.kubernetes.secret import Secret -class TestSecret(unittest.TestCase): +class TestSecret: def test_to_env_secret(self): secret = Secret('env', 'name', 'secret', 'key') assert secret.to_env_secret() == k8s.V1EnvVar( diff --git a/tests/kubernetes/test_client.py b/tests/kubernetes/test_client.py index 2f702adf2ff08..13fa8a34179a4 100644 --- a/tests/kubernetes/test_client.py +++ b/tests/kubernetes/test_client.py @@ -17,7 +17,6 @@ from __future__ import annotations import socket -import unittest from unittest import mock from kubernetes.client import Configuration @@ -26,7 +25,7 @@ from airflow.kubernetes.kube_client import _disable_verify_ssl, _enable_tcp_keepalive, get_kube_client -class TestClient(unittest.TestCase): +class TestClient: @mock.patch('airflow.kubernetes.kube_client.config') def test_load_cluster_config(self, config): get_kube_client(in_cluster=True) @@ -50,7 +49,7 @@ def test_load_config_disable_ssl(self, conf, config): configuration = Configuration.get_default_copy() else: configuration = Configuration() - self.assertFalse(configuration.verify_ssl) + assert not configuration.verify_ssl def test_enable_tcp_keepalive(self): socket_options = [ @@ -69,7 +68,7 @@ def test_enable_tcp_keepalive(self): def test_disable_verify_ssl(self): configuration = Configuration() - self.assertTrue(configuration.verify_ssl) + assert configuration.verify_ssl _disable_verify_ssl() @@ -78,4 +77,4 @@ def test_disable_verify_ssl(self): configuration = Configuration.get_default_copy() else: configuration = Configuration() - self.assertFalse(configuration.verify_ssl) + assert not configuration.verify_ssl diff --git a/tests/macros/test_hive.py b/tests/macros/test_hive.py index c0a47c794a02d..f231724620444 100644 --- a/tests/macros/test_hive.py +++ b/tests/macros/test_hive.py @@ -17,13 +17,12 @@ # under the License. from __future__ import annotations -import unittest from datetime import datetime from airflow.macros import hive -class TestHive(unittest.TestCase): +class TestHive: def test_closest_ds_partition(self): date1 = datetime.strptime('2017-04-24', '%Y-%m-%d') date2 = datetime.strptime('2017-04-25', '%Y-%m-%d') diff --git a/tests/models/test_dagcode.py b/tests/models/test_dagcode.py index fe53601a3e7a7..6f3c5d64cdf66 100644 --- a/tests/models/test_dagcode.py +++ b/tests/models/test_dagcode.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from datetime import timedelta from unittest.mock import patch @@ -39,13 +38,13 @@ def make_example_dags(module): return dagbag.dags -class TestDagCode(unittest.TestCase): +class TestDagCode: """Unit tests for DagCode.""" - def setUp(self): + def setup_method(self): clear_db_dag_code() - def tearDown(self): + def teardown_method(self): clear_db_dag_code() def _write_two_example_dags(self): diff --git a/tests/models/test_param.py b/tests/models/test_param.py index fcfcda4f719f6..b6b565490efeb 100644 --- a/tests/models/test_param.py +++ b/tests/models/test_param.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import unittest from contextlib import nullcontext import pytest @@ -29,7 +28,7 @@ from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_xcom -class TestParam(unittest.TestCase): +class TestParam: def test_param_without_schema(self): p = Param('test') assert p.resolve() == 'test' @@ -47,7 +46,6 @@ def test_null_param(self): assert p.resolve() is None assert p.resolve(None) is None - p = Param(type="null") p = Param(None, type='null') assert p.resolve() is None assert p.resolve(None) is None diff --git a/tests/operators/test_branch_operator.py b/tests/operators/test_branch_operator.py index efc05a647e7a9..64081fbc3914c 100644 --- a/tests/operators/test_branch_operator.py +++ b/tests/operators/test_branch_operator.py @@ -18,7 +18,6 @@ from __future__ import annotations import datetime -import unittest from airflow.models import DAG, DagRun, TaskInstance as TI from airflow.operators.branch import BaseBranchOperator @@ -42,16 +41,14 @@ def choose_branch(self, context): return ['branch_1', 'branch_2'] -class TestBranchOperator(unittest.TestCase): +class TestBranchOperator: @classmethod - def setUpClass(cls): - super().setUpClass() - + def setup_class(cls): with create_session() as session: session.query(DagRun).delete() session.query(TI).delete() - def setUp(self): + def setup_method(self): self.dag = DAG( 'branch_operator_test', default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE}, @@ -63,9 +60,7 @@ def setUp(self): self.branch_3 = None self.branch_op = None - def tearDown(self): - super().tearDown() - + def teardown_method(self): with create_session() as session: session.query(DagRun).delete() session.query(TI).delete() diff --git a/tests/operators/test_trigger_dagrun.py b/tests/operators/test_trigger_dagrun.py index d9b88fedb790d..b396caf43fb34 100644 --- a/tests/operators/test_trigger_dagrun.py +++ b/tests/operators/test_trigger_dagrun.py @@ -20,7 +20,7 @@ import pathlib import tempfile from datetime import datetime -from unittest import TestCase, mock +from unittest import mock import pytest @@ -49,8 +49,8 @@ ).format(dag_id=TRIGGERED_DAG_ID) -class TestDagRunOperator(TestCase): - def setUp(self): +class TestDagRunOperator: + def setup_method(self): # Airflow relies on reading the DAG from disk when triggering it. # Therefore write a temp file holding the DAG to trigger. with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: @@ -67,7 +67,7 @@ def setUp(self): dagbag.bag_dag(self.dag, root_dag=self.dag) dagbag.sync_to_db() - def tearDown(self): + def teardown_method(self): """Cleanup state after testing in DB.""" with create_session() as session: session.query(Log).filter(Log.dag_id == TEST_DAG_ID).delete(synchronize_session=False) diff --git a/tests/operators/test_weekday.py b/tests/operators/test_weekday.py index bfda7e83fda02..45ebe7bc8d68c 100644 --- a/tests/operators/test_weekday.py +++ b/tests/operators/test_weekday.py @@ -18,7 +18,6 @@ from __future__ import annotations import datetime -import unittest import pytest from freezegun import freeze_time @@ -35,21 +34,36 @@ DEFAULT_DATE = timezone.datetime(2020, 2, 5) # Wednesday INTERVAL = datetime.timedelta(hours=12) - - -class TestBranchDayOfWeekOperator(unittest.TestCase): +TEST_CASE_BRANCH_FOLLOW_TRUE = { + "with-string": "Monday", + "with-enum": WeekDay.MONDAY, + "with-enum-set": {WeekDay.MONDAY}, + "with-enum-list": [WeekDay.MONDAY], + "with-enum-dict": {WeekDay.MONDAY: "some_value"}, + "with-enum-set-2-items": {WeekDay.MONDAY, WeekDay.FRIDAY}, + "with-enum-list-2-items": [WeekDay.MONDAY, WeekDay.FRIDAY], + "with-enum-dict-2-items": {WeekDay.MONDAY: "some_value", WeekDay.FRIDAY: "some_value_2"}, + "with-string-set": {"Monday"}, + "with-string-set-2-items": {"Monday", "Friday"}, + "with-set-mix-types": {"Monday", WeekDay.FRIDAY}, + "with-list-mix-types": ["Monday", WeekDay.FRIDAY], + "with-dict-mix-types": {"Monday": "some_value", WeekDay.FRIDAY: "some_value_2"}, +} + + +class TestBranchDayOfWeekOperator: """ Tests for BranchDayOfWeekOperator """ @classmethod - def setUpClass(cls): + def setup_class(cls): with create_session() as session: session.query(DagRun).delete() session.query(TI).delete() session.query(XCom).delete() - def setUp(self): + def setup_method(self): self.dag = DAG( "branch_day_of_week_operator_test", start_date=DEFAULT_DATE, @@ -59,7 +73,7 @@ def setUp(self): self.branch_2 = EmptyOperator(task_id="branch_2", dag=self.dag) self.branch_3 = None - def tearDown(self): + def teardown_method(self): with create_session() as session: session.query(DagRun).delete() session.query(TI).delete() @@ -74,31 +88,14 @@ def _assert_task_ids_match_states(self, dr, task_ids_to_states): except KeyError: raise ValueError(f'Invalid task id {ti.task_id} found!') else: - self.assertEqual( - ti.state, - expected_state, - f"Task {ti.task_id} has state {ti.state} instead of expected {expected_state}", - ) + assert_msg = f"Task {ti.task_id} has state {ti.state} instead of expected {expected_state}" + assert ti.state == expected_state, assert_msg - @parameterized.expand( - [ - ("with-string", "Monday"), - ("with-enum", WeekDay.MONDAY), - ("with-enum-set", {WeekDay.MONDAY}), - ("with-enum-list", [WeekDay.MONDAY]), - ("with-enum-dict", {WeekDay.MONDAY: "some_value"}), - ("with-enum-set-2-items", {WeekDay.MONDAY, WeekDay.FRIDAY}), - ("with-enum-list-2-items", [WeekDay.MONDAY, WeekDay.FRIDAY]), - ("with-enum-dict-2-items", {WeekDay.MONDAY: "some_value", WeekDay.FRIDAY: "some_value_2"}), - ("with-string-set", {"Monday"}), - ("with-string-set-2-items", {"Monday", "Friday"}), - ("with-set-mix-types", {"Monday", WeekDay.FRIDAY}), - ("with-list-mix-types", ["Monday", WeekDay.FRIDAY]), - ("with-dict-mix-types", {"Monday": "some_value", WeekDay.FRIDAY: "some_value_2"}), - ] + @pytest.mark.parametrize( + "weekday", TEST_CASE_BRANCH_FOLLOW_TRUE.values(), ids=TEST_CASE_BRANCH_FOLLOW_TRUE.keys() ) @freeze_time("2021-01-25") # Monday - def test_branch_follow_true(self, _, weekday): + def test_branch_follow_true(self, weekday): """Checks if BranchDayOfWeekOperator follows true branch""" print(datetime.datetime.now()) branch_op = BranchDayOfWeekOperator( @@ -205,7 +202,7 @@ def test_branch_follow_false(self): def test_branch_with_no_weekday(self): """Check if BranchDayOfWeekOperator raises exception on missing weekday""" - with self.assertRaises(AirflowException): + with pytest.raises(AirflowException): BranchDayOfWeekOperator( task_id="make_choice", follow_task_ids_if_true="branch_1", diff --git a/tests/plugins/test_plugin_ignore.py b/tests/plugins/test_plugin_ignore.py index 3db8f44af3dd5..8bff057930995 100644 --- a/tests/plugins/test_plugin_ignore.py +++ b/tests/plugins/test_plugin_ignore.py @@ -20,19 +20,18 @@ import os import shutil import tempfile -import unittest from unittest.mock import patch from airflow import settings from airflow.utils.file import find_path_from_directory -class TestIgnorePluginFile(unittest.TestCase): +class TestIgnorePluginFile: """ Test that the .airflowignore work and whether the file is properly ignored. """ - def setUp(self): + def setup_method(self): """ Make tmp folder and files that should be ignored. And set base path. """ @@ -64,7 +63,7 @@ def setUp(self): settings, 'PLUGINS_FOLDER', return_value=self.plugin_folder_path ) - def tearDown(self): + def teardown_method(self): """ Delete tmp folder """ diff --git a/tests/sensors/test_bash.py b/tests/sensors/test_bash.py index 460e7aa0a2e17..6cef34a5ffc23 100644 --- a/tests/sensors/test_bash.py +++ b/tests/sensors/test_bash.py @@ -18,7 +18,6 @@ from __future__ import annotations import datetime -import unittest import pytest @@ -27,8 +26,8 @@ from airflow.sensors.bash import BashSensor -class TestBashSensor(unittest.TestCase): - def setUp(self): +class TestBashSensor: + def setup_method(self): args = {'owner': 'airflow', 'start_date': datetime.datetime(2017, 1, 1)} dag = DAG('test_dag_id', default_args=args) self.dag = dag diff --git a/tests/sensors/test_filesystem.py b/tests/sensors/test_filesystem.py index 57b462fa3bf3c..0305263caf98c 100644 --- a/tests/sensors/test_filesystem.py +++ b/tests/sensors/test_filesystem.py @@ -20,7 +20,6 @@ import os import shutil import tempfile -import unittest import pytest @@ -33,8 +32,8 @@ DEFAULT_DATE = datetime(2015, 1, 1) -class TestFileSensor(unittest.TestCase): - def setUp(self): +class TestFileSensor: + def setup_method(self): from airflow.hooks.filesystem import FSHook hook = FSHook() diff --git a/tests/sensors/test_time_delta.py b/tests/sensors/test_time_delta.py index 27b95230ab951..9c9e256a12821 100644 --- a/tests/sensors/test_time_delta.py +++ b/tests/sensors/test_time_delta.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from datetime import timedelta from airflow.models import DagBag @@ -30,8 +29,8 @@ TEST_DAG_ID = 'unit_tests' -class TestTimedeltaSensor(unittest.TestCase): - def setUp(self): +class TestTimedeltaSensor: + def setup_method(self): self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True) self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG(TEST_DAG_ID, default_args=self.args) diff --git a/tests/sensors/test_timeout_sensor.py b/tests/sensors/test_timeout_sensor.py index 675a686a31415..59798c827fd25 100644 --- a/tests/sensors/test_timeout_sensor.py +++ b/tests/sensors/test_timeout_sensor.py @@ -18,7 +18,6 @@ from __future__ import annotations import time -import unittest from datetime import timedelta import pytest @@ -63,8 +62,8 @@ def execute(self, context: Context): self.log.info("Success criteria met. Exiting.") -class TestSensorTimeout(unittest.TestCase): - def setUp(self): +class TestSensorTimeout: + def setup_method(self): args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG(TEST_DAG_ID, default_args=args) diff --git a/tests/sensors/test_weekday_sensor.py b/tests/sensors/test_weekday_sensor.py index 63be29aab44bd..868f76ab43aeb 100644 --- a/tests/sensors/test_weekday_sensor.py +++ b/tests/sensors/test_weekday_sensor.py @@ -17,10 +17,7 @@ # under the License. from __future__ import annotations -import unittest - import pytest -from parameterized import parameterized from airflow.exceptions import AirflowSensorTimeout from airflow.models import DagBag @@ -35,42 +32,43 @@ WEEKEND_DATE = datetime(2018, 12, 22) TEST_DAG_ID = 'weekday_sensor_dag' DEV_NULL = '/dev/null' +TEST_CASE_WEEKDAY_SENSOR_TRUE = { + "with-string": "Thursday", + "with-enum": WeekDay.THURSDAY, + "with-enum-set": {WeekDay.THURSDAY}, + "with-enum-list": [WeekDay.THURSDAY], + "with-enum-dict": {WeekDay.THURSDAY: "some_value"}, + "with-enum-set-2-items": {WeekDay.THURSDAY, WeekDay.FRIDAY}, + "with-enum-list-2-items": [WeekDay.THURSDAY, WeekDay.FRIDAY], + "with-enum-dict-2-items": {WeekDay.THURSDAY: "some_value", WeekDay.FRIDAY: "some_value_2"}, + "with-string-set": {"Thursday"}, + "with-string-set-2-items": {"Thursday", "Friday"}, + "with-set-mix-types": {"Thursday", WeekDay.FRIDAY}, + "with-list-mix-types": ["Thursday", WeekDay.FRIDAY], + "with-dict-mix-types": {"Thursday": "some_value", WeekDay.FRIDAY: "some_value_2"}, +} -class TestDayOfWeekSensor(unittest.TestCase): +class TestDayOfWeekSensor: @staticmethod def clean_db(): db.clear_db_runs() db.clear_db_task_fail() - def setUp(self): + def setup_method(self): self.clean_db() self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True) self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} dag = DAG(TEST_DAG_ID, default_args=self.args) self.dag = dag - def tearDown(self): + def teardwon_method(self): self.clean_db() - @parameterized.expand( - [ - ("with-string", "Thursday"), - ("with-enum", WeekDay.THURSDAY), - ("with-enum-set", {WeekDay.THURSDAY}), - ("with-enum-list", [WeekDay.THURSDAY]), - ("with-enum-dict", {WeekDay.THURSDAY: "some_value"}), - ("with-enum-set-2-items", {WeekDay.THURSDAY, WeekDay.FRIDAY}), - ("with-enum-list-2-items", [WeekDay.THURSDAY, WeekDay.FRIDAY]), - ("with-enum-dict-2-items", {WeekDay.THURSDAY: "some_value", WeekDay.FRIDAY: "some_value_2"}), - ("with-string-set", {"Thursday"}), - ("with-string-set-2-items", {"Thursday", "Friday"}), - ("with-set-mix-types", {"Thursday", WeekDay.FRIDAY}), - ("with-list-mix-types", ["Thursday", WeekDay.FRIDAY]), - ("with-dict-mix-types", {"Thursday": "some_value", WeekDay.FRIDAY: "some_value_2"}), - ] + @pytest.mark.parametrize( + "week_day", TEST_CASE_WEEKDAY_SENSOR_TRUE.values(), ids=TEST_CASE_WEEKDAY_SENSOR_TRUE.keys() ) - def test_weekday_sensor_true(self, _, week_day): + def test_weekday_sensor_true(self, week_day): op = DayOfWeekSensor( task_id='weekday_sensor_check_true', week_day=week_day, use_task_logical_date=True, dag=self.dag ) diff --git a/tests/task/task_runner/test_cgroup_task_runner.py b/tests/task/task_runner/test_cgroup_task_runner.py index a9473ce80932f..79ce5dcd84ab3 100644 --- a/tests/task/task_runner/test_cgroup_task_runner.py +++ b/tests/task/task_runner/test_cgroup_task_runner.py @@ -17,13 +17,12 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from airflow.task.task_runner.cgroup_task_runner import CgroupTaskRunner -class TestCgroupTaskRunner(unittest.TestCase): +class TestCgroupTaskRunner: @mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.__init__") @mock.patch("airflow.task.task_runner.base_task_runner.BaseTaskRunner.on_finish") def test_cgroup_task_runner_super_calls(self, mock_super_on_finish, mock_super_init): diff --git a/tests/task/task_runner/test_task_runner.py b/tests/task/task_runner/test_task_runner.py index afeafd03b216d..c498171563752 100644 --- a/tests/task/task_runner/test_task_runner.py +++ b/tests/task/task_runner/test_task_runner.py @@ -16,10 +16,9 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock -from parameterized import parameterized +import pytest from airflow.task.task_runner import CORE_TASK_RUNNERS, get_task_runner from airflow.utils.module_loading import import_string @@ -27,8 +26,8 @@ custom_task_runner = mock.MagicMock() -class GetTaskRunner(unittest.TestCase): - @parameterized.expand([(import_path,) for import_path in CORE_TASK_RUNNERS.values()]) +class TestGetTaskRunner: + @pytest.mark.parametrize("import_path", CORE_TASK_RUNNERS.values()) def test_should_have_valid_imports(self, import_path): assert import_string(import_path) is not None diff --git a/tests/ti_deps/deps/test_dag_ti_slots_available_dep.py b/tests/ti_deps/deps/test_dag_ti_slots_available_dep.py index 6c716d3a5d41c..61fc43c9666e2 100644 --- a/tests/ti_deps/deps/test_dag_ti_slots_available_dep.py +++ b/tests/ti_deps/deps/test_dag_ti_slots_available_dep.py @@ -17,14 +17,13 @@ # under the License. from __future__ import annotations -import unittest from unittest.mock import Mock from airflow.models import TaskInstance from airflow.ti_deps.deps.dag_ti_slots_available_dep import DagTISlotsAvailableDep -class TestDagTISlotsAvailableDep(unittest.TestCase): +class TestDagTISlotsAvailableDep: def test_concurrency_reached(self): """ Test max_active_tasks reached should fail dep diff --git a/tests/ti_deps/deps/test_dag_unpaused_dep.py b/tests/ti_deps/deps/test_dag_unpaused_dep.py index 2aeaed40c585d..514c070ac47fa 100644 --- a/tests/ti_deps/deps/test_dag_unpaused_dep.py +++ b/tests/ti_deps/deps/test_dag_unpaused_dep.py @@ -17,14 +17,13 @@ # under the License. from __future__ import annotations -import unittest from unittest.mock import Mock from airflow.models import TaskInstance from airflow.ti_deps.deps.dag_unpaused_dep import DagUnpausedDep -class TestDagUnpausedDep(unittest.TestCase): +class TestDagUnpausedDep: def test_concurrency_reached(self): """ Test paused DAG should fail dependency diff --git a/tests/ti_deps/deps/test_dagrun_exists_dep.py b/tests/ti_deps/deps/test_dagrun_exists_dep.py index 54d98b587c239..56347ad187d77 100644 --- a/tests/ti_deps/deps/test_dagrun_exists_dep.py +++ b/tests/ti_deps/deps/test_dagrun_exists_dep.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest.mock import Mock, patch from airflow.models import DAG, DagRun @@ -25,7 +24,7 @@ from airflow.utils.state import State -class TestDagrunRunningDep(unittest.TestCase): +class TestDagrunRunningDep: @patch('airflow.models.DagRun.find', return_value=()) def test_dagrun_doesnt_exist(self, mock_dagrun_find): """ diff --git a/tests/ti_deps/deps/test_dagrun_id_dep.py b/tests/ti_deps/deps/test_dagrun_id_dep.py index 09b8614dd71b7..36a3049e07dee 100644 --- a/tests/ti_deps/deps/test_dagrun_id_dep.py +++ b/tests/ti_deps/deps/test_dagrun_id_dep.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest.mock import Mock from airflow.models.dagrun import DagRun @@ -25,7 +24,7 @@ from airflow.utils.types import DagRunType -class TestDagrunRunningDep(unittest.TestCase): +class TestDagrunRunningDep: def test_run_id_is_backfill(self): """ Task instances whose run_id is a backfill dagrun run_id should fail this dep. diff --git a/tests/ti_deps/deps/test_not_in_retry_period_dep.py b/tests/ti_deps/deps/test_not_in_retry_period_dep.py index 07715a3f97b6d..eb4fc90768e95 100644 --- a/tests/ti_deps/deps/test_not_in_retry_period_dep.py +++ b/tests/ti_deps/deps/test_not_in_retry_period_dep.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from datetime import timedelta from unittest.mock import Mock @@ -29,7 +28,7 @@ from airflow.utils.timezone import datetime -class TestNotInRetryPeriodDep(unittest.TestCase): +class TestNotInRetryPeriodDep: def _get_task_instance(self, state, end_date=None, retry_delay=timedelta(minutes=15)): task = Mock(retry_delay=retry_delay, retry_exponential_backoff=False) ti = TaskInstance(task=task, state=state, execution_date=None) diff --git a/tests/ti_deps/deps/test_pool_slots_available_dep.py b/tests/ti_deps/deps/test_pool_slots_available_dep.py index 2cbb16ef1d283..0aec7d8c8c52e 100644 --- a/tests/ti_deps/deps/test_pool_slots_available_dep.py +++ b/tests/ti_deps/deps/test_pool_slots_available_dep.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest.mock import Mock, patch from airflow.models import Pool @@ -27,15 +26,15 @@ from tests.test_utils import db -class TestPoolSlotsAvailableDep(unittest.TestCase): - def setUp(self): +class TestPoolSlotsAvailableDep: + def setup_method(self): db.clear_db_pools() with create_session() as session: test_pool = Pool(pool='test_pool') session.add(test_pool) session.commit() - def tearDown(self): + def teardown_method(self): db.clear_db_pools() @patch('airflow.models.Pool.open_slots', return_value=0) diff --git a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py index 91dbc0deab315..722847a39123f 100644 --- a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py +++ b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from datetime import timedelta from unittest.mock import Mock, patch @@ -28,7 +27,7 @@ from airflow.utils.timezone import utcnow -class TestNotInReschedulePeriodDep(unittest.TestCase): +class TestNotInReschedulePeriodDep: def _get_task_instance(self, state): dag = DAG('test_dag') task = Mock(dag=dag, reschedule=True, is_mapped=False) diff --git a/tests/ti_deps/deps/test_task_concurrency.py b/tests/ti_deps/deps/test_task_concurrency.py index 55f8bd88582df..f694beb430d8d 100644 --- a/tests/ti_deps/deps/test_task_concurrency.py +++ b/tests/ti_deps/deps/test_task_concurrency.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from datetime import datetime from unittest.mock import Mock @@ -27,7 +26,7 @@ from airflow.ti_deps.deps.task_concurrency_dep import TaskConcurrencyDep -class TestTaskConcurrencyDep(unittest.TestCase): +class TestTaskConcurrencyDep: def _get_task(self, **kwargs): return BaseOperator(task_id='test_task', dag=DAG('test_dag'), **kwargs) diff --git a/tests/ti_deps/deps/test_task_not_running_dep.py b/tests/ti_deps/deps/test_task_not_running_dep.py index 9a401d3e690bb..62a1a1f59d8fd 100644 --- a/tests/ti_deps/deps/test_task_not_running_dep.py +++ b/tests/ti_deps/deps/test_task_not_running_dep.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from datetime import datetime from unittest.mock import Mock @@ -25,7 +24,7 @@ from airflow.utils.state import State -class TestTaskNotRunningDep(unittest.TestCase): +class TestTaskNotRunningDep: def test_not_running_state(self): ti = Mock(state=State.QUEUED, end_date=datetime(2016, 1, 1)) assert TaskNotRunningDep().is_met(ti=ti) diff --git a/tests/ti_deps/deps/test_valid_state_dep.py b/tests/ti_deps/deps/test_valid_state_dep.py index f4528212b4847..9fb862341729f 100644 --- a/tests/ti_deps/deps/test_valid_state_dep.py +++ b/tests/ti_deps/deps/test_valid_state_dep.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from datetime import datetime from unittest.mock import Mock @@ -28,7 +27,7 @@ from airflow.utils.state import State -class TestValidStateDep(unittest.TestCase): +class TestValidStateDep: def test_valid_state(self): """ Valid state should pass this dep diff --git a/tests/utils/log/test_file_processor_handler.py b/tests/utils/log/test_file_processor_handler.py index ed563e5c83c4f..1d479c71dad0c 100644 --- a/tests/utils/log/test_file_processor_handler.py +++ b/tests/utils/log/test_file_processor_handler.py @@ -19,7 +19,6 @@ import os import shutil -import unittest from datetime import timedelta from freezegun import freeze_time @@ -28,9 +27,8 @@ from airflow.utils.log.file_processor_handler import FileProcessorHandler -class TestFileProcessorHandler(unittest.TestCase): - def setUp(self): - super().setUp() +class TestFileProcessorHandler: + def setup_method(self): self.base_log_folder = "/tmp/log_test" self.filename = "{filename}" self.filename_template = "{{ filename }}.log" @@ -109,5 +107,5 @@ def test_symlink_latest_log_directory_exists(self): with freeze_time(date1): handler.set_context(filename=os.path.join(self.dag_dir, "log1")) - def tearDown(self): + def teardown_method(self): shutil.rmtree(self.base_log_folder, ignore_errors=True) diff --git a/tests/utils/log/test_json_formatter.py b/tests/utils/log/test_json_formatter.py index 98b409db6dff7..627da568b0706 100644 --- a/tests/utils/log/test_json_formatter.py +++ b/tests/utils/log/test_json_formatter.py @@ -22,13 +22,12 @@ import json import sys -import unittest from logging import makeLogRecord from airflow.utils.log.json_formatter import JSONFormatter -class TestJSONFormatter(unittest.TestCase): +class TestJSONFormatter: """ TestJSONFormatter class combine all tests for JSONFormatter """ diff --git a/tests/utils/test_dag_cycle.py b/tests/utils/test_dag_cycle.py index f6012bc2eaa70..731ea707f75ea 100644 --- a/tests/utils/test_dag_cycle.py +++ b/tests/utils/test_dag_cycle.py @@ -16,8 +16,6 @@ # under the License. from __future__ import annotations -import unittest - import pytest from airflow import DAG @@ -29,7 +27,7 @@ from tests.models import DEFAULT_DATE -class TestCycleTester(unittest.TestCase): +class TestCycleTester: def test_cycle_empty(self): # test empty dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) diff --git a/tests/utils/test_dates.py b/tests/utils/test_dates.py index ae016cda63b33..029bfb5582fb7 100644 --- a/tests/utils/test_dates.py +++ b/tests/utils/test_dates.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from datetime import datetime, timedelta import pendulum @@ -28,7 +27,7 @@ from airflow.utils import dates, timezone -class TestDates(unittest.TestCase): +class TestDates: @pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_days_ago(self): today = pendulum.today() @@ -106,7 +105,7 @@ def test_scale_time_units(self): assert arr4 == approx([2.3147, 1.1574], rel=1e-3) -class TestUtilsDatesDateRange(unittest.TestCase): +class TestUtilsDatesDateRange: def test_no_delta(self): assert dates.date_range(datetime(2016, 1, 1), datetime(2016, 1, 3)) == [] diff --git a/tests/utils/test_docs.py b/tests/utils/test_docs.py index 852891aca1407..3e3fe6df3320c 100644 --- a/tests/utils/test_docs.py +++ b/tests/utils/test_docs.py @@ -16,16 +16,16 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock -from parameterized import parameterized +import pytest from airflow.utils.docs import get_docs_url -class TestGetDocsUrl(unittest.TestCase): - @parameterized.expand( +class TestGetDocsUrl: + @pytest.mark.parametrize( + "version, page, expected_url", [ ( '2.0.0.dev0', @@ -45,7 +45,7 @@ class TestGetDocsUrl(unittest.TestCase): 'project.html', 'https://airflow.apache.org/docs/apache-airflow/1.10.10/project.html', ), - ] + ], ) def test_should_return_link(self, version, page, expected_url): with mock.patch('airflow.version.version', version): diff --git a/tests/utils/test_email.py b/tests/utils/test_email.py index 45d70d11613a1..1a29b4f8d4855 100644 --- a/tests/utils/test_email.py +++ b/tests/utils/test_email.py @@ -19,7 +19,6 @@ import os import tempfile -import unittest from email.mime.application import MIMEApplication from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText @@ -37,7 +36,7 @@ send_email_test = mock.MagicMock() -class TestEmail(unittest.TestCase): +class TestEmail: def test_get_email_address_single_email(self): emails_string = 'test1@example.com' @@ -147,7 +146,7 @@ def test_build_mime_message(self): assert msg['To'] == ','.join(recipients) -class TestEmailSmtp(unittest.TestCase): +class TestEmailSmtp: @mock.patch('airflow.utils.email.send_mime_email') def test_send_smtp(self, mock_send_mime): with tempfile.NamedTemporaryFile() as attachment: diff --git a/tests/utils/test_event_scheduler.py b/tests/utils/test_event_scheduler.py index 7e126bad2662a..641d8dd0f909c 100644 --- a/tests/utils/test_event_scheduler.py +++ b/tests/utils/test_event_scheduler.py @@ -17,13 +17,12 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from airflow.utils.event_scheduler import EventScheduler -class TestEventScheduler(unittest.TestCase): +class TestEventScheduler: def test_call_regular_interval(self): somefunction = mock.MagicMock() diff --git a/tests/utils/test_file.py b/tests/utils/test_file.py index f403408263ee2..2036fbce7a636 100644 --- a/tests/utils/test_file.py +++ b/tests/utils/test_file.py @@ -19,7 +19,6 @@ import os import os.path -import unittest from pathlib import Path from unittest import mock @@ -29,7 +28,7 @@ from tests.models import TEST_DAGS_FOLDER -class TestCorrectMaybeZipped(unittest.TestCase): +class TestCorrectMaybeZipped: @mock.patch("zipfile.is_zipfile") def test_correct_maybe_zipped_normal_file(self, mocked_is_zipfile): path = '/path/to/some/file.txt' @@ -62,7 +61,7 @@ def test_correct_maybe_zipped_archive(self, mocked_is_zipfile): assert dag_folder == '/path/to/archive.zip' -class TestOpenMaybeZipped(unittest.TestCase): +class TestOpenMaybeZipped: def test_open_maybe_zipped_normal_file(self): test_file_path = os.path.join(TEST_DAGS_FOLDER, "no_dags.py") with open_maybe_zipped(test_file_path, 'r') as test_file: diff --git a/tests/utils/test_json.py b/tests/utils/test_json.py index f0e65067f935b..adbf6d1a3a292 100644 --- a/tests/utils/test_json.py +++ b/tests/utils/test_json.py @@ -19,18 +19,16 @@ import decimal import json -import unittest from datetime import date, datetime import numpy as np -import parameterized import pendulum import pytest from airflow.utils import json as utils_json -class TestAirflowJsonEncoder(unittest.TestCase): +class TestAirflowJsonEncoder: def test_encode_datetime(self): obj = datetime.strptime('2017-05-21 00:00:00', '%Y-%m-%d %H:%M:%S') assert json.dumps(obj, cls=utils_json.AirflowJsonEncoder) == '"2017-05-21T00:00:00+00:00"' @@ -42,7 +40,8 @@ def test_encode_pendulum(self): def test_encode_date(self): assert json.dumps(date(2017, 5, 21), cls=utils_json.AirflowJsonEncoder) == '"2017-05-21"' - @parameterized.parameterized.expand( + @pytest.mark.parametrize( + "expr, expected", [("1", "1"), ("52e4", "520000"), ("2e0", "2"), ("12e-2", "0.12"), ("12.34", "12.34")], ) def test_encode_decimal(self, expr, expected): diff --git a/tests/utils/test_logging_mixin.py b/tests/utils/test_logging_mixin.py index ca736e9cfc100..567729bdd8369 100644 --- a/tests/utils/test_logging_mixin.py +++ b/tests/utils/test_logging_mixin.py @@ -17,15 +17,14 @@ # under the License. from __future__ import annotations -import unittest import warnings from unittest import mock from airflow.utils.log.logging_mixin import StreamLogWriter, set_context -class TestLoggingMixin(unittest.TestCase): - def setUp(self): +class TestLoggingMixin: + def setup_method(self): warnings.filterwarnings(action='always') def test_set_context(self): @@ -53,7 +52,7 @@ def tearDown(self): warnings.resetwarnings() -class TestStreamLogWriter(unittest.TestCase): +class TestStreamLogWriter: def test_write(self): logger = mock.MagicMock() logger.log = mock.MagicMock() diff --git a/tests/utils/test_module_loading.py b/tests/utils/test_module_loading.py index bdb2c3af4aa5e..2c52b662360f7 100644 --- a/tests/utils/test_module_loading.py +++ b/tests/utils/test_module_loading.py @@ -17,14 +17,12 @@ # under the License. from __future__ import annotations -import unittest - import pytest from airflow.utils.module_loading import import_string -class TestModuleImport(unittest.TestCase): +class TestModuleImport: def test_import_string(self): cls = import_string('airflow.utils.module_loading.import_string') assert cls == import_string diff --git a/tests/utils/test_net.py b/tests/utils/test_net.py index bdd5a0ac9e345..c633f511cb3ff 100644 --- a/tests/utils/test_net.py +++ b/tests/utils/test_net.py @@ -18,7 +18,6 @@ from __future__ import annotations import re -import unittest from unittest import mock import pytest @@ -32,7 +31,7 @@ def get_hostname(): return 'awesomehostname' -class TestGetHostname(unittest.TestCase): +class TestGetHostname: @mock.patch('airflow.utils.net.getfqdn', return_value='first') @conf_vars({('core', 'hostname_callable'): None}) def test_get_hostname_unset(self, mock_getfqdn): diff --git a/tests/utils/test_operator_helpers.py b/tests/utils/test_operator_helpers.py index c590b51f685ed..96368eaf7fc95 100644 --- a/tests/utils/test_operator_helpers.py +++ b/tests/utils/test_operator_helpers.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from datetime import datetime from unittest import mock @@ -26,9 +25,8 @@ from airflow.utils import operator_helpers -class TestOperatorHelpers(unittest.TestCase): - def setUp(self): - super().setUp() +class TestOperatorHelpers: + def setup_method(self): self.dag_id = 'dag_id' self.task_id = 'task_id' self.try_number = 1 diff --git a/tests/utils/test_operator_resources.py b/tests/utils/test_operator_resources.py index 64b5af32954b1..dd81f4a281140 100644 --- a/tests/utils/test_operator_resources.py +++ b/tests/utils/test_operator_resources.py @@ -17,12 +17,10 @@ # under the License. from __future__ import annotations -import unittest - from airflow.utils.operator_resources import Resources -class TestResources(unittest.TestCase): +class TestResources: def test_resource_eq(self): r = Resources(cpus=0.1, ram=2048) assert r not in [{}, [], None] diff --git a/tests/utils/test_preexisting_python_virtualenv_decorator.py b/tests/utils/test_preexisting_python_virtualenv_decorator.py index 1b54fa45c049c..3934342062344 100644 --- a/tests/utils/test_preexisting_python_virtualenv_decorator.py +++ b/tests/utils/test_preexisting_python_virtualenv_decorator.py @@ -17,12 +17,10 @@ # under the License. from __future__ import annotations -import unittest - from airflow.utils.decorators import remove_task_decorator -class TestExternalPythonDecorator(unittest.TestCase): +class TestExternalPythonDecorator: def test_remove_task_decorator(self): py_source = "@task.external_python(use_dill=True)\ndef f():\nimport funcsigs" res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python") diff --git a/tests/utils/test_python_virtualenv.py b/tests/utils/test_python_virtualenv.py index 89877de7936c7..d7d7970bf1ad7 100644 --- a/tests/utils/test_python_virtualenv.py +++ b/tests/utils/test_python_virtualenv.py @@ -18,14 +18,13 @@ from __future__ import annotations import sys -import unittest from unittest import mock from airflow.utils.decorators import remove_task_decorator from airflow.utils.python_virtualenv import prepare_virtualenv -class TestPrepareVirtualenv(unittest.TestCase): +class TestPrepareVirtualenv: @mock.patch('airflow.utils.python_virtualenv.execute_in_subprocess') def test_should_create_virtualenv(self, mock_execute_in_subprocess): python_bin = prepare_virtualenv( diff --git a/tests/utils/test_sqlalchemy.py b/tests/utils/test_sqlalchemy.py index f8d9f21ab763e..86c320f66635b 100644 --- a/tests/utils/test_sqlalchemy.py +++ b/tests/utils/test_sqlalchemy.py @@ -19,13 +19,11 @@ import datetime import pickle -import unittest from unittest import mock from unittest.mock import MagicMock import pytest from kubernetes.client import models as k8s -from parameterized import parameterized from pytest import param from sqlalchemy.exc import StatementError @@ -41,8 +39,8 @@ TEST_POD = k8s.V1Pod(spec=k8s.V1PodSpec(containers=[k8s.V1Container(name="base")])) -class TestSqlAlchemyUtils(unittest.TestCase): - def setUp(self): +class TestSqlAlchemyUtils: + def setup_method(self): session = Session() # make sure NOT to run in UTC. Only postgres supports storing @@ -108,7 +106,8 @@ def test_process_bind_param_naive(self): ) dag.clear() - @parameterized.expand( + @pytest.mark.parametrize( + "dialect, supports_for_update_of, expected_return_value", [ ( "postgresql", @@ -130,7 +129,7 @@ def test_process_bind_param_naive(self): False, {'skip_locked': True}, ), - ] + ], ) def test_skip_locked(self, dialect, supports_for_update_of, expected_return_value): session = mock.Mock() @@ -138,7 +137,8 @@ def test_skip_locked(self, dialect, supports_for_update_of, expected_return_valu session.bind.dialect.supports_for_update_of = supports_for_update_of assert skip_locked(session=session) == expected_return_value - @parameterized.expand( + @pytest.mark.parametrize( + "dialect, supports_for_update_of, expected_return_value", [ ( "postgresql", @@ -162,7 +162,7 @@ def test_skip_locked(self, dialect, supports_for_update_of, expected_return_valu 'nowait': True, }, ), - ] + ], ) def test_nowait(self, dialect, supports_for_update_of, expected_return_value): session = mock.Mock() @@ -170,7 +170,8 @@ def test_nowait(self, dialect, supports_for_update_of, expected_return_value): session.bind.dialect.supports_for_update_of = supports_for_update_of assert nowait(session=session) == expected_return_value - @parameterized.expand( + @pytest.mark.parametrize( + "dialect, supports_for_update_of, use_row_level_lock_conf, expected_use_row_level_lock", [ ("postgresql", True, True, True), ("postgresql", True, False, False), @@ -179,7 +180,7 @@ def test_nowait(self, dialect, supports_for_update_of, expected_return_value): ("mysql", True, True, True), ("mysql", True, False, False), ("sqlite", False, True, True), - ] + ], ) def test_with_row_locks( self, dialect, supports_for_update_of, use_row_level_lock_conf, expected_use_row_level_lock @@ -232,7 +233,7 @@ def test_prohibit_commit_specific_session_only(self): other_session.execute('SELECT 1') other_session.commit() - def tearDown(self): + def teardown_method(self): self.session.close() settings.engine.dispose() diff --git a/tests/utils/test_timezone.py b/tests/utils/test_timezone.py index 729e514fa1379..e006d990dfc35 100644 --- a/tests/utils/test_timezone.py +++ b/tests/utils/test_timezone.py @@ -18,7 +18,6 @@ from __future__ import annotations import datetime -import unittest import pendulum import pytest @@ -32,7 +31,7 @@ UTC = timezone.utc -class TestTimezone(unittest.TestCase): +class TestTimezone: def test_is_aware(self): assert timezone.is_localized(datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT)) assert not timezone.is_localized(datetime.datetime(2011, 9, 1, 13, 20, 30)) diff --git a/tests/utils/test_trigger_rule.py b/tests/utils/test_trigger_rule.py index f8f598c39806d..44866f8a8b081 100644 --- a/tests/utils/test_trigger_rule.py +++ b/tests/utils/test_trigger_rule.py @@ -17,14 +17,12 @@ # under the License. from __future__ import annotations -import unittest - import pytest from airflow.utils.trigger_rule import TriggerRule -class TestTriggerRule(unittest.TestCase): +class TestTriggerRule: def test_valid_trigger_rules(self): assert TriggerRule.is_valid(TriggerRule.ALL_SUCCESS) assert TriggerRule.is_valid(TriggerRule.ALL_FAILED) diff --git a/tests/utils/test_weekday.py b/tests/utils/test_weekday.py index c478e834de2ab..0dcadfd1ac0cb 100644 --- a/tests/utils/test_weekday.py +++ b/tests/utils/test_weekday.py @@ -17,16 +17,14 @@ # under the License. from __future__ import annotations -import unittest from enum import Enum import pytest -from parameterized import parameterized from airflow.utils.weekday import WeekDay -class TestWeekDay(unittest.TestCase): +class TestWeekDay: def test_weekday_enum_length(self): assert len(WeekDay) == 7 @@ -44,36 +42,44 @@ def test_weekday_name_value(self): assert isinstance(weekday_enum, int) assert isinstance(weekday_enum, Enum) - @parameterized.expand( + @pytest.mark.parametrize( + "weekday, expected", [ - ("with-string", "Monday", 1), - ("with-enum", WeekDay.MONDAY, 1), - ] + ("Monday", 1), + (WeekDay.MONDAY, 1), + ], + ids=["with-string", "with-enum"], ) - def test_convert(self, _, weekday, expected): + def test_convert(self, weekday, expected): result = WeekDay.convert(weekday) - self.assertEqual(result, expected) + assert result == expected def test_convert_with_incorrect_input(self): invalid = "Sun" - with self.assertRaisesRegex( - AttributeError, - f'Invalid Week Day passed: "{invalid}"', - ): + error_message = fr'Invalid Week Day passed: "{invalid}"' + with pytest.raises(AttributeError, match=error_message): WeekDay.convert(invalid) - @parameterized.expand( + @pytest.mark.parametrize( + "weekday, expected", [ - ("with-string", "Monday", {WeekDay.MONDAY}), - ("with-enum", WeekDay.MONDAY, {WeekDay.MONDAY}), - ("with-dict", {"Thursday": "1"}, {WeekDay.THURSDAY}), - ("with-list", ["Thursday"], {WeekDay.THURSDAY}), - ("with-mix", ["Thursday", WeekDay.MONDAY], {WeekDay.MONDAY, WeekDay.THURSDAY}), - ] + ("Monday", {WeekDay.MONDAY}), + (WeekDay.MONDAY, {WeekDay.MONDAY}), + ({"Thursday": "1"}, {WeekDay.THURSDAY}), + (["Thursday"], {WeekDay.THURSDAY}), + (["Thursday", WeekDay.MONDAY], {WeekDay.MONDAY, WeekDay.THURSDAY}), + ], + ids=[ + "with-string", + "with-enum", + "with-dict", + "with-list", + "with-mix", + ], ) - def test_validate_week_day(self, _, weekday, expected): + def test_validate_week_day(self, weekday, expected): result = WeekDay.validate_week_day(weekday) - self.assertEqual(expected, result) + assert expected == result def test_validate_week_day_with_invalid_type(self): invalid_week_day = 5 diff --git a/tests/utils/test_weight_rule.py b/tests/utils/test_weight_rule.py index 7be17e7604a7f..73abafe782b86 100644 --- a/tests/utils/test_weight_rule.py +++ b/tests/utils/test_weight_rule.py @@ -17,14 +17,12 @@ # under the License. from __future__ import annotations -import unittest - import pytest from airflow.utils.weight_rule import WeightRule -class TestWeightRule(unittest.TestCase): +class TestWeightRule: def test_valid_weight_rules(self): assert WeightRule.is_valid(WeightRule.DOWNSTREAM) assert WeightRule.is_valid(WeightRule.UPSTREAM) diff --git a/tests/www/test_app.py b/tests/www/test_app.py index d82dda1d7ae34..106001cc840f5 100644 --- a/tests/www/test_app.py +++ b/tests/www/test_app.py @@ -19,7 +19,6 @@ import runpy import sys -import unittest from datetime import timedelta from unittest import mock @@ -33,9 +32,9 @@ from tests.test_utils.decorators import dont_initialize_flask_app_submodules -class TestApp(unittest.TestCase): +class TestApp: @classmethod - def setUpClass(cls) -> None: + def setup_class(cls) -> None: from airflow import settings settings.configure_orm() diff --git a/tests/www/test_init_views.py b/tests/www/test_init_views.py index 7f23e439e125a..49d06ee04eb72 100644 --- a/tests/www/test_init_views.py +++ b/tests/www/test_init_views.py @@ -17,7 +17,6 @@ from __future__ import annotations import re -import unittest from unittest import mock import pytest @@ -26,7 +25,7 @@ from tests.test_utils.config import conf_vars -class TestInitApiExperimental(unittest.TestCase): +class TestInitApiExperimental: @conf_vars({('api', 'enable_experimental_api'): 'true'}) def test_should_raise_deprecation_warning_when_enabled(self): app = mock.MagicMock() diff --git a/tests/www/test_utils.py b/tests/www/test_utils.py index 1fc59567f7607..f1a2012a94ffa 100644 --- a/tests/www/test_utils.py +++ b/tests/www/test_utils.py @@ -18,7 +18,6 @@ from __future__ import annotations import re -import unittest from datetime import datetime from urllib.parse import parse_qs @@ -28,7 +27,7 @@ from airflow.www.utils import wrapped_markdown -class TestUtils(unittest.TestCase): +class TestUtils: def check_generate_pages_html(self, current_page, total_pages, window=7, check_middle=False): extra_links = 4 # first, prev, next, last search = "'>\"/>" @@ -156,8 +155,8 @@ def test_dag_run_link(self): assert '' not in html -class TestAttrRenderer(unittest.TestCase): - def setUp(self): +class TestAttrRenderer: + def setup_method(self): self.attr_renderer = utils.get_attr_renderer() def test_python_callable(self): @@ -178,11 +177,11 @@ def test_markdown(self): assert "
  • bar
  • " in rendered def test_markdown_none(self): - rendered = self.attr_renderer["python_callable"](None) - assert "" == rendered + rendered = self.attr_renderer["doc_md"](None) + assert rendered is None -class TestWrappedMarkdown(unittest.TestCase): +class TestWrappedMarkdown: def test_wrapped_markdown_with_docstring_curly_braces(self): rendered = wrapped_markdown("{braces}", css_class="a_class") assert ( diff --git a/tests/www/test_validators.py b/tests/www/test_validators.py index eaac039436426..499c9a860809a 100644 --- a/tests/www/test_validators.py +++ b/tests/www/test_validators.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest @@ -25,9 +24,8 @@ from airflow.www import validators -class TestGreaterEqualThan(unittest.TestCase): - def setUp(self): - super().setUp() +class TestGreaterEqualThan: + def setup_method(self): self.form_field_mock = mock.MagicMock(data='2017-05-06') self.form_field_mock.gettext.side_effect = lambda msg: msg self.other_field_mock = mock.MagicMock(data='2017-05-05') @@ -89,9 +87,8 @@ def test_validation_raises_custom_message(self): ) -class TestValidJson(unittest.TestCase): - def setUp(self): - super().setUp() +class TestValidJson: + def setup_method(self): self.form_field_mock = mock.MagicMock(data='{"valid":"True"}') self.form_field_mock.gettext.side_effect = lambda msg: msg self.form_mock = mock.MagicMock(spec_set=dict)