Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix impersonation issue with LocalTaskJob #16852

Merged
merged 1 commit into from Jul 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion airflow/jobs/local_task_job.py
Expand Up @@ -19,6 +19,7 @@
import signal
from typing import Optional

import psutil
from sqlalchemy.exc import OperationalError

from airflow.configuration import conf
Expand Down Expand Up @@ -190,9 +191,10 @@ def heartbeat_callback(self, session=None):
fqdn,
)
raise AirflowException("Hostname of job runner does not match")

current_pid = self.task_runner.process.pid
same_process = ti.pid == current_pid
if ti.run_as_user:
same_process = psutil.Process(ti.pid).ppid() == current_pid
if ti.pid is not None and not same_process:
self.log.warning("Recorded pid %s does not match " "the current pid %s", ti.pid, current_pid)
raise AirflowException("PID of job runner does not match")
Expand Down
52 changes: 52 additions & 0 deletions tests/jobs/test_local_task_job.py
Expand Up @@ -147,6 +147,58 @@ def test_localtaskjob_heartbeat(self):
with pytest.raises(AirflowException):
job1.heartbeat_callback()

@mock.patch('airflow.jobs.local_task_job.psutil')
def test_localtaskjob_heartbeat_with_run_as_user(self, psutil_mock):
session = settings.Session()
dag = DAG('test_localtaskjob_heartbeat', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})

with dag:
op1 = DummyOperator(task_id='op1', run_as_user='myuser')

dag.clear()
dr = dag.create_dagrun(
run_id="test",
state=State.SUCCESS,
execution_date=DEFAULT_DATE,
start_date=DEFAULT_DATE,
session=session,
)

ti = dr.get_task_instance(task_id=op1.task_id, session=session)
ti.state = State.RUNNING
ti.pid = 2
ti.hostname = get_hostname()
session.commit()

job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
ti.task = op1
ti.refresh_from_task(op1)
job1.task_runner = StandardTaskRunner(job1)
job1.task_runner.process = mock.Mock()
job1.task_runner.process.pid = 2
# Here, ti.pid is 2, the parent process of ti.pid is a mock(different).
# And task_runner process is 2. Should fail
with pytest.raises(AirflowException, match='PID of job runner does not match'):
job1.heartbeat_callback()

job1.task_runner.process.pid = 1
# We make the parent process of ti.pid to equal the task_runner process id
psutil_mock.Process.return_value.ppid.return_value = 1
ti.state = State.RUNNING
ti.pid = 2
# The task_runner process id is 1, same as the parent process of ti.pid
# as seen above
assert ti.run_as_user
session.merge(ti)
session.commit()
job1.heartbeat_callback(session=None)

# Here the task_runner process id is changed to 2
# while parent process of ti.pid is kept at 1, which is different
job1.task_runner.process.pid = 2
with pytest.raises(AirflowException, match='PID of job runner does not match'):
job1.heartbeat_callback()

def test_heartbeat_failed_fast(self):
"""
Test that task heartbeat will sleep when it fails fast
Expand Down