diff --git a/airflow/example_dags/tutorial.py b/airflow/example_dags/tutorial.py index 11fc3ce1e00ae..d6b50128e8d18 100644 --- a/airflow/example_dags/tutorial.py +++ b/airflow/example_dags/tutorial.py @@ -97,7 +97,7 @@ """ ) - dag.doc_md = __doc__ # providing that you have a docstring at the beginning of the DAG + dag.doc_md = __doc__ # providing that you have a docstring at the beginning of the DAG; OR dag.doc_md = """ This is a documentation placed anywhere """ # otherwise, type it like this diff --git a/airflow/example_dags/tutorial_taskflow_api.py b/airflow/example_dags/tutorial_taskflow_api.py index 4ff2f68831348..ba8aef09a613e 100644 --- a/airflow/example_dags/tutorial_taskflow_api.py +++ b/airflow/example_dags/tutorial_taskflow_api.py @@ -100,7 +100,7 @@ def load(total_order_value: float): # [START dag_invocation] -tutorial_dag = tutorial_taskflow_api() +tutorial_taskflow_api() # [END dag_invocation] # [END tutorial] diff --git a/airflow/models/dag.py b/airflow/models/dag.py index c7336b8aaa549..f67d4162b6e97 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -28,6 +28,7 @@ import traceback import warnings import weakref +from collections import deque from datetime import datetime, timedelta from inspect import signature from typing import ( @@ -35,6 +36,7 @@ Any, Callable, Collection, + Deque, Dict, FrozenSet, Iterable, @@ -101,6 +103,8 @@ from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType if TYPE_CHECKING: + from types import ModuleType + from airflow.datasets import Dataset from airflow.decorators import TaskDecoratorCollection from airflow.models.slamiss import SlaMiss @@ -329,6 +333,7 @@ class DAG(LoggingMixin): :param owner_links: Dict of owners and their links, that will be clickable on the DAGs view UI. Can be used as an HTTP link (for example the link to your Slack channel), or a mailto link. e.g: {"dag_owner": "https://airflow.apache.org/"} + :param auto_register: Automatically register this DAG when it is used in a ``with`` block """ _comps = { @@ -390,6 +395,7 @@ def __init__( render_template_as_native_obj: bool = False, tags: Optional[List[str]] = None, owner_links: Optional[Dict[str, str]] = None, + auto_register: bool = True, ): from airflow.utils.task_group import TaskGroup @@ -565,6 +571,7 @@ def __init__( self._access_control = DAG._upgrade_outdated_dag_access_control(access_control) self.is_paused_upon_creation = is_paused_upon_creation + self.auto_register = auto_register self.jinja_environment_kwargs = jinja_environment_kwargs self.render_template_as_native_obj = render_template_as_native_obj @@ -2860,6 +2867,7 @@ def get_serialized_fields(cls): # has_on_*_callback are only stored if the value is True, as the default is False 'has_on_success_callback', 'has_on_failure_callback', + 'auto_register', } cls.__serialized_fields = frozenset(vars(DAG(dag_id='test')).keys()) - exclusion_list return cls.__serialized_fields @@ -3315,6 +3323,7 @@ def dag( render_template_as_native_obj: bool = False, tags: Optional[List[str]] = None, owner_links: Optional[Dict[str, str]] = None, + auto_register: bool = True, ) -> Callable[[Callable], Callable[..., DAG]]: """ Python dag decorator. Wraps a function into an Airflow DAG. @@ -3367,6 +3376,7 @@ def factory(*args, **kwargs): tags=tags, schedule=schedule, owner_links=owner_links, + auto_register=auto_register, ) as dag_obj: # Set DAG documentation from function documentation. if f.__doc__: @@ -3424,24 +3434,28 @@ class DagContext: """ - _context_managed_dag: Optional[DAG] = None - _previous_context_managed_dags: List[DAG] = [] + _context_managed_dags: Deque[DAG] = deque() + autoregistered_dags: Set[Tuple[DAG, "ModuleType"]] = set() + current_autoregister_module_name: Optional[str] = None @classmethod def push_context_managed_dag(cls, dag: DAG): - if cls._context_managed_dag: - cls._previous_context_managed_dags.append(cls._context_managed_dag) - cls._context_managed_dag = dag + cls._context_managed_dags.appendleft(dag) @classmethod def pop_context_managed_dag(cls) -> Optional[DAG]: - old_dag = cls._context_managed_dag - if cls._previous_context_managed_dags: - cls._context_managed_dag = cls._previous_context_managed_dags.pop() - else: - cls._context_managed_dag = None - return old_dag + dag = cls._context_managed_dags.popleft() + + # In a few cases around serialization we explicitly push None in to the stack + if cls.current_autoregister_module_name is not None and dag and dag.auto_register: + mod = sys.modules[cls.current_autoregister_module_name] + cls.autoregistered_dags.add((dag, mod)) + + return dag @classmethod def get_current_dag(cls) -> Optional[DAG]: - return cls._context_managed_dag + try: + return cls._context_managed_dags[0] + except IndexError: + return None diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index 3f8d6e57ca38e..2f07c2437734c 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -260,9 +260,12 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True): Given a path to a python module or zip file, this method imports the module and look for dag objects within it. """ + from airflow.models.dag import DagContext + # if the source file no longer exists in the DB or in the filesystem, # return an empty list # todo: raise exception? + if filepath is None or not os.path.isfile(filepath): return [] @@ -280,6 +283,9 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True): self.log.exception(e) return [] + # Ensure we don't pick up anything else we didn't mean to + DagContext.autoregistered_dags.clear() + if filepath.endswith(".py") or not zipfile.is_zipfile(filepath): mods = self._load_modules_from_file(filepath, safe_mode) else: @@ -291,6 +297,8 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True): return found_dags def _load_modules_from_file(self, filepath, safe_mode): + from airflow.models.dag import DagContext + if not might_contain_dag(filepath, safe_mode): # Don't want to spam user with skip messages if not self.has_logged: @@ -306,6 +314,8 @@ def _load_modules_from_file(self, filepath, safe_mode): if mod_name in sys.modules: del sys.modules[mod_name] + DagContext.current_autoregister_module_name = mod_name + def parse(mod_name, filepath): try: loader = importlib.machinery.SourceFileLoader(mod_name, filepath) @@ -344,6 +354,8 @@ def parse(mod_name, filepath): return parse(mod_name, filepath) def _load_modules_from_zip(self, filepath, safe_mode): + from airflow.models.dag import DagContext + mods = [] with zipfile.ZipFile(filepath) as current_zip_file: for zip_info in current_zip_file.infolist(): @@ -372,6 +384,7 @@ def _load_modules_from_zip(self, filepath, safe_mode): if mod_name in sys.modules: del sys.modules[mod_name] + DagContext.current_autoregister_module_name = mod_name try: sys.path.insert(0, filepath) current_module = importlib.import_module(mod_name) @@ -391,9 +404,14 @@ def _load_modules_from_zip(self, filepath, safe_mode): return mods def _process_modules(self, filepath, mods, file_last_changed_on_disk): - from airflow.models.dag import DAG # Avoid circular import + from airflow.models.dag import DAG, DagContext # Avoid circular import + + top_level_dags = {(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG)} + + top_level_dags.update(DagContext.autoregistered_dags) - top_level_dags = ((o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG)) + DagContext.current_autoregister_module_name = None + DagContext.autoregistered_dags.clear() found_dags = [] diff --git a/docs/apache-airflow/howto/dynamic-dag-generation.rst b/docs/apache-airflow/howto/dynamic-dag-generation.rst index f2f364cd51260..5f1067ac8f3d0 100644 --- a/docs/apache-airflow/howto/dynamic-dag-generation.rst +++ b/docs/apache-airflow/howto/dynamic-dag-generation.rst @@ -74,10 +74,10 @@ Then you can import and use the ``ALL_TASKS`` constant in all your DAGs like tha schedule=None, start_date=datetime(2021, 1, 1), catchup=False, - ) as dag: + ): for task in ALL_TASKS: # create your operators and relations here - pass + ... Don't forget that in this case you need to add empty ``__init__.py`` file in the ``my_company_utils`` folder and you should add the ``my_company_utils/.*`` line to ``.airflowignore`` file (if using the regexp ignore @@ -107,10 +107,11 @@ the meta-data file in your DAG easily. The location of the file to read can be f # Configuration dict is available here -Dynamic DAGs with ``globals()`` -............................... -You can dynamically generate DAGs by working with ``globals()``. -As long as a ``DAG`` object in ``globals()`` is created, Airflow will load it. +Registering dynamic DAGs +........................ + +You can dynamically generate DAGs when using the ``@dag`` decorator or the ``with DAG(..)`` context manager +and Airflow will automatically register them. .. code-block:: python @@ -133,13 +134,18 @@ As long as a ``DAG`` object in ``globals()`` is created, Airflow will load it. print_message(config["message"]) - globals()[dag_id] = dynamic_generated_dag() + dynamic_generated_dag() The code below will generate a DAG for each config: ``dynamic_generated_dag_config1`` and ``dynamic_generated_dag_config2``. -Each of them can run separately with related configuration +Each of them can run separately with related configuration. + +If you do not wish to have DAGs auto-registered, you can disable the behavior by setting ``auto_register=False`` on your DAG. + +.. versionchanged:: 2.4 -.. warning:: - Using this practice, pay attention to "late binding" behaviour in Python loops. See `that GitHub discussion `_ for more details + As of version 2.4 DAGs that are created by calling a ``@dag`` decorated function (or that are used in the + ``with DAG(...)`` context manager are automatically registered, and no longer need to be stored in a + global variable. Optimizing DAG parsing delays during execution @@ -199,5 +205,5 @@ of the context are set to ``None``. if current_dag_id is not None and current_dag_id != dag_id: continue # skip generation of non-selected DAG - dag = DAG(dag_id=dag_id, ...) - globals()[dag_id] = dag + with DAG(dag_id=dag_id, ...): + ... diff --git a/docs/apache-airflow/tutorial/taskflow.rst b/docs/apache-airflow/tutorial/taskflow.rst index e3da69e893ba7..4348e75e6d8b7 100644 --- a/docs/apache-airflow/tutorial/taskflow.rst +++ b/docs/apache-airflow/tutorial/taskflow.rst @@ -62,6 +62,18 @@ as shown below, with the Python function name acting as the DAG identifier. :start-after: [START instantiate_dag] :end-before: [END instantiate_dag] +Now to actually enable this to be run as a DAG, we invoke the Python function +``tutorial_taskflow_api`` set up using the ``@dag`` decorator earlier, as shown below. + +.. exampleinclude:: /../../airflow/example_dags/tutorial_taskflow_api.py + :language: python + :start-after: [START dag_invocation] + :end-before: [END dag_invocation] + +.. versionchanged:: 2.4 + + It's no longer required to "register" the DAG into a global variable for Airflow to be able to detect the dag if that DAG is used inside a ``with`` block, or if it is the result of a ``@dag`` decorated function. + Tasks ----- In this data pipeline, tasks are created based on Python functions using the ``@task`` decorator diff --git a/newsfragments/23592.significant.rst b/newsfragments/23592.significant.rst new file mode 100644 index 0000000000000..2fe559f055da3 --- /dev/null +++ b/newsfragments/23592.significant.rst @@ -0,0 +1,40 @@ +DAGS used in a context manager no longer need to be assigned to a module variable + +Previously you had do assign a DAG to a module-level variable in order for Airflow to pick it up. For example this + +.. code-block:: python + + with DAG(dag_id="example") as dag: + ... + + + @dag + def dag_maker(): + ... + + + dag2 = dag_maker() + + +can become + +.. code-block:: python + + with DAG(dag_id="example"): + ... + + + @dag + def dag_maker(): + ... + + + dag_maker() + +If you want to disable the behaviour for any reason then set ``auto_register=False`` on the dag:: + +.. code-block:: + + # This dag will not be picked up by Airflow as it's not assigned to a variable + with DAG(dag_id="example", auto_register=False): + ... diff --git a/tests/dags/test_subdag.py b/tests/dags/test_subdag.py index 9046a43519405..94606ef41791d 100644 --- a/tests/dags/test_subdag.py +++ b/tests/dags/test_subdag.py @@ -63,7 +63,7 @@ def subdag(parent_dag_name, child_dag_name, args): max_active_runs=1, default_args=DEFAULT_TASK_ARGS, schedule=timedelta(minutes=1), -) as dag: +): start = EmptyOperator( task_id='start', diff --git a/tests/dags/test_zip.zip b/tests/dags/test_zip.zip index a09aa4a4d3b21..e1a58d27335a9 100644 Binary files a/tests/dags/test_zip.zip and b/tests/dags/test_zip.zip differ diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py index b9759f64a078d..4579273b07025 100644 --- a/tests/models/test_dagbag.py +++ b/tests/models/test_dagbag.py @@ -49,6 +49,8 @@ from tests.test_utils.asserts import assert_queries_count from tests.test_utils.config import conf_vars +example_dags_folder = pathlib.Path(airflow.example_dags.__path__[0]) # type: ignore[attr-defined] + def db_clean_up(): db.clear_db_dags() @@ -319,23 +321,38 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True): assert dagbag.get_dag(dag_id) is not None assert 1 == dagbag.process_file_calls - def test_get_dag_fileloc(self): - """ - Test that fileloc is correctly set when we load example DAGs, - specifically SubDAGs and packaged DAGs. - """ - dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=True) - dagbag.process_file(os.path.join(TEST_DAGS_FOLDER, "test_zip.zip")) - - expected = { - 'example_bash_operator': 'airflow/example_dags/example_bash_operator.py', - 'example_subdag_operator': 'airflow/example_dags/example_subdag_operator.py', - 'example_subdag_operator.section-1': 'airflow/example_dags/example_subdag_operator.py', - 'test_zip_dag': 'dags/test_zip.zip/test_zip.py', - } - + @pytest.mark.parametrize( + ("file_to_load", "expected"), + ( + pytest.param( + TEST_DAGS_FOLDER / "test_zip.zip", + { + 'test_zip_dag': 'dags/test_zip.zip/test_zip.py', + 'test_zip_autoregister': 'dags/test_zip.zip/test_zip.py', + }, + id='test_zip.zip', + ), + pytest.param( + pathlib.Path(example_dags_folder) / 'example_bash_operator.py', + {'example_bash_operator': 'airflow/example_dags/example_bash_operator.py'}, + id='example_bash_operator', + ), + pytest.param( + TEST_DAGS_FOLDER / 'test_subdag.py', + { + 'test_subdag_operator': 'dags/test_subdag.py', + 'test_subdag_operator.section-1': 'dags/test_subdag.py', + }, + id='test_subdag_operator', + ), + ), + ) + def test_get_dag_registration(self, file_to_load, expected): + dagbag = models.DagBag(dag_folder=os.devnull, include_examples=False) + dagbag.process_file(str(file_to_load)) for dag_id, path in expected.items(): dag = dagbag.get_dag(dag_id) + assert dag, f"{dag_id} was bagged" assert dag.fileloc.endswith(path) @patch.object(DagModel, "get_current") @@ -343,10 +360,9 @@ def test_refresh_py_dag(self, mock_dagmodel): """ Test that we can refresh an ordinary .py DAG """ - example_dags_folder = airflow.example_dags.__path__[0] dag_id = "example_bash_operator" - fileloc = os.path.realpath(os.path.join(example_dags_folder, "example_bash_operator.py")) + fileloc = str(example_dags_folder / "example_bash_operator.py") mock_dagmodel.return_value = DagModel() mock_dagmodel.return_value.last_expired = datetime.max.replace(tzinfo=timezone.utc) @@ -944,8 +960,7 @@ def test_get_dag_with_dag_serialization(self): def test_collect_dags_from_db(self): """DAGs are collected from Database""" db.clear_db_dags() - example_dags_folder = airflow.example_dags.__path__[0] - dagbag = DagBag(example_dags_folder) + dagbag = DagBag(str(example_dags_folder)) example_dags = dagbag.dags for dag in example_dags.values():