Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly set json_provider_class on Flask app so it uses our encoder #26554

Merged
merged 1 commit into from Sep 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 18 additions & 2 deletions airflow/utils/json.py
Expand Up @@ -17,11 +17,12 @@
# under the License.
from __future__ import annotations

import json
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At some point in the past, if user had simplejson installed, then webserver would blow up, that's why we imported from flask.json. It might be worth doing pip install simplejson just to make sure that it is compatible

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was this commit: ea3d42a

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested that, doesn't blow up (by which I mean tests still create the app and pass)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool -- i just tried it too, installed it and launched webserver and navigated to a view or two

import logging
from datetime import date, datetime
from decimal import Decimal

from flask.json import JSONEncoder
from flask.json.provider import JSONProvider

from airflow.utils.timezone import convert_to_utc, is_naive

Expand All @@ -40,7 +41,7 @@
log = logging.getLogger(__name__)


class AirflowJsonEncoder(JSONEncoder):
class AirflowJsonEncoder(json.JSONEncoder):
"""Custom Airflow json encoder implementation."""

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -107,3 +108,18 @@ def safe_get_name(pod):
return {}

raise TypeError(f"Object of type '{obj.__class__.__name__}' is not JSON serializable")


class AirflowJsonProvider(JSONProvider):
"""JSON Provider for Flask app to use AirflowJsonEncoder."""

ensure_ascii: bool = True
sort_keys: bool = True

def dumps(self, obj, **kwargs):
kwargs.setdefault('ensure_ascii', self.ensure_ascii)
kwargs.setdefault('sort_keys', self.sort_keys)
return json.dumps(obj, **kwargs, cls=AirflowJsonEncoder)

def loads(self, s: str | bytes, **kwargs):
return json.loads(s, **kwargs)
5 changes: 3 additions & 2 deletions airflow/www/app.py
Expand Up @@ -32,7 +32,7 @@
from airflow.exceptions import AirflowConfigException, RemovedInAirflow3Warning
from airflow.logging_config import configure_logging
from airflow.models import import_all_models
from airflow.utils.json import AirflowJsonEncoder
from airflow.utils.json import AirflowJsonProvider
from airflow.www.extensions.init_appbuilder import init_appbuilder
from airflow.www.extensions.init_appbuilder_links import init_appbuilder_links
from airflow.www.extensions.init_dagbag import init_dagbag
Expand Down Expand Up @@ -109,7 +109,8 @@ def create_app(config=None, testing=False):
flask_app.config['SQLALCHEMY_ENGINE_OPTIONS'] = settings.prepare_engine_args()

# Configure the JSON encoder used by `|tojson` filter from Flask
flask_app.json_provider_class = AirflowJsonEncoder
flask_app.json_provider_class = AirflowJsonProvider
flask_app.json = AirflowJsonProvider(flask_app)

csrf.init_app(flask_app)

Expand Down
10 changes: 1 addition & 9 deletions airflow/www/utils.py
Expand Up @@ -24,7 +24,7 @@
from urllib.parse import urlencode

import sqlalchemy as sqla
from flask import Response, request, url_for
from flask import request, url_for
from flask.helpers import flash
from flask_appbuilder.forms import FieldConverter
from flask_appbuilder.models.filters import BaseFilter
Expand All @@ -47,7 +47,6 @@
from airflow.utils import timezone
from airflow.utils.code_utils import get_python_source
from airflow.utils.helpers import alchemy_to_dict
from airflow.utils.json import AirflowJsonEncoder
from airflow.utils.state import State, TaskInstanceState
from airflow.www.forms import DateTimeWithTimezoneField
from airflow.www.widgets import AirflowDateTimePickerWidget
Expand Down Expand Up @@ -322,13 +321,6 @@ def epoch(dttm):
return (int(time.mktime(dttm.timetuple())) * 1000,)


def json_response(obj):
"""Returns a json response from a json serializable python object"""
return Response(
response=json.dumps(obj, indent=4, cls=AirflowJsonEncoder), status=200, mimetype="application/json"
)


def make_cache_key(*args, **kwargs):
"""Used by cache to get a unique key per URL"""
path = request.path
Expand Down
33 changes: 17 additions & 16 deletions airflow/www/views.py
Expand Up @@ -37,6 +37,7 @@
from urllib.parse import parse_qsl, unquote, urlencode, urlparse

import configupdater
import flask.json
import lazy_object_proxy
import markupsafe
import nvd3
Expand Down Expand Up @@ -107,7 +108,7 @@
from airflow.ti_deps.dependencies_deps import RUNNING_DEPS, SCHEDULER_QUEUED_DEPS
from airflow.timetables.base import DataInterval, TimeRestriction
from airflow.timetables.interval import CronDataIntervalTimetable
from airflow.utils import json as utils_json, timezone, yaml
from airflow.utils import timezone, yaml
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.dag_edges import dag_edges
from airflow.utils.dates import infer_time_unit, scale_time_units
Expand Down Expand Up @@ -575,7 +576,7 @@ def health(self):
'latest_scheduler_heartbeat': latest_scheduler_heartbeat,
}

return wwwutils.json_response(payload)
return flask.json.jsonify(payload)

@expose('/home')
@auth.has_access(
Expand Down Expand Up @@ -856,7 +857,7 @@ def dag_stats(self, session=None):
filter_dag_ids = allowed_dag_ids

if not filter_dag_ids:
return wwwutils.json_response({})
return flask.json.jsonify({})

payload = {}
dag_state_stats = dag_state_stats.filter(dr.dag_id.in_(filter_dag_ids))
Expand All @@ -873,7 +874,7 @@ def dag_stats(self, session=None):
count = data.get(dag_id, {}).get(state, 0)
payload[dag_id].append({'state': state, 'count': count})

return wwwutils.json_response(payload)
return flask.json.jsonify(payload)

@expose('/task_stats', methods=['POST'])
@auth.has_access(
Expand All @@ -889,7 +890,7 @@ def task_stats(self, session=None):
allowed_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user)

if not allowed_dag_ids:
return wwwutils.json_response({})
return flask.json.jsonify({})

# Filter by post parameters
selected_dag_ids = {unquote(dag_id) for dag_id in request.form.getlist('dag_ids') if dag_id}
Expand Down Expand Up @@ -983,7 +984,7 @@ def task_stats(self, session=None):
for state in State.task_states:
count = data.get(dag_id, {}).get(state, 0)
payload[dag_id].append({'state': state, 'count': count})
return wwwutils.json_response(payload)
return flask.json.jsonify(payload)

@expose('/last_dagruns', methods=['POST'])
@auth.has_access(
Expand All @@ -1006,7 +1007,7 @@ def last_dagruns(self, session=None):
filter_dag_ids = allowed_dag_ids

if not filter_dag_ids:
return wwwutils.json_response({})
return flask.json.jsonify({})

last_runs_subquery = (
session.query(
Expand Down Expand Up @@ -1046,7 +1047,7 @@ def last_dagruns(self, session=None):
}
for r in query
}
return wwwutils.json_response(resp)
return flask.json.jsonify(resp)

@expose('/code')
@auth.has_access(
Expand Down Expand Up @@ -2106,7 +2107,7 @@ def blocked(self, session=None):
filter_dag_ids = allowed_dag_ids

if not filter_dag_ids:
return wwwutils.json_response([])
return flask.json.jsonify([])

dags = (
session.query(DagRun.dag_id, sqla.func.count(DagRun.id))
Expand All @@ -2129,7 +2130,7 @@ def blocked(self, session=None):
'max_active_runs': max_active_runs,
}
)
return wwwutils.json_response(payload)
return flask.json.jsonify(payload)

def _mark_dagrun_state_as_failed(self, dag_id, dag_run_id, confirmed):
if not dag_run_id:
Expand Down Expand Up @@ -3412,7 +3413,7 @@ def task_instances(self):
for ti in dag.get_task_instances(dttm, dttm)
}

return json.dumps(task_instances, cls=utils_json.AirflowJsonEncoder)
return flask.json.jsonify(task_instances)

@expose('/object/grid_data')
@auth.has_access(
Expand Down Expand Up @@ -3467,7 +3468,7 @@ def grid_data(self):
}
# avoid spaces to reduce payload size
return (
htmlsafe_json_dumps(data, separators=(',', ':'), cls=utils_json.AirflowJsonEncoder),
htmlsafe_json_dumps(data, separators=(',', ':'), dumps=flask.json.dumps),
{'Content-Type': 'application/json; charset=utf-8'},
)

Expand Down Expand Up @@ -3510,7 +3511,7 @@ def next_run_datasets(self, dag_id):
.all()
]
return (
htmlsafe_json_dumps(data, separators=(',', ':'), cls=utils_json.AirflowJsonEncoder),
htmlsafe_json_dumps(data, separators=(',', ':'), dumps=flask.json.dumps),
{'Content-Type': 'application/json; charset=utf-8'},
)

Expand Down Expand Up @@ -3547,7 +3548,7 @@ def dataset_dependencies(self):
}

return (
htmlsafe_json_dumps(data, separators=(',', ':'), cls=utils_json.AirflowJsonEncoder),
htmlsafe_json_dumps(data, separators=(',', ':'), dumps=flask.json.dumps),
{'Content-Type': 'application/json; charset=utf-8'},
)

Expand Down Expand Up @@ -5207,7 +5208,7 @@ def autocomplete(self, session=None):
query = unquote(request.args.get('query', ''))

if not query:
return wwwutils.json_response([])
return flask.json.jsonify([])

# Provide suggestions of dag_ids and owners
dag_ids_query = session.query(
Expand Down Expand Up @@ -5241,7 +5242,7 @@ def autocomplete(self, session=None):
payload = [
row._asdict() for row in dag_ids_query.union(owners_query).order_by('name').limit(10).all()
]
return wwwutils.json_response(payload)
return flask.json.jsonify(payload)


class DagDependenciesView(AirflowBaseView):
Expand Down
10 changes: 10 additions & 0 deletions tests/www/test_app.py
Expand Up @@ -240,3 +240,13 @@ def test_flask_cli_should_display_routes(self, capsys):

output = capsys.readouterr()
assert "/login/" in output.out


def test_app_can_json_serialize_k8s_pod():
# This is mostly testing that we have correctly configured the JSON provider to use. Testing the k8s pos
ashb marked this conversation as resolved.
Show resolved Hide resolved
# is a side-effect of that.
k8s = pytest.importorskip('kubernetes.client.models')

pod = k8s.V1Pod(spec=k8s.V1PodSpec(containers=[k8s.V1Container(name="base")]))
app = application.cached_app(testing=True)
assert app.json.dumps(pod) == '{"spec": {"containers": [{"name": "base"}]}}'