Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only allow webserver to request from the worker log server #16754

Merged
merged 2 commits into from Jul 1, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 8 additions & 1 deletion airflow/utils/log/file_task_handler.py
Expand Up @@ -22,6 +22,7 @@
from typing import TYPE_CHECKING, Optional

import httpx
from itsdangerous import TimedJSONWebSignatureSerializer

from airflow.configuration import AirflowConfigException, conf
from airflow.utils.helpers import parse_template_string
Expand Down Expand Up @@ -172,7 +173,13 @@ def _read(self, ti, try_number, metadata=None):
except (AirflowConfigException, ValueError):
pass

response = httpx.get(url, timeout=timeout)
signer = TimedJSONWebSignatureSerializer(
secret_key=conf.get('webserver', 'secret_key'),
algorithm_name='HS512',
expires_in=conf.getint('webserver', 'log_request_clock_grace', fallback=30),
)

response = httpx.get(url, timeout=timeout, headers={'Authorization': signer.dumps({})})
response.encoding = "utf-8"

# Check if the resource was properly fetched
Expand Down
58 changes: 47 additions & 11 deletions airflow/utils/serve_logs.py
Expand Up @@ -17,25 +17,61 @@

"""Serve logs process"""
import os
import time

import flask
from flask import Flask, abort, request, send_from_directory
from itsdangerous import TimedJSONWebSignatureSerializer
from setproctitle import setproctitle

from airflow.configuration import conf


def serve_logs():
"""Serves logs generated by Worker"""
print("Starting flask")
flask_app = flask.Flask(__name__)
setproctitle("airflow serve-logs")
def flask_app():
flask_app = Flask(__name__)
max_request_age = conf.getint('webserver', 'log_request_clock_grace', fallback=30)
log_directory = os.path.expanduser(conf.get('logging', 'BASE_LOG_FOLDER'))

signer = TimedJSONWebSignatureSerializer(
secret_key=conf.get('webserver', 'secret_key'),
ashb marked this conversation as resolved.
Show resolved Hide resolved
algorithm_name='HS512',
expires_in=max_request_age,
)

# Prevent direct access to the logs port
@flask_app.before_request
def validate_pre_signed_url():
try:
auth = request.headers['Authorization']

# We don't actually care about the payload, just that the signature
ashb marked this conversation as resolved.
Show resolved Hide resolved
# was valid and the `exp` claim is correct
_, headers = signer.loads(auth, return_header=True)

issued_at = int(headers['iat'])
expires_at = int(headers['exp'])
except Exception as e:
print(e)
abort(403)
# Validate the `iat` and `exp` are within `max_request_age` of now.
now = int(time.time())
if abs(now - issued_at) > max_request_age:
abort(403)
if abs(now - expires_at) > max_request_age:
abort(403)
if issued_at > expires_at or expires_at - issued_at > max_request_age:
abort(403)

@flask_app.route('/log/<path:filename>')
def serve_logs_view(filename):
log_directory = os.path.expanduser(conf.get('logging', 'BASE_LOG_FOLDER'))
return flask.send_from_directory(
log_directory, filename, mimetype="application/json", as_attachment=False
)
return send_from_directory(log_directory, filename, mimetype="application/json", as_attachment=False)

return flask_app


def serve_logs():
"""Serves logs generated by Worker"""
setproctitle("airflow serve-logs")
app = flask_app()

worker_log_server_port = conf.getint('celery', 'WORKER_LOG_SERVER_PORT')
flask_app.run(host='0.0.0.0', port=worker_log_server_port)
app.run(host='0.0.0.0', port=worker_log_server_port)
99 changes: 77 additions & 22 deletions tests/utils/test_serve_logs.py
Expand Up @@ -14,33 +14,88 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import unittest
from multiprocessing import Process
from os.path import basename
from tempfile import NamedTemporaryFile
from time import sleep
from typing import TYPE_CHECKING

import pytest
import requests
from itsdangerous import TimedJSONWebSignatureSerializer

from airflow.configuration import conf
from airflow.utils.serve_logs import serve_logs
from airflow.utils.serve_logs import flask_app
from tests.test_utils.config import conf_vars

if TYPE_CHECKING:
from flask.testing import FlaskClient

LOG_DATA = "Airflow log data" * 20


@pytest.mark.quarantined
class TestServeLogs(unittest.TestCase):
def test_should_serve_file(self):
log_dir = os.path.expanduser(conf.get('logging', 'BASE_LOG_FOLDER'))
log_port = conf.get('celery', 'WORKER_LOG_SERVER_PORT')
with NamedTemporaryFile(dir=log_dir) as f:
f.write(LOG_DATA.encode())
f.flush()
sub_proc = Process(target=serve_logs)
sub_proc.start()
sleep(1)
log_url = f"http://localhost:{log_port}/log/{basename(f.name)}"
assert LOG_DATA == requests.get(log_url).content.decode()
sub_proc.terminate()
@pytest.fixture
def client(tmpdir):
with conf_vars({('logging', 'base_log_folder'): str(tmpdir)}):
app = flask_app()

yield app.test_client()


@pytest.fixture
def sample_log(tmpdir):
f = tmpdir / 'sample.log'
f.write(LOG_DATA.encode())

return f


@pytest.mark.usefixtures('sample_log')
class TestServeLogs:
def test_forbidden_no_auth(self, client: "FlaskClient"):
assert 403 == client.get('/log/sample.log').status_code

def test_should_serve_file(self, client: "FlaskClient"):
signer = TimedJSONWebSignatureSerializer(
secret_key=conf.get('webserver', 'secret_key'),
algorithm_name='HS512',
expires_in=30,
)
assert (
LOG_DATA
== client.get(
'/log/sample.log',
headers={
'Authorization': signer.dumps({}),
},
).data.decode()
)

def test_forbidden_too_long_validity(self, client: "FlaskClient"):
signer = TimedJSONWebSignatureSerializer(
secret_key=conf.get('webserver', 'secret_key'),
algorithm_name='HS512',
expires_in=3600,
)
assert (
403
== client.get(
'/log/sample.log',
headers={
'Authorization': signer.dumps({}),
},
).status_code
)

def test_forbidden_expired(self, client: "FlaskClient"):
signer = TimedJSONWebSignatureSerializer(
secret_key=conf.get('webserver', 'secret_key'),
algorithm_name='HS512',
expires_in=30,
)
# Fake the time we think we are
signer.now = lambda: 0
assert (
403
== client.get(
'/log/sample.log',
headers={
'Authorization': signer.dumps({}),
},
).status_code
)