From ee8b8651057f44107a215bfb482277c9986fdfd5 Mon Sep 17 00:00:00 2001 From: Jens Scheffler Date: Sun, 28 Mar 2021 11:19:42 +0200 Subject: [PATCH 01/12] Add validation for dag_run conf to be a dict to WWW and API endpoints incl. pytests #15023 --- airflow/www/api/experimental/endpoints.py | 6 + airflow/www/templates/airflow/trigger.html | 2 +- airflow/www/views.py | 7 +- .../endpoints/test_dag_run_endpoint.py | 21 + tests/www/api/experimental/test_endpoints.py | 26 +- tests/www/test_views.py | 761 ++++++++++++++++++ 6 files changed, 814 insertions(+), 9 deletions(-) diff --git a/airflow/www/api/experimental/endpoints.py b/airflow/www/api/experimental/endpoints.py index 78cacde9123ff..9e8c59dc3a5d4 100644 --- a/airflow/www/api/experimental/endpoints.py +++ b/airflow/www/api/experimental/endpoints.py @@ -88,6 +88,12 @@ def trigger_dag(dag_id): conf = None if 'conf' in data: conf = data['conf'] + if type(conf) is not dict: + error_message = 'Dag Run conf must be a dictionary object, other types are not supported' + log.error(error_message) + response = jsonify({'error': error_message}) + response.status_code = 400 + return response execution_date = None if 'execution_date' in data and data['execution_date'] is not None: diff --git a/airflow/www/templates/airflow/trigger.html b/airflow/www/templates/airflow/trigger.html index c4187f1fcbf3f..80164b81c677a 100644 --- a/airflow/www/templates/airflow/trigger.html +++ b/airflow/www/templates/airflow/trigger.html @@ -35,7 +35,7 @@

Trigger DAG: {{ dag_id }}

- +

diff --git a/airflow/www/views.py b/airflow/www/views.py index 5a6faece479e6..d8458ccd8a0fb 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -1519,8 +1519,13 @@ def trigger(self, session=None): if request_conf: try: run_conf = json.loads(request_conf) + if type(run_conf) is not dict: + flash("Invalid JSON configuration, must be a dict", "error") + return self.render_template( + 'airflow/trigger.html', dag_id=dag_id, origin=origin, conf=request_conf + ) except json.decoder.JSONDecodeError: - flash("Invalid JSON configuration", "error") + flash("Invalid JSON configuration, not parseable", "error") return self.render_template( 'airflow/trigger.html', dag_id=dag_id, origin=origin, conf=request_conf ) diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index 4ea91107eaf3c..134ef14c30b99 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -945,6 +945,27 @@ def test_should_response_400_for_naive_datetime_and_bad_datetime(self, data, exp assert response.status_code == 400 assert response.json['detail'] == expected + @parameterized.expand( + [ + ( + "Conf is an array, not a dict", + { + "dag_run_id": "TEST_DAG_RUN", + "execution_date": "2020-06-11T18:00:00+00:00", + "conf": "some string" + }, + "'some string' is not of type 'object' - 'conf'" + ) + ] + ) + def test_should_response_400_for_non_dict_dagrun_conf(self, name, data, expected): + self._create_dag("TEST_DAG_ID") + response = self.client.post( + "api/v1/dags/TEST_DAG_ID/dagRuns", json=data, environ_overrides={'REMOTE_USER': "test"} + ) + assert response.status_code == 400 + assert response.json['detail'] == expected + def test_response_404(self): response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns", diff --git a/tests/www/api/experimental/test_endpoints.py b/tests/www/api/experimental/test_endpoints.py index 8c6734c7a4ca8..6ef24888883dd 100644 --- a/tests/www/api/experimental/test_endpoints.py +++ b/tests/www/api/experimental/test_endpoints.py @@ -148,9 +148,27 @@ def test_dag_paused(self): def test_trigger_dag(self): url_template = '/api/experimental/dags/{}/dag_runs' run_id = 'my_run' + utcnow().isoformat() + + # Test error for nonexistent dag + response = self.client.post( + url_template.format('does_not_exist_dag'), data=json.dumps({}), content_type="application/json" + ) + assert 404 == response.status_code + + # Test error for bad conf data + response = self.client.post( + url_template.format('example_bash_operator'), data=json.dumps({'conf': 'This is a string not a dict'}), content_type="application/json" + ) + assert 400 == response.status_code + + # Test OK case response = self.client.post( url_template.format('example_bash_operator'), - data=json.dumps({'run_id': run_id}), + data=json.dumps({ + 'run_id': run_id, + 'conf': { + 'param': 'value' + }}), content_type="application/json", ) self.assert_deprecated(response) @@ -168,12 +186,6 @@ def test_trigger_dag(self): assert run_id == dag_run_id assert dag_run_id == response['run_id'] - # Test error for nonexistent dag - response = self.client.post( - url_template.format('does_not_exist_dag'), data=json.dumps({}), content_type="application/json" - ) - assert 404 == response.status_code - def test_trigger_dag_for_date(self): url_template = '/api/experimental/dags/{}/dag_runs' dag_id = 'example_bash_operator' diff --git a/tests/www/test_views.py b/tests/www/test_views.py index db6f5616928bc..7c7439b7a5d93 100644 --- a/tests/www/test_views.py +++ b/tests/www/test_views.py @@ -2630,3 +2630,764 @@ def test_refresh_failure_for_viewer(self): self.login(username='test_viewer', password='test_viewer') resp = self.client.post('refresh?dag_id=example_bash_operator') self.check_content_in_response('Redirecting', resp, resp_code=302) + + +class TestTaskInstanceView(TestBase): + TI_ENDPOINT = '/taskinstance/list/?_flt_0_execution_date={}' + + def test_start_date_filter(self): + resp = self.client.get(self.TI_ENDPOINT.format(self.percent_encode('2018-10-09 22:44:31'))) + # We aren't checking the logic of the date filter itself (that is built + # in to FAB) but simply that our UTC conversion was run - i.e. it + # doesn't blow up! + self.check_content_in_response('List Task Instance', resp) + + +class TestTaskRescheduleView(TestBase): + TI_ENDPOINT = '/taskreschedule/list/?_flt_0_execution_date={}' + + def test_start_date_filter(self): + resp = self.client.get(self.TI_ENDPOINT.format(self.percent_encode('2018-10-09 22:44:31'))) + # We aren't checking the logic of the date filter itself (that is built + # in to FAB) but simply that our UTC conversion was run - i.e. it + # doesn't blow up! + self.check_content_in_response('List Task Reschedule', resp) + + +class TestRenderedView(TestBase): + def setUp(self): + + self.default_date = datetime(2020, 3, 1) + self.dag = DAG( + "testdag", + start_date=self.default_date, + user_defined_filters={"hello": lambda name: f'Hello {name}'}, + user_defined_macros={"fullname": lambda fname, lname: f'{fname} {lname}'}, + ) + self.task1 = BashOperator(task_id='task1', bash_command='{{ task_instance_key_str }}', dag=self.dag) + self.task2 = BashOperator( + task_id='task2', bash_command='echo {{ fullname("Apache", "Airflow") | hello }}', dag=self.dag + ) + SerializedDagModel.write_dag(self.dag) + with create_session() as session: + session.query(RTIF).delete() + + self.app.dag_bag = mock.MagicMock(**{'get_dag.return_value': self.dag}) + super().setUp() + + def tearDown(self) -> None: + super().tearDown() + with create_session() as session: + session.query(RTIF).delete() + + def test_rendered_template_view(self): + """ + Test that the Rendered View contains the values from RenderedTaskInstanceFields + """ + assert self.task1.bash_command == '{{ task_instance_key_str }}' + ti = TaskInstance(self.task1, self.default_date) + + with create_session() as session: + session.add(RTIF(ti)) + + url = 'rendered-templates?task_id=task1&dag_id=testdag&execution_date={}'.format( + self.percent_encode(self.default_date) + ) + + resp = self.client.get(url, follow_redirects=True) + self.check_content_in_response("testdag__task1__20200301", resp) + + def test_rendered_template_view_for_unexecuted_tis(self): + """ + Test that the Rendered View is able to show rendered values + even for TIs that have not yet executed + """ + assert self.task1.bash_command == '{{ task_instance_key_str }}' + + url = 'rendered-templates?task_id=task1&dag_id=task1&execution_date={}'.format( + self.percent_encode(self.default_date) + ) + + resp = self.client.get(url, follow_redirects=True) + self.check_content_in_response("testdag__task1__20200301", resp) + + def test_user_defined_filter_and_macros_raise_error(self): + """ + Test that the Rendered View is able to show rendered values + even for TIs that have not yet executed + """ + self.app.dag_bag = mock.MagicMock( + **{'get_dag.return_value': SerializedDagModel.get(self.dag.dag_id).dag} + ) + assert self.task2.bash_command == 'echo {{ fullname("Apache", "Airflow") | hello }}' + + url = 'rendered-templates?task_id=task2&dag_id=testdag&execution_date={}'.format( + self.percent_encode(self.default_date) + ) + + resp = self.client.get(url, follow_redirects=True) + self.check_content_not_in_response("echo Hello Apache Airflow", resp) + self.check_content_in_response( + "Webserver does not have access to User-defined Macros or Filters " + "when Dag Serialization is enabled. Hence for the task that have not yet " + "started running, please use 'airflow tasks render' for debugging the " + "rendering of template_fields.

OriginalError: no filter named 'hello'", + resp, + ) + + +class TestTriggerDag(TestBase): + def setUp(self): + super().setUp() + models.DagBag().get_dag("example_bash_operator").sync_to_db(session=self.session) + self.session.commit() + + def test_trigger_dag_button_normal_exist(self): + resp = self.client.get('/', follow_redirects=True) + assert '/trigger?dag_id=example_bash_operator' in resp.data.decode('utf-8') + assert "return confirmDeleteDag(this, 'example_bash_operator')" in resp.data.decode('utf-8') + + @pytest.mark.quarantined + def test_trigger_dag_button(self): + + test_dag_id = "example_bash_operator" + + DR = models.DagRun # pylint: disable=invalid-name + self.session.query(DR).delete() + self.session.commit() + + self.client.post(f'trigger?dag_id={test_dag_id}') + + run = self.session.query(DR).filter(DR.dag_id == test_dag_id).first() + assert run is not None + assert DagRunType.MANUAL in run.run_id + assert run.run_type == DagRunType.MANUAL + + @pytest.mark.quarantined + def test_trigger_dag_conf(self): + + test_dag_id = "example_bash_operator" + conf_dict = {'string': 'Hello, World!'} + + DR = models.DagRun # pylint: disable=invalid-name + self.session.query(DR).delete() + self.session.commit() + + self.client.post(f'trigger?dag_id={test_dag_id}', data={'conf': json.dumps(conf_dict)}) + + run = self.session.query(DR).filter(DR.dag_id == test_dag_id).first() + assert run is not None + assert DagRunType.MANUAL in run.run_id + assert run.run_type == DagRunType.MANUAL + assert run.conf == conf_dict + + def test_trigger_dag_conf_malformed(self): + test_dag_id = "example_bash_operator" + + DR = models.DagRun # pylint: disable=invalid-name + self.session.query(DR).delete() + self.session.commit() + + response = self.client.post(f'trigger?dag_id={test_dag_id}', data={'conf': '{"a": "b"'}) + self.check_content_in_response('Invalid JSON configuration', response) + + run = self.session.query(DR).filter(DR.dag_id == test_dag_id).first() + assert run is None + + def test_trigger_dag_conf_not_dict(self): + test_dag_id = "example_bash_operator" + + DR = models.DagRun # pylint: disable=invalid-name + self.session.query(DR).delete() + self.session.commit() + + response = self.client.post(f'trigger?dag_id={test_dag_id}', data={'conf': 'string and not a dict'}) + self.check_content_in_response('must be a dict', response) + + run = self.session.query(DR).filter(DR.dag_id == test_dag_id).first() + assert run is None + + def test_trigger_dag_form(self): + test_dag_id = "example_bash_operator" + resp = self.client.get(f'trigger?dag_id={test_dag_id}') + self.check_content_in_response(f'Trigger DAG: {test_dag_id}', resp) + + @parameterized.expand( + [ + ("javascript:alert(1)", "/home"), + ("http://google.com", "/home"), + ("36539'%3balert(1)%2f%2f166", "/home"), + ( + "%2Ftree%3Fdag_id%3Dexample_bash_operator';alert(33)//", + "/home", + ), + ("%2Ftree%3Fdag_id%3Dexample_bash_operator", "/tree?dag_id=example_bash_operator"), + ("%2Fgraph%3Fdag_id%3Dexample_bash_operator", "/graph?dag_id=example_bash_operator"), + ] + ) + def test_trigger_dag_form_origin_url(self, test_origin, expected_origin): + test_dag_id = "example_bash_operator" + + resp = self.client.get(f'trigger?dag_id={test_dag_id}&origin={test_origin}') + self.check_content_in_response( + '