From d772f38f843b9add5319a01cf51a844145b01f63 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Thu, 1 Jul 2021 12:22:42 +0100 Subject: [PATCH 1/2] Only allow webserver to request from the worker log server Logs _shouldn't_ contain any sensitive info, but they often do by mistake. As an extra level of protection we shouldn't allow anything other than the webserver to access the logs. (We can't change the bind IP form 0.0.0.0 as for it to be useful it needs to be accessed from different hosts -- i.e. the webserver will almost always be on a different node) --- airflow/utils/log/file_task_handler.py | 9 ++- airflow/utils/serve_logs.py | 58 ++++++++++++--- tests/utils/test_serve_logs.py | 99 ++++++++++++++++++++------ 3 files changed, 132 insertions(+), 34 deletions(-) diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index 1ad3689073afb..7aa1677425d6c 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Optional import httpx +from itsdangerous import TimedJSONWebSignatureSerializer from airflow.configuration import AirflowConfigException, conf from airflow.utils.helpers import parse_template_string @@ -172,7 +173,13 @@ def _read(self, ti, try_number, metadata=None): except (AirflowConfigException, ValueError): pass - response = httpx.get(url, timeout=timeout) + signer = TimedJSONWebSignatureSerializer( + secret_key=conf.get('webserver', 'secret_key'), + algorithm_name='HS512', + expires_in=conf.getint('webserver', 'log_request_clock_grace', fallback=30), + ) + + response = httpx.get(url, timeout=timeout, headers={'Authorization': signer.dumps({})}) response.encoding = "utf-8" # Check if the resource was properly fetched diff --git a/airflow/utils/serve_logs.py b/airflow/utils/serve_logs.py index 45c4400680a2f..c56204573887d 100644 --- a/airflow/utils/serve_logs.py +++ b/airflow/utils/serve_logs.py @@ -17,25 +17,61 @@ """Serve logs process""" import os +import time -import flask +from flask import Flask, abort, request, send_from_directory +from itsdangerous import TimedJSONWebSignatureSerializer from setproctitle import setproctitle from airflow.configuration import conf -def serve_logs(): - """Serves logs generated by Worker""" - print("Starting flask") - flask_app = flask.Flask(__name__) - setproctitle("airflow serve-logs") +def flask_app(): + flask_app = Flask(__name__) + max_request_age = conf.getint('webserver', 'log_request_clock_grace', fallback=30) + log_directory = os.path.expanduser(conf.get('logging', 'BASE_LOG_FOLDER')) + + signer = TimedJSONWebSignatureSerializer( + secret_key=conf.get('webserver', 'secret_key'), + algorithm_name='HS512', + expires_in=max_request_age, + ) + + # Prevent direct access to the logs port + @flask_app.before_request + def validate_pre_signed_url(): + try: + auth = request.headers['Authorization'] + + # We don't actually care about the payload, just that the signature + # was valid and the `exp` claim is correct + _, headers = signer.loads(auth, return_header=True) + + issued_at = int(headers['iat']) + expires_at = int(headers['exp']) + except Exception as e: + print(e) + abort(403) + # Validate the `iat` and `exp` are within `max_request_age` of now. + now = int(time.time()) + if abs(now - issued_at) > max_request_age: + abort(403) + if abs(now - expires_at) > max_request_age: + abort(403) + if issued_at > expires_at or expires_at - issued_at > max_request_age: + abort(403) @flask_app.route('/log/') def serve_logs_view(filename): - log_directory = os.path.expanduser(conf.get('logging', 'BASE_LOG_FOLDER')) - return flask.send_from_directory( - log_directory, filename, mimetype="application/json", as_attachment=False - ) + return send_from_directory(log_directory, filename, mimetype="application/json", as_attachment=False) + + return flask_app + + +def serve_logs(): + """Serves logs generated by Worker""" + setproctitle("airflow serve-logs") + app = flask_app() worker_log_server_port = conf.getint('celery', 'WORKER_LOG_SERVER_PORT') - flask_app.run(host='0.0.0.0', port=worker_log_server_port) + app.run(host='0.0.0.0', port=worker_log_server_port) diff --git a/tests/utils/test_serve_logs.py b/tests/utils/test_serve_logs.py index edb4d0991b3e6..c3a012b621227 100644 --- a/tests/utils/test_serve_logs.py +++ b/tests/utils/test_serve_logs.py @@ -14,33 +14,88 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import os -import unittest -from multiprocessing import Process -from os.path import basename -from tempfile import NamedTemporaryFile -from time import sleep +from typing import TYPE_CHECKING import pytest -import requests +from itsdangerous import TimedJSONWebSignatureSerializer from airflow.configuration import conf -from airflow.utils.serve_logs import serve_logs +from airflow.utils.serve_logs import flask_app +from tests.test_utils.config import conf_vars + +if TYPE_CHECKING: + from flask.testing import FlaskClient LOG_DATA = "Airflow log data" * 20 -@pytest.mark.quarantined -class TestServeLogs(unittest.TestCase): - def test_should_serve_file(self): - log_dir = os.path.expanduser(conf.get('logging', 'BASE_LOG_FOLDER')) - log_port = conf.get('celery', 'WORKER_LOG_SERVER_PORT') - with NamedTemporaryFile(dir=log_dir) as f: - f.write(LOG_DATA.encode()) - f.flush() - sub_proc = Process(target=serve_logs) - sub_proc.start() - sleep(1) - log_url = f"http://localhost:{log_port}/log/{basename(f.name)}" - assert LOG_DATA == requests.get(log_url).content.decode() - sub_proc.terminate() +@pytest.fixture +def client(tmpdir): + with conf_vars({('logging', 'base_log_folder'): str(tmpdir)}): + app = flask_app() + + yield app.test_client() + + +@pytest.fixture +def sample_log(tmpdir): + f = tmpdir / 'sample.log' + f.write(LOG_DATA.encode()) + + return f + + +@pytest.mark.usefixtures('sample_log') +class TestServeLogs: + def test_forbidden_no_auth(self, client: "FlaskClient"): + assert 403 == client.get('/log/sample.log').status_code + + def test_should_serve_file(self, client: "FlaskClient"): + signer = TimedJSONWebSignatureSerializer( + secret_key=conf.get('webserver', 'secret_key'), + algorithm_name='HS512', + expires_in=30, + ) + assert ( + LOG_DATA + == client.get( + '/log/sample.log', + headers={ + 'Authorization': signer.dumps({}), + }, + ).data.decode() + ) + + def test_forbidden_too_long_validity(self, client: "FlaskClient"): + signer = TimedJSONWebSignatureSerializer( + secret_key=conf.get('webserver', 'secret_key'), + algorithm_name='HS512', + expires_in=3600, + ) + assert ( + 403 + == client.get( + '/log/sample.log', + headers={ + 'Authorization': signer.dumps({}), + }, + ).status_code + ) + + def test_forbidden_expired(self, client: "FlaskClient"): + signer = TimedJSONWebSignatureSerializer( + secret_key=conf.get('webserver', 'secret_key'), + algorithm_name='HS512', + expires_in=30, + ) + # Fake the time we think we are + signer.now = lambda: 0 + assert ( + 403 + == client.get( + '/log/sample.log', + headers={ + 'Authorization': signer.dumps({}), + }, + ).status_code + ) From 27265516d2b897585f5019ecd820cfe5471fd351 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Thu, 1 Jul 2021 15:20:16 +0100 Subject: [PATCH 2/2] fixup! Only allow webserver to request from the worker log server --- airflow/utils/log/file_task_handler.py | 6 ++- airflow/utils/serve_logs.py | 11 ++++-- tests/utils/test_serve_logs.py | 51 +++++++++++++++----------- 3 files changed, 43 insertions(+), 25 deletions(-) diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index 7aa1677425d6c..2dc9beb57b0ac 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -177,9 +177,13 @@ def _read(self, ti, try_number, metadata=None): secret_key=conf.get('webserver', 'secret_key'), algorithm_name='HS512', expires_in=conf.getint('webserver', 'log_request_clock_grace', fallback=30), + # This isn't really a "salt", more of a signing context + salt='task-instance-logs', ) - response = httpx.get(url, timeout=timeout, headers={'Authorization': signer.dumps({})}) + response = httpx.get( + url, timeout=timeout, headers={'Authorization': signer.dumps(log_relative_path)} + ) response.encoding = "utf-8" # Check if the resource was properly fetched diff --git a/airflow/utils/serve_logs.py b/airflow/utils/serve_logs.py index c56204573887d..463b0e4b7477d 100644 --- a/airflow/utils/serve_logs.py +++ b/airflow/utils/serve_logs.py @@ -35,6 +35,8 @@ def flask_app(): secret_key=conf.get('webserver', 'secret_key'), algorithm_name='HS512', expires_in=max_request_age, + # This isn't really a "salt", more of a signing context + salt='task-instance-logs', ) # Prevent direct access to the logs port @@ -45,13 +47,16 @@ def validate_pre_signed_url(): # We don't actually care about the payload, just that the signature # was valid and the `exp` claim is correct - _, headers = signer.loads(auth, return_header=True) + filename, headers = signer.loads(auth, return_header=True) issued_at = int(headers['iat']) expires_at = int(headers['exp']) - except Exception as e: - print(e) + except Exception: abort(403) + + if filename != request.view_args['filename']: + abort(403) + # Validate the `iat` and `exp` are within `max_request_age` of now. now = int(time.time()) if abs(now - issued_at) > max_request_age: diff --git a/tests/utils/test_serve_logs.py b/tests/utils/test_serve_logs.py index c3a012b621227..9ae76aa358ae3 100644 --- a/tests/utils/test_serve_logs.py +++ b/tests/utils/test_serve_logs.py @@ -45,49 +45,46 @@ def sample_log(tmpdir): return f +@pytest.fixture +def signer(): + return TimedJSONWebSignatureSerializer( + secret_key=conf.get('webserver', 'secret_key'), + algorithm_name='HS512', + expires_in=30, + # This isn't really a "salt", more of a signing context + salt='task-instance-logs', + ) + + @pytest.mark.usefixtures('sample_log') class TestServeLogs: def test_forbidden_no_auth(self, client: "FlaskClient"): assert 403 == client.get('/log/sample.log').status_code - def test_should_serve_file(self, client: "FlaskClient"): - signer = TimedJSONWebSignatureSerializer( - secret_key=conf.get('webserver', 'secret_key'), - algorithm_name='HS512', - expires_in=30, - ) + def test_should_serve_file(self, client: "FlaskClient", signer): assert ( LOG_DATA == client.get( '/log/sample.log', headers={ - 'Authorization': signer.dumps({}), + 'Authorization': signer.dumps('sample.log'), }, ).data.decode() ) - def test_forbidden_too_long_validity(self, client: "FlaskClient"): - signer = TimedJSONWebSignatureSerializer( - secret_key=conf.get('webserver', 'secret_key'), - algorithm_name='HS512', - expires_in=3600, - ) + def test_forbidden_too_long_validity(self, client: "FlaskClient", signer): + signer.expires_in = 3600 assert ( 403 == client.get( '/log/sample.log', headers={ - 'Authorization': signer.dumps({}), + 'Authorization': signer.dumps('sample.log'), }, ).status_code ) - def test_forbidden_expired(self, client: "FlaskClient"): - signer = TimedJSONWebSignatureSerializer( - secret_key=conf.get('webserver', 'secret_key'), - algorithm_name='HS512', - expires_in=30, - ) + def test_forbidden_expired(self, client: "FlaskClient", signer): # Fake the time we think we are signer.now = lambda: 0 assert ( @@ -95,7 +92,19 @@ def test_forbidden_expired(self, client: "FlaskClient"): == client.get( '/log/sample.log', headers={ - 'Authorization': signer.dumps({}), + 'Authorization': signer.dumps('sample.log'), + }, + ).status_code + ) + + def test_wrong_context(self, client: "FlaskClient", signer): + signer.salt = None + assert ( + 403 + == client.get( + '/log/sample.log', + headers={ + 'Authorization': signer.dumps('sample.log'), }, ).status_code )