diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index b8590207f5c03..1b52a4f55d54a 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Optional import httpx +from itsdangerous import TimedJSONWebSignatureSerializer from airflow.configuration import AirflowConfigException, conf from airflow.utils.helpers import parse_template_string @@ -172,7 +173,17 @@ def _read(self, ti, try_number, metadata=None): # pylint: disable=unused-argume except (AirflowConfigException, ValueError): pass - response = httpx.get(url, timeout=timeout) + signer = TimedJSONWebSignatureSerializer( + secret_key=conf.get('webserver', 'secret_key'), + algorithm_name='HS512', + expires_in=conf.getint('webserver', 'log_request_clock_grace', fallback=30), + # This isn't really a "salt", more of a signing context + salt='task-instance-logs', + ) + + response = httpx.get( + url, timeout=timeout, headers={'Authorization': signer.dumps(log_relative_path)} + ) response.encoding = "utf-8" # Check if the resource was properly fetched diff --git a/airflow/utils/serve_logs.py b/airflow/utils/serve_logs.py index 0fefa420b7d84..fd5eadb2ac5a6 100644 --- a/airflow/utils/serve_logs.py +++ b/airflow/utils/serve_logs.py @@ -16,24 +16,70 @@ # under the License. """Serve logs process""" + +# pylint: skip-file + import os +import time -import flask +from flask import Flask, abort, request, send_from_directory +from itsdangerous import TimedJSONWebSignatureSerializer +from setproctitle import setproctitle from airflow.configuration import conf -def serve_logs(): - """Serves logs generated by Worker""" - print("Starting flask") - flask_app = flask.Flask(__name__) +def flask_app(): # noqa: D103 + flask_app = Flask(__name__) + max_request_age = conf.getint('webserver', 'log_request_clock_grace', fallback=30) + log_directory = os.path.expanduser(conf.get('logging', 'BASE_LOG_FOLDER')) + + signer = TimedJSONWebSignatureSerializer( + secret_key=conf.get('webserver', 'secret_key'), + algorithm_name='HS512', + expires_in=max_request_age, + # This isn't really a "salt", more of a signing context + salt='task-instance-logs', + ) + + # Prevent direct access to the logs port + @flask_app.before_request + def validate_pre_signed_url(): + try: + auth = request.headers['Authorization'] + + # We don't actually care about the payload, just that the signature + # was valid and the `exp` claim is correct + filename, headers = signer.loads(auth, return_header=True) + + issued_at = int(headers['iat']) + expires_at = int(headers['exp']) + except Exception: + abort(403) + + if filename != request.view_args['filename']: + abort(403) + + # Validate the `iat` and `exp` are within `max_request_age` of now. + now = int(time.time()) + if abs(now - issued_at) > max_request_age: + abort(403) + if abs(now - expires_at) > max_request_age: + abort(403) + if issued_at > expires_at or expires_at - issued_at > max_request_age: + abort(403) @flask_app.route('/log/') - def serve_logs_view(filename): # pylint: disable=unused-variable - log_directory = os.path.expanduser(conf.get('logging', 'BASE_LOG_FOLDER')) - return flask.send_from_directory( - log_directory, filename, mimetype="application/json", as_attachment=False - ) + def serve_logs_view(filename): + return send_from_directory(log_directory, filename, mimetype="application/json", as_attachment=False) + + return flask_app + + +def serve_logs(): + """Serves logs generated by Worker""" + setproctitle("airflow serve-logs") + app = flask_app() worker_log_server_port = conf.getint('celery', 'WORKER_LOG_SERVER_PORT') - flask_app.run(host='0.0.0.0', port=worker_log_server_port) + app.run(host='0.0.0.0', port=worker_log_server_port) diff --git a/tests/utils/test_serve_logs.py b/tests/utils/test_serve_logs.py index edb4d0991b3e6..080e9c9649594 100644 --- a/tests/utils/test_serve_logs.py +++ b/tests/utils/test_serve_logs.py @@ -14,33 +14,99 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import os -import unittest -from multiprocessing import Process -from os.path import basename -from tempfile import NamedTemporaryFile -from time import sleep +from typing import TYPE_CHECKING import pytest -import requests +from itsdangerous import TimedJSONWebSignatureSerializer from airflow.configuration import conf -from airflow.utils.serve_logs import serve_logs +from airflow.utils.serve_logs import flask_app +from tests.test_utils.config import conf_vars + +# pylint: skip-file + +if TYPE_CHECKING: + from flask.testing import FlaskClient LOG_DATA = "Airflow log data" * 20 -@pytest.mark.quarantined -class TestServeLogs(unittest.TestCase): - def test_should_serve_file(self): - log_dir = os.path.expanduser(conf.get('logging', 'BASE_LOG_FOLDER')) - log_port = conf.get('celery', 'WORKER_LOG_SERVER_PORT') - with NamedTemporaryFile(dir=log_dir) as f: - f.write(LOG_DATA.encode()) - f.flush() - sub_proc = Process(target=serve_logs) - sub_proc.start() - sleep(1) - log_url = f"http://localhost:{log_port}/log/{basename(f.name)}" - assert LOG_DATA == requests.get(log_url).content.decode() - sub_proc.terminate() +@pytest.fixture +def client(tmpdir): + with conf_vars({('logging', 'base_log_folder'): str(tmpdir)}): + app = flask_app() + + yield app.test_client() + + +@pytest.fixture +def sample_log(tmpdir): + f = tmpdir / 'sample.log' + f.write(LOG_DATA.encode()) + + return f + + +@pytest.fixture +def signer(): + return TimedJSONWebSignatureSerializer( + secret_key=conf.get('webserver', 'secret_key'), + algorithm_name='HS512', + expires_in=30, + # This isn't really a "salt", more of a signing context + salt='task-instance-logs', + ) + + +@pytest.mark.usefixtures('sample_log') +class TestServeLogs: + def test_forbidden_no_auth(self, client: "FlaskClient"): + assert 403 == client.get('/log/sample.log').status_code + + def test_should_serve_file(self, client: "FlaskClient", signer): + assert ( + LOG_DATA + == client.get( + '/log/sample.log', + headers={ + 'Authorization': signer.dumps('sample.log'), + }, + ).data.decode() + ) + + def test_forbidden_too_long_validity(self, client: "FlaskClient", signer): + signer.expires_in = 3600 + assert ( + 403 + == client.get( + '/log/sample.log', + headers={ + 'Authorization': signer.dumps('sample.log'), + }, + ).status_code + ) + + def test_forbidden_expired(self, client: "FlaskClient", signer): + # Fake the time we think we are + signer.now = lambda: 0 + assert ( + 403 + == client.get( + '/log/sample.log', + headers={ + 'Authorization': signer.dumps('sample.log'), + }, + ).status_code + ) + + def test_wrong_context(self, client: "FlaskClient", signer): + signer.salt = None + assert ( + 403 + == client.get( + '/log/sample.log', + headers={ + 'Authorization': signer.dumps('sample.log'), + }, + ).status_code + )