diff --git a/airflow/jobs/base_job.py b/airflow/jobs/base_job.py index 1837c27992963..4edb692251c6b 100644 --- a/airflow/jobs/base_job.py +++ b/airflow/jobs/base_job.py @@ -25,6 +25,7 @@ from sqlalchemy.orm import backref, foreign, relationship from sqlalchemy.orm.session import make_transient +from airflow.compat.functools import cached_property from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.executors.executor_loader import ExecutorLoader @@ -94,8 +95,11 @@ class BaseJob(Base, LoggingMixin): def __init__(self, executor=None, heartrate=None, *args, **kwargs): self.hostname = get_hostname() - self.executor = executor or ExecutorLoader.get_default_executor() - self.executor_class = self.executor.__class__.__name__ + if executor: + self.executor = executor + self.executor_class = executor.__class__.__name__ + else: + self.executor_class = conf.get('core', 'EXECUTOR') self.start_date = timezone.utcnow() self.latest_heartbeat = timezone.utcnow() if heartrate is not None: @@ -104,6 +108,10 @@ def __init__(self, executor=None, heartrate=None, *args, **kwargs): self.max_tis_per_query = conf.getint('scheduler', 'max_tis_per_query') super().__init__(*args, **kwargs) + @cached_property + def executor(self): + return ExecutorLoader.get_default_executor() + @classmethod @provide_session def most_recent_job(cls, session=None) -> Optional['BaseJob']: diff --git a/tests/jobs/test_base_job.py b/tests/jobs/test_base_job.py index 093386b3021cd..93f26309700e0 100644 --- a/tests/jobs/test_base_job.py +++ b/tests/jobs/test_base_job.py @@ -118,7 +118,12 @@ def test_heartbeat_failed(self, mock_create_session): assert job.latest_heartbeat == when, "attribute not updated when heartbeat fails" - @conf_vars({('scheduler', 'max_tis_per_query'): '100'}) + @conf_vars( + { + ('scheduler', 'max_tis_per_query'): '100', + ('core', 'executor'): 'SequentialExecutor', + } + ) @patch('airflow.jobs.base_job.ExecutorLoader.get_default_executor') @patch('airflow.jobs.base_job.get_hostname') @patch('airflow.jobs.base_job.getuser')