From f0e6ede5c134924a01b955af02267acc2c110ff4 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Wed, 7 Jul 2021 11:25:59 +0100 Subject: [PATCH] Fix impersonation issue with LocalTaskJob Running a task with run_as_user fails because PIDs are not matched correctly. This change fixes it by matching the parent process ID of the task instance to the current process ID of the task_runner process when we use impersonation Update tests/jobs/test_local_task_job.py Co-authored-by: Ash Berlin-Taylor fixup! Update tests/jobs/test_local_task_job.py fixup! Fix impersonation issue with LocalTaskJob --- airflow/jobs/local_task_job.py | 4 ++- tests/jobs/test_local_task_job.py | 52 +++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py index 574d45e9e68cd..6852576fc8a90 100644 --- a/airflow/jobs/local_task_job.py +++ b/airflow/jobs/local_task_job.py @@ -19,6 +19,7 @@ import signal from typing import Optional +import psutil from sqlalchemy.exc import OperationalError from airflow.configuration import conf @@ -190,9 +191,10 @@ def heartbeat_callback(self, session=None): fqdn, ) raise AirflowException("Hostname of job runner does not match") - current_pid = self.task_runner.process.pid same_process = ti.pid == current_pid + if ti.run_as_user: + same_process = psutil.Process(ti.pid).ppid() == current_pid if ti.pid is not None and not same_process: self.log.warning("Recorded pid %s does not match " "the current pid %s", ti.pid, current_pid) raise AirflowException("PID of job runner does not match") diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py index 456d1b458f8b4..b2c5a1e46ec33 100644 --- a/tests/jobs/test_local_task_job.py +++ b/tests/jobs/test_local_task_job.py @@ -147,6 +147,58 @@ def test_localtaskjob_heartbeat(self): with pytest.raises(AirflowException): job1.heartbeat_callback() + @mock.patch('airflow.jobs.local_task_job.psutil') + def test_localtaskjob_heartbeat_with_run_as_user(self, psutil_mock): + session = settings.Session() + dag = DAG('test_localtaskjob_heartbeat', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) + + with dag: + op1 = DummyOperator(task_id='op1', run_as_user='myuser') + + dag.clear() + dr = dag.create_dagrun( + run_id="test", + state=State.SUCCESS, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session, + ) + + ti = dr.get_task_instance(task_id=op1.task_id, session=session) + ti.state = State.RUNNING + ti.pid = 2 + ti.hostname = get_hostname() + session.commit() + + job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) + ti.task = op1 + ti.refresh_from_task(op1) + job1.task_runner = StandardTaskRunner(job1) + job1.task_runner.process = mock.Mock() + job1.task_runner.process.pid = 2 + # Here, ti.pid is 2, the parent process of ti.pid is a mock(different). + # And task_runner process is 2. Should fail + with pytest.raises(AirflowException, match='PID of job runner does not match'): + job1.heartbeat_callback() + + job1.task_runner.process.pid = 1 + # We make the parent process of ti.pid to equal the task_runner process id + psutil_mock.Process.return_value.ppid.return_value = 1 + ti.state = State.RUNNING + ti.pid = 2 + # The task_runner process id is 1, same as the parent process of ti.pid + # as seen above + assert ti.run_as_user + session.merge(ti) + session.commit() + job1.heartbeat_callback(session=None) + + # Here the task_runner process id is changed to 2 + # while parent process of ti.pid is kept at 1, which is different + job1.task_runner.process.pid = 2 + with pytest.raises(AirflowException, match='PID of job runner does not match'): + job1.heartbeat_callback() + def test_heartbeat_failed_fast(self): """ Test that task heartbeat will sleep when it fails fast