Skip to content

Commit

Permalink
Remove the Parser cache
Browse files Browse the repository at this point in the history
Because the cache is no longer used field-by-field to fetch data,
there's significantly less value in keeping it. Combined with the fact
that each parser instantiation was already clearing the cache to avoid
a security bug ( marshmallow-code#371 ), the cache is no longer actually used at all
in most (any?) contexts.

Remove the cache and all of the machinery associated with it
(Parser._clear_cache, Parser._clone, and relevant checks).

Resolves marshmallow-code#374
  • Loading branch information
sirosen committed Jan 24, 2020
1 parent 3046191 commit f2071b0
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 120 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.rst
@@ -1,6 +1,14 @@
Changelog
---------

6.0.0b4 (Unreleased)
********************

Refactoring:

* Remove the cache attached to webargs parsers. Due to changes between webargs
v5 and v6, the cache is no longer considered useful.

6.0.0b3 (2020-01-21)
********************

Expand Down
29 changes: 11 additions & 18 deletions src/webargs/aiohttpparser.py
Expand Up @@ -84,10 +84,8 @@ def load_querystring(self, req: Request, schema: Schema) -> MultiDictProxy:

async def load_form(self, req: Request, schema: Schema) -> MultiDictProxy:
"""Return form values from the request as a MultiDictProxy."""
post_data = self._cache.get("post")
if post_data is None:
self._cache["post"] = await req.post()
return MultiDictProxy(self._cache["post"], schema)
post_data = await req.post()
return MultiDictProxy(post_data, schema)

async def load_json_or_form(
self, req: Request, schema: Schema
Expand All @@ -99,22 +97,17 @@ async def load_json_or_form(

async def load_json(self, req: Request, schema: Schema) -> typing.Dict:
"""Return a parsed json payload from the request."""
json_data = self._cache.get("json")
if json_data is None:
if not (req.body_exists and is_json_request(req)):
if not (req.body_exists and is_json_request(req)):
return core.missing
try:
return await req.json(loads=json.loads)
except json.JSONDecodeError as e:
if e.doc == "":
return core.missing
try:
json_data = await req.json(loads=json.loads)
except json.JSONDecodeError as e:
if e.doc == "":
return core.missing
else:
return self._handle_invalid_json_error(e, req)
except UnicodeDecodeError as e:
else:
return self._handle_invalid_json_error(e, req)

self._cache["json"] = json_data
return json_data
except UnicodeDecodeError as e:
return self._handle_invalid_json_error(e, req)

def load_headers(self, req: Request, schema: Schema) -> MultiDictProxy:
"""Return headers from the request as a MultiDictProxy."""
Expand Down
1 change: 0 additions & 1 deletion src/webargs/asyncparser.py
Expand Up @@ -35,7 +35,6 @@ async def parse(
Receives the same arguments as `webargs.core.Parser.parse`.
"""
self.clear_cache() # in case someone used `location_load_*()`
req = req if req is not None else self.get_default_request()
if req is None:
raise ValueError("Must pass req object")
Expand Down
51 changes: 11 additions & 40 deletions src/webargs/core.py
Expand Up @@ -2,7 +2,6 @@
import inspect
import logging
import warnings
from copy import copy
from collections.abc import Mapping
import json

Expand Down Expand Up @@ -130,8 +129,6 @@ def __init__(self, location=None, error_handler=None, schema_class=None):
self.location = location or self.DEFAULT_LOCATION
self.error_callback = _callable_or_raise(error_handler)
self.schema_class = schema_class or self.DEFAULT_SCHEMA_CLASS
#: A short-lived cache to store results from processing request bodies.
self._cache = {}

def _get_loader(self, location):
"""Get the loader function for the given location.
Expand Down Expand Up @@ -207,15 +204,6 @@ def _get_schema(self, argmap, req):
)
return schema

def _clone(self):
"""Clone the current parser in order to ensure that it has a fresh and
independent cache. This is used whenever `Parser.parse` is called, so
that these methods always have separate caches.
"""
clone = copy(self)
clone.clear_cache()
return clone

def parse(
self,
argmap,
Expand Down Expand Up @@ -250,32 +238,20 @@ def parse(
raise ValueError("Must pass req object")
data = None
validators = _ensure_list_of_callables(validate)
parser = self._clone()
schema = self._get_schema(argmap, req)
try:
location_data = parser._load_location_data(
location_data = self._load_location_data(
schema=schema, req=req, location=location or self.location
)
result = schema.load(location_data)
data = result.data if MARSHMALLOW_VERSION_INFO[0] < 3 else result
parser._validate_arguments(data, validators)
self._validate_arguments(data, validators)
except ma.exceptions.ValidationError as error:
parser._on_validation_error(
self._on_validation_error(
error, req, schema, error_status_code, error_headers
)
return data

def clear_cache(self):
"""Invalidate the parser's cache.
This is usually a no-op now since the Parser clone used for parsing a
request is discarded afterwards. It can still be used when manually
calling ``parse_*`` methods which would populate the cache on the main
Parser instance.
"""
self._cache = {}
return None

def get_default_request(self):
"""Optional override. Provides a hook for frameworks that use thread-local
request objects.
Expand Down Expand Up @@ -458,19 +434,14 @@ def load_json(self, req, schema):
# `_handle_invalid_json_error` and `_raw_load_json`
# these methods are not part of the public API and are used to simplify
# code sharing amongst the built-in webargs parsers
if "json" not in self._cache:
try:
json_data = self._raw_load_json(req)
except json.JSONDecodeError as e:
if e.doc == "":
json_data = missing
else:
return self._handle_invalid_json_error(e, req)
except UnicodeDecodeError as e:
return self._handle_invalid_json_error(e, req)
self._cache["json"] = json_data

return self._cache["json"]
try:
return self._raw_load_json(req)
except json.JSONDecodeError as e:
if e.doc == "":
return missing
return self._handle_invalid_json_error(e, req)
except UnicodeDecodeError as e:
return self._handle_invalid_json_error(e, req)

def load_json_or_form(self, req, schema):
"""Load data from a request, accepting either JSON or form-encoded
Expand Down
4 changes: 1 addition & 3 deletions src/webargs/falconparser.py
Expand Up @@ -91,9 +91,7 @@ def load_form(self, req, schema):
The request stream will be read and left at EOF.
"""
form = self._cache.get("form")
if form is None:
self._cache["form"] = form = parse_form_body(req)
form = parse_form_body(req)
if form is core.missing:
return form
return MultiDictProxy(form, schema)
Expand Down
36 changes: 0 additions & 36 deletions tests/test_flaskparser.py
@@ -1,5 +1,3 @@
import threading

from werkzeug.exceptions import HTTPException
import pytest

Expand Down Expand Up @@ -126,37 +124,3 @@ def test_abort_has_serializable_data():
error = json.loads(serialized_error)
assert isinstance(error, dict)
assert error["message"] == "custom error message"


def test_json_cache_race_condition():
app = Flask("testapp")
lock = threading.Lock()
lock.acquire()

class MyField(fields.Field):
def _deserialize(self, value, attr, data, **kwargs):
with lock:
return value

argmap = {"value": MyField()}
results = {}

def thread_fn(value):
with app.test_request_context(
"/foo",
method="post",
data=json.dumps({"value": value}),
content_type="application/json",
):
results[value] = parser.parse(argmap)["value"]

t1 = threading.Thread(target=thread_fn, args=(42,))
t2 = threading.Thread(target=thread_fn, args=(23,))
t1.start()
t2.start()
lock.release()
t1.join()
t2.join()
# ensure we didn't get contaminated by a parallel request
assert results[42] == 42
assert results[23] == 23
22 changes: 0 additions & 22 deletions tests/test_tornadoparser.py
Expand Up @@ -51,9 +51,6 @@ def test_tornado_multidictproxy():


class TestQueryArgs:
def setup_method(self, method):
parser.clear_cache()

def test_it_should_get_single_values(self):
query = [("name", "Aeschylus")]
request = make_get_request(query)
Expand All @@ -75,9 +72,6 @@ def test_it_should_return_missing_if_not_present(self):


class TestFormArgs:
def setup_method(self, method):
parser.clear_cache()

def test_it_should_get_single_values(self):
query = [("name", "Aristophanes")]
request = make_form_request(query)
Expand All @@ -99,9 +93,6 @@ def test_it_should_return_missing_if_not_present(self):


class TestJSONArgs:
def setup_method(self, method):
parser.clear_cache()

def test_it_should_get_single_values(self):
query = {"name": "Euripides"}
request = make_json_request(query)
Expand Down Expand Up @@ -162,14 +153,10 @@ def test_it_should_handle_type_error_on_load_json(self):
def test_it_should_handle_value_error_on_parse_json(self):
request = make_request("this is json not")
result = parser.load_json(request, author_schema)
assert parser._cache.get("json") == missing
assert result is missing


class TestHeadersArgs:
def setup_method(self, method):
parser.clear_cache()

def test_it_should_get_single_values(self):
query = {"name": "Euphorion"}
request = make_request(headers=query)
Expand All @@ -190,9 +177,6 @@ def test_it_should_return_missing_if_not_present(self):


class TestFilesArgs:
def setup_method(self, method):
parser.clear_cache()

def test_it_should_get_single_values(self):
query = [("name", "Sappho")]
request = make_files_request(query)
Expand Down Expand Up @@ -221,9 +205,6 @@ def test_it_should_raise_httperror_on_failed_validation(self):


class TestParse:
def setup_method(self, method):
parser.clear_cache()

def test_it_should_parse_query_arguments(self):
attrs = {"string": fields.Field(), "integer": fields.List(fields.Int())}

Expand Down Expand Up @@ -322,9 +303,6 @@ def test_it_should_parse_multiple_arg_required(self):


class TestUseArgs:
def setup_method(self, method):
parser.clear_cache()

def test_it_should_pass_parsed_as_first_argument(self):
class Handler:
request = make_json_request({"key": "value"})
Expand Down

0 comments on commit f2071b0

Please sign in to comment.