From 075622cbeb7e9d20b6936de37c8f5abccc6f882e Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Wed, 7 Jul 2021 15:06:21 +0100 Subject: [PATCH] Fix impersonation issue with LocalTaskJob (#16852) Running a task with run_as_user fails because PIDs are not matched correctly. This change fixes it by matching the parent process ID (the `sudo` process) of the task instance to the current process ID of the task_runner process when we use impersonation Co-authored-by: Ash Berlin-Taylor (cherry picked from commit feea38057ae16b5c09dfdda19552a5e75c01a2dd) (cherry picked from commit 26a2bebb02f37b927ed743218747a8a924aaf7b6) --- airflow/jobs/local_task_job.py | 5 ++- tests/jobs/test_local_task_job.py | 54 +++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py index efd84d6efcf7c..c697a852b18df 100644 --- a/airflow/jobs/local_task_job.py +++ b/airflow/jobs/local_task_job.py @@ -20,6 +20,8 @@ import signal from typing import Optional +import psutil + from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.jobs.base_job import BaseJob @@ -188,9 +190,10 @@ def heartbeat_callback(self, session=None): fqdn, ) raise AirflowException("Hostname of job runner does not match") - current_pid = self.task_runner.process.pid same_process = ti.pid == current_pid + if ti.run_as_user: + same_process = psutil.Process(ti.pid).ppid() == current_pid if ti.pid is not None and not same_process: self.log.warning("Recorded pid %s does not match " "the current pid %s", ti.pid, current_pid) raise AirflowException("PID of job runner does not match") diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py index 97645e457398a..1eaad46dd0992 100644 --- a/tests/jobs/test_local_task_job.py +++ b/tests/jobs/test_local_task_job.py @@ -49,6 +49,8 @@ from tests.test_utils.db import clear_db_jobs, clear_db_runs from tests.test_utils.mock_executor import MockExecutor +# pylint: skip-file + DEFAULT_DATE = timezone.datetime(2016, 1, 1) TEST_DAG_FOLDER = os.environ['AIRFLOW__CORE__DAGS_FOLDER'] @@ -135,6 +137,58 @@ def test_localtaskjob_heartbeat(self): with pytest.raises(AirflowException): job1.heartbeat_callback() # pylint: disable=no-value-for-parameter + @mock.patch('airflow.jobs.local_task_job.psutil') + def test_localtaskjob_heartbeat_with_run_as_user(self, psutil_mock): + session = settings.Session() + dag = DAG('test_localtaskjob_heartbeat', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) + + with dag: + op1 = DummyOperator(task_id='op1', run_as_user='myuser') + + dag.clear() + dr = dag.create_dagrun( + run_id="test", + state=State.SUCCESS, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session, + ) + + ti = dr.get_task_instance(task_id=op1.task_id, session=session) + ti.state = State.RUNNING + ti.pid = 2 + ti.hostname = get_hostname() + session.commit() + + job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) + ti.task = op1 + ti.refresh_from_task(op1) + job1.task_runner = StandardTaskRunner(job1) + job1.task_runner.process = mock.Mock() + job1.task_runner.process.pid = 2 + # Here, ti.pid is 2, the parent process of ti.pid is a mock(different). + # And task_runner process is 2. Should fail + with pytest.raises(AirflowException, match='PID of job runner does not match'): + job1.heartbeat_callback() + + job1.task_runner.process.pid = 1 + # We make the parent process of ti.pid to equal the task_runner process id + psutil_mock.Process.return_value.ppid.return_value = 1 + ti.state = State.RUNNING + ti.pid = 2 + # The task_runner process id is 1, same as the parent process of ti.pid + # as seen above + assert ti.run_as_user + session.merge(ti) + session.commit() + job1.heartbeat_callback(session=None) + + # Here the task_runner process id is changed to 2 + # while parent process of ti.pid is kept at 1, which is different + job1.task_runner.process.pid = 2 + with pytest.raises(AirflowException, match='PID of job runner does not match'): + job1.heartbeat_callback() + def test_heartbeat_failed_fast(self): """ Test that task heartbeat will sleep when it fails fast