From fa7e3c996edf2d5482fff8f9d890ac2390dede4d Mon Sep 17 00:00:00 2001 From: Patrick Wang <1263870+patrickkwang@users.noreply.github.com> Date: Mon, 7 Jun 2021 06:46:18 -0400 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20Check=20Content-Type=20request?= =?UTF-8?q?=20header=20before=20assuming=20JSON=20(#2118)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Patrick Wang Co-authored-by: Sebastián Ramírez --- fastapi/routing.py | 19 ++++- .../test_body/test_tutorial001.py | 84 +++++++++++++++++-- .../test_tutorial001.py | 1 + 3 files changed, 92 insertions(+), 12 deletions(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index ac5e19d99835a..9b51f03cac562 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -1,4 +1,5 @@ import asyncio +import email.message import enum import inspect import json @@ -36,7 +37,7 @@ ) from pydantic import BaseModel from pydantic.error_wrappers import ErrorWrapper, ValidationError -from pydantic.fields import ModelField +from pydantic.fields import ModelField, Undefined from starlette import routing from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException @@ -174,14 +175,26 @@ def get_request_handler( async def app(request: Request) -> Response: try: - body = None + body: Any = None if body_field: if is_body_form: body = await request.form() else: body_bytes = await request.body() if body_bytes: - body = await request.json() + json_body: Any = Undefined + content_type_value = request.headers.get("content-type") + if content_type_value: + message = email.message.Message() + message["content-type"] = content_type_value + if message.get_content_maintype() == "application": + subtype = message.get_content_subtype() + if subtype == "json" or subtype.endswith("+json"): + json_body = await request.json() + if json_body != Undefined: + body = json_body + else: + body = body_bytes except json.JSONDecodeError as e: raise RequestValidationError([ErrorWrapper(e, ("body", e.pos))], body=e.doc) except Exception as e: diff --git a/tests/test_tutorial/test_body/test_tutorial001.py b/tests/test_tutorial/test_body/test_tutorial001.py index 38c6dbe876b26..c90240ae4c349 100644 --- a/tests/test_tutorial/test_body/test_tutorial001.py +++ b/tests/test_tutorial/test_body/test_tutorial001.py @@ -173,25 +173,91 @@ def test_post_body(path, body, expected_status, expected_response): def test_post_broken_body(): - response = client.post("/items/", data={"name": "Foo", "price": 50.5}) + response = client.post( + "/items/", + headers={"content-type": "application/json"}, + data="{some broken json}", + ) assert response.status_code == 422, response.text assert response.json() == { "detail": [ { + "loc": ["body", 1], + "msg": "Expecting property name enclosed in double quotes: line 1 column 2 (char 1)", + "type": "value_error.jsondecode", "ctx": { - "colno": 1, - "doc": "name=Foo&price=50.5", + "msg": "Expecting property name enclosed in double quotes", + "doc": "{some broken json}", + "pos": 1, "lineno": 1, - "msg": "Expecting value", - "pos": 0, + "colno": 2, }, - "loc": ["body", 0], - "msg": "Expecting value: line 1 column 1 (char 0)", - "type": "value_error.jsondecode", } ] } + + +def test_post_form_for_json(): + response = client.post("/items/", data={"name": "Foo", "price": 50.5}) + assert response.status_code == 422, response.text + assert response.json() == { + "detail": [ + { + "loc": ["body"], + "msg": "value is not a valid dict", + "type": "type_error.dict", + } + ] + } + + +def test_explicit_content_type(): + response = client.post( + "/items/", + data='{"name": "Foo", "price": 50.5}', + headers={"Content-Type": "application/json"}, + ) + assert response.status_code == 200, response.text + + +def test_geo_json(): + response = client.post( + "/items/", + data='{"name": "Foo", "price": 50.5}', + headers={"Content-Type": "application/geo+json"}, + ) + assert response.status_code == 200, response.text + + +def test_wrong_headers(): + data = '{"name": "Foo", "price": 50.5}' + invalid_dict = { + "detail": [ + { + "loc": ["body"], + "msg": "value is not a valid dict", + "type": "type_error.dict", + } + ] + } + + response = client.post("/items/", data=data, headers={"Content-Type": "text/plain"}) + assert response.status_code == 422, response.text + assert response.json() == invalid_dict + + response = client.post( + "/items/", data=data, headers={"Content-Type": "application/geo+json-seq"} + ) + assert response.status_code == 422, response.text + assert response.json() == invalid_dict + response = client.post( + "/items/", data=data, headers={"Content-Type": "application/not-really-json"} + ) + assert response.status_code == 422, response.text + assert response.json() == invalid_dict + + +def test_other_exceptions(): with patch("json.loads", side_effect=Exception): response = client.post("/items/", json={"test": "test2"}) assert response.status_code == 400, response.text - assert response.json() == {"detail": "There was an error parsing the body"} diff --git a/tests/test_tutorial/test_custom_request_and_route/test_tutorial001.py b/tests/test_tutorial/test_custom_request_and_route/test_tutorial001.py index cc85a8a82a5ac..3eb5822e28816 100644 --- a/tests/test_tutorial/test_custom_request_and_route/test_tutorial001.py +++ b/tests/test_tutorial/test_custom_request_and_route/test_tutorial001.py @@ -25,6 +25,7 @@ def test_gzip_request(compress): if compress: data = gzip.compress(data) headers["Content-Encoding"] = "gzip" + headers["Content-Type"] = "application/json" response = client.post("/sum", data=data, headers=headers) assert response.json() == {"sum": n}