diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 83b3ed0a..e50bab3d 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -16,7 +16,7 @@ Refactoring: * *Backwards-incompatible*: Schema fields may not specify a location any longer, and `Parser.use_args` and `Parser.use_kwargs` now accept `location` - (singular) instead of `locations` plural. Instead of using a single field or + (singular) instead of `locations` (plural). Instead of using a single field or schema with multiple `locations`, users are recommended to make multiple calls to `use_args` or `use_kwargs` with a distinct schema per location. For example, code should be rewritten like this: @@ -24,27 +24,29 @@ Refactoring: .. code-block:: python # under webargs v5 - class CompoundSchema: - q1 = ma.fields.Int(location="query") - q2 = ma.fields.Int(location="query") - h1 = ma.fields.Int(location="headers") - - @parser.use_args(CompoundSchema(), locations=("query", "headers")) + @parser.use_args( + { + "q1": ma.fields.Int(location="query"), + "q2": ma.fields.Int(location="query"), + "h1": ma.fields.Int(location="headers"), + }, + locations=("query", "headers"), + ) def foo(q1, q2, h1): ... + # should be split up like so under webargs v6 - class QueryParamSchema: - q1 = ma.fields.Int() - q2 = ma.fields.Int() - class HeaderSchema: - h1 = ma.fields.Int() - - @parser.use_args(QueryParamSchema(), location="query") - @parser.use_args(HeaderSchema(), location="headers") + @parser.use_args({"q1": ma.fields.Int(), "q2": ma.fields.Int()}, location="query") + @parser.use_args({"h1": ma.fields.Int()}, location="headers") def foo(q1, q2, h1): ... +* The `location_handler` decorator has been removed and replaced with + `location_loader`. `location_loader` serves the same purpose (letting you + write custom hooks for loading data) but its expected method signature is + different. See the docs on `location_loader` for proper usage. + 5.5.2 (2019-10-06) ****************** diff --git a/src/webargs/asyncparser.py b/src/webargs/asyncparser.py index 82369018..8a8ac66d 100644 --- a/src/webargs/asyncparser.py +++ b/src/webargs/asyncparser.py @@ -96,6 +96,25 @@ async def parse( ) return data + async def _load_location_data(self, schema, req, location): + """Return a dictionary-like object for the location on the given request. + + Needs to have the schema in hand in order to correctly handle loading + lists from multidict objects and `many=True` schemas. + """ + loader_func = self._get_loader(location) + if asyncio.iscoroutinefunction(loader_func): + data = await loader_func(req, schema) + else: + data = loader_func(req, schema) + + # when the desired location is empty (no data), provide an empty + # dict as the default so that optional arguments in a location + # (e.g. optional JSON body) work smoothly + if data is core.missing: + data = {} + return data + async def _on_validation_error( self, error: ValidationError, diff --git a/src/webargs/core.py b/src/webargs/core.py index fe2f39b6..6be9b8a3 100644 --- a/src/webargs/core.py +++ b/src/webargs/core.py @@ -14,9 +14,9 @@ import marshmallow as ma from marshmallow import ValidationError -from marshmallow.utils import missing, is_collection +from marshmallow.utils import missing -from webargs.compat import Mapping, iteritems, MARSHMALLOW_VERSION_INFO +from webargs.compat import Mapping, MARSHMALLOW_VERSION_INFO from webargs.dict2schema import dict2schema from webargs.fields import DelimitedList @@ -28,7 +28,6 @@ "dict2schema", "is_multiple", "Parser", - "get_value", "missing", "parse_json", ] @@ -74,42 +73,6 @@ def is_json(mimetype): return False -def get_value(data, name, field, allow_many_nested=False): - """Get a value from a dictionary. Handles ``MultiDict`` types when - ``field`` handles repeated/multi-value arguments. - If the value is not found, return `missing`. - - :param object data: Mapping (e.g. `dict`) or list-like instance to - pull the value from. - :param str name: Name of the key. - :param bool allow_many_nested: Whether to allow a list of nested objects - (it is valid only for JSON format, so it is set to True in ``parse_json`` - methods). - """ - missing_value = missing - if allow_many_nested and isinstance(field, ma.fields.Nested) and field.many: - if is_collection(data): - return data - - if not hasattr(data, "get"): - return missing_value - - multiple = is_multiple(field) - val = data.get(name, missing_value) - if multiple and val is not missing: - if hasattr(data, "getlist"): - return data.getlist(name) - elif hasattr(data, "getall"): - return data.getall(name) - elif isinstance(val, (list, tuple)): - return val - if val is None: - return None - else: - return [val] - return val - - def parse_json(s, encoding="utf-8"): if isinstance(s, bytes): try: @@ -142,15 +105,16 @@ class Parser(object): """Base parser class that provides high-level implementation for parsing a request. - Descendant classes must provide lower-level implementations for parsing - different locations, e.g. ``parse_json``, ``parse_querystring``, etc. + Descendant classes must provide lower-level implementations for reading + data from different locations, e.g. ``load_json``, ``load_querystring``, + etc. - :param tuple locations: Default locations to parse. + :param str location: Default location to use for data :param callable error_handler: Custom error handler function. """ - #: Default locations to check for data - DEFAULT_LOCATIONS = ("querystring", "form", "json") + #: Default location to check for data + DEFAULT_LOCATION = "json" #: The marshmallow Schema class to use when creating new schemas DEFAULT_SCHEMA_CLASS = ma.Schema #: Default status code to return for validation errors @@ -160,38 +124,32 @@ class Parser(object): #: Maps location => method name __location_map__ = { - "json": "parse_json", - "querystring": "parse_querystring", - "query": "parse_querystring", - "form": "parse_form", - "headers": "parse_headers", - "cookies": "parse_cookies", - "files": "parse_files", + "json": "load_json", + "querystring": "load_querystring", + "query": "load_querystring", + "form": "load_form", + "headers": "load_headers", + "cookies": "load_cookies", + "files": "load_files", } - def __init__(self, locations=None, error_handler=None, schema_class=None): - self.locations = locations or self.DEFAULT_LOCATIONS + def __init__(self, location=None, error_handler=None, schema_class=None): + self.location = location or self.DEFAULT_LOCATION self.error_callback = _callable_or_raise(error_handler) self.schema_class = schema_class or self.DEFAULT_SCHEMA_CLASS #: A short-lived cache to store results from processing request bodies. self._cache = {} - def _validated_locations(self, locations): - """Ensure that the given locations argument is valid. + def _get_loader(self, location): + """Get the loader function for the given location. - :raises: ValueError if a given locations includes an invalid location. + :raises: ValueError if a given location is invalid. """ - # The set difference between the given locations and the available locations - # will be the set of invalid locations valid_locations = set(self.__location_map__.keys()) - given = set(locations) - invalid_locations = given - valid_locations - if len(invalid_locations): - msg = "Invalid locations arguments: {0}".format(list(invalid_locations)) + if location not in valid_locations: + msg = "Invalid location argument: {0}".format(location) raise ValueError(msg) - return locations - def _get_handler(self, location): # Parsing function to call # May be a method name (str) or a function func = self.__location_map__.get(location) @@ -204,73 +162,20 @@ def _get_handler(self, location): raise ValueError('Invalid location: "{0}"'.format(location)) return function - def _get_value(self, name, argobj, req, location): - function = self._get_handler(location) - return function(req, name, argobj) + def _load_location_data(self, schema, req, location): + """Return a dictionary-like object for the location on the given request. - def parse_arg(self, name, field, req, locations=None): - """Parse a single argument from a request. - - .. note:: - This method does not perform validation on the argument. - - :param str name: The name of the value. - :param marshmallow.fields.Field field: The marshmallow `Field` for the request - parameter. - :param req: The request object to parse. - :param tuple locations: The locations ('json', 'querystring', etc.) where - to search for the value. - :return: The unvalidated argument value or `missing` if the value cannot - be found on the request. + Needs to have the schema in hand in order to correctly handle loading + lists from multidict objects and `many=True` schemas. """ - location = field.metadata.get("location") - if location: - locations_to_check = self._validated_locations([location]) - else: - locations_to_check = self._validated_locations(locations or self.locations) - - for location in locations_to_check: - value = self._get_value(name, field, req=req, location=location) - # Found the value; validate and return it - if value is not missing: - return value - return missing - - def _parse_request(self, schema, req, locations): - """Return a parsed arguments dictionary for the current request.""" - if schema.many: - assert ( - "json" in locations - ), "schema.many=True is only supported for JSON location" - # The ad hoc Nested field is more like a workaround or a helper, - # and it servers its purpose fine. However, if somebody has a desire - # to re-design the support of bulk-type arguments, go ahead. - parsed = self.parse_arg( - name="json", - field=ma.fields.Nested(schema, many=True), - req=req, - locations=locations, - ) - if parsed is missing: - parsed = [] - else: - argdict = schema.fields - parsed = {} - for argname, field_obj in iteritems(argdict): - if MARSHMALLOW_VERSION_INFO[0] < 3: - parsed_value = self.parse_arg(argname, field_obj, req, locations) - # If load_from is specified on the field, try to parse from that key - if parsed_value is missing and field_obj.load_from: - parsed_value = self.parse_arg( - field_obj.load_from, field_obj, req, locations - ) - argname = field_obj.load_from - else: - argname = field_obj.data_key or argname - parsed_value = self.parse_arg(argname, field_obj, req, locations) - if parsed_value is not missing: - parsed[argname] = parsed_value - return parsed + loader_func = self._get_loader(location) + data = loader_func(req, schema) + # when the desired location is empty (no data), provide an empty + # dict as the default so that optional arguments in a location + # (e.g. optional JSON body) work smoothly + if data is missing: + data = {} + return data def _on_validation_error( self, error, req, schema, error_status_code, error_headers @@ -310,6 +215,10 @@ def _get_schema(self, argmap, req): return schema def _clone(self): + """Clone the current parser in order to ensure that it has a fresh and + independent cache. This is used whenever `Parser.parse` is called, so + that these methods always have separate caches. + """ clone = copy(self) clone.clear_cache() return clone @@ -318,7 +227,7 @@ def parse( self, argmap, req=None, - locations=None, + location=None, validate=None, error_status_code=None, error_headers=None, @@ -329,9 +238,9 @@ def parse( of argname -> `marshmallow.fields.Field` pairs, or a callable which accepts a request and returns a `marshmallow.Schema`. :param req: The request object to parse. - :param tuple locations: Where on the request to search for values. - Can include one or more of ``('json', 'querystring', 'form', - 'headers', 'cookies', 'files')``. + :param str location: Where on the request to load values. + Can be one of ``('json', 'querystring', 'form', 'headers', 'cookies', + 'files')``. :param callable validate: Validation function or list of validation functions that receives the dictionary of parsed arguments. Validator either returns a boolean or raises a :exc:`ValidationError`. @@ -342,18 +251,18 @@ def parse( :return: A dictionary of parsed arguments """ - self.clear_cache() # in case someone used `parse_*()` req = req if req is not None else self.get_default_request() - assert req is not None, "Must pass req object" + if req is None: + raise ValueError("Must pass req object") data = None validators = _ensure_list_of_callables(validate) parser = self._clone() schema = self._get_schema(argmap, req) try: - parsed = parser._parse_request( - schema=schema, req=req, locations=locations or self.locations + location_data = parser._load_location_data( + schema=schema, req=req, location=location or self.location ) - result = schema.load(parsed) + result = schema.load(location_data) data = result.data if MARSHMALLOW_VERSION_INFO[0] < 3 else result parser._validate_arguments(data, validators) except ma.exceptions.ValidationError as error: @@ -397,7 +306,7 @@ def use_args( self, argmap, req=None, - locations=None, + location=None, as_kwargs=False, validate=None, error_status_code=None, @@ -408,14 +317,14 @@ def use_args( Example usage with Flask: :: @app.route('/echo', methods=['get', 'post']) - @parser.use_args({'name': fields.Str()}) + @parser.use_args({'name': fields.Str()}, location="querystring") def greet(args): return 'Hello ' + args['name'] :param argmap: Either a `marshmallow.Schema`, a `dict` of argname -> `marshmallow.fields.Field` pairs, or a callable which accepts a request and returns a `marshmallow.Schema`. - :param tuple locations: Where on the request to search for values. + :param str locations: Where on the request to load values. :param bool as_kwargs: Whether to insert arguments as keyword arguments. :param callable validate: Validation function that receives the dictionary of parsed arguments. If the function returns ``False``, the parser @@ -425,7 +334,7 @@ def greet(args): :param dict error_headers: Headers passed to error handler functions when a a `ValidationError` is raised. """ - locations = locations or self.locations + location = location or self.location request_obj = req # Optimization: If argmap is passed as a dictionary, we only need # to generate a Schema once @@ -441,11 +350,12 @@ def wrapper(*args, **kwargs): if not req_obj: req_obj = self.get_request_from_view_args(func, args, kwargs) + # NOTE: At this point, argmap may be a Schema, or a callable parsed_args = self.parse( argmap, req=req_obj, - locations=locations, + location=location, validate=validate, error_status_code=error_status_code, error_headers=error_headers, @@ -481,19 +391,23 @@ def greet(name): kwargs["as_kwargs"] = True return self.use_args(*args, **kwargs) - def location_handler(self, name): - """Decorator that registers a function for parsing a request location. - The wrapped function receives a request, the name of the argument, and - the corresponding `Field ` object. + def location_loader(self, name): + """Decorator that registers a function for loading a request location. + The wrapped function receives a schema and a request. + + The schema will usually not be relevant, but it's important in some + cases -- most notably in order to correctly load multidict values into + list fields. Without the schema, there would be no way to know whether + to simply `.get()` or `.getall()` from a multidict for a given value. Example: :: from webargs import core parser = core.Parser() - @parser.location_handler("name") - def parse_data(request, name, field): - return request.data.get(name) + @parser.location_loader("name") + def load_data(request, schema): + return request.data :param str name: The name of the location to register. """ @@ -533,39 +447,38 @@ def handle_error(error, req, schema, status_code, headers): # Abstract Methods - def parse_json(self, req, name, arg): - """Pull a JSON value from a request object or return `missing` if the - value cannot be found. + def load_json(self, req, schema): + """Load JSON from a request object or return `missing` if no value can + be found. """ return missing - def parse_querystring(self, req, name, arg): - """Pull a value from the query string of a request object or return `missing` if - the value cannot be found. + def load_querystring(self, req, schema): + """Load the query string of a request object or return `missing` if no + value can be found. """ return missing - def parse_form(self, req, name, arg): - """Pull a value from the form data of a request object or return - `missing` if the value cannot be found. + def load_form(self, req, schema): + """Load the form data of a request object or return `missing` if no + value can be found. """ return missing - def parse_headers(self, req, name, arg): - """Pull a value from the headers or return `missing` if the value - cannot be found. + def load_headers(self, req, schema): + """Load the headers or return `missing` if no value can be found. """ return missing - def parse_cookies(self, req, name, arg): - """Pull a cookie value from the request or return `missing` if the value - cannot be found. + def load_cookies(self, req, schema): + """Load the cookies from the request or return `missing` if no value + can be found. """ return missing - def parse_files(self, req, name, arg): - """Pull a file from the request or return `missing` if the value file - cannot be found. + def load_files(self, req, schema): + """Load files from the request or return `missing` if no values can be + found. """ return missing diff --git a/src/webargs/falconparser.py b/src/webargs/falconparser.py index b8c5ec76..1113a6f5 100644 --- a/src/webargs/falconparser.py +++ b/src/webargs/falconparser.py @@ -6,6 +6,7 @@ from webargs import core from webargs.core import json +from webargs.multidictproxy import MultiDictProxy HTTP_422 = "422 Unprocessable Entity" @@ -69,7 +70,8 @@ def parse_form_body(req): return parse_query_string( body, keep_blank_qs_values=req.options.keep_blank_qs_values ) - return {} + + return core.missing class HTTPError(falcon.HTTPError): @@ -95,7 +97,7 @@ def parse_querystring(self, req, name, field): """Pull a querystring value from the request.""" return core.get_value(req.params, name, field) - def parse_form(self, req, name, field): + def location_load_form(self, req, schema): """Pull a form value from the request. .. note:: @@ -105,7 +107,9 @@ def parse_form(self, req, name, field): form = self._cache.get("form") if form is None: self._cache["form"] = form = parse_form_body(req) - return core.get_value(form, name, field) + if form is core.missing: + return form + return MultiDictProxy(form, schema) def parse_json(self, req, name, field): """Pull a JSON body value from the request. diff --git a/src/webargs/flaskparser.py b/src/webargs/flaskparser.py index 9f6b38f7..9f956dcb 100644 --- a/src/webargs/flaskparser.py +++ b/src/webargs/flaskparser.py @@ -24,6 +24,7 @@ def index(args): from webargs import core from webargs.core import json +from webargs.multidictproxy import MultiDictProxy def abort(http_status_code, exc=None, **kwargs): @@ -48,17 +49,23 @@ class FlaskParser(core.Parser): """Flask request argument parser.""" __location_map__ = dict( - view_args="parse_view_args", - path="parse_view_args", + view_args="load_view_args", + path="load_view_args", **core.Parser.__location_map__ ) - def parse_view_args(self, req, name, field): - """Pull a value from the request's ``view_args``.""" - return core.get_value(req.view_args, name, field) + def load_view_args(self, req, schema): + """Read the request's ``view_args`` or ``missing`` if there are none.""" + return req.view_args or core.missing + + def load_json(self, req, schema): + """Read a json payload from the request. + + Checks the input mimetype and may return 'missing' if the mimetype is + non-json, even if the request body is parseable as json.""" + if not is_json_request(req): + return core.missing - def parse_json(self, req, name, field): - """Pull a json value from the request.""" json_data = self._cache.get("json") if json_data is None: # We decode the json manually here instead of @@ -72,31 +79,36 @@ def parse_json(self, req, name, field): return core.missing else: return self.handle_invalid_json_error(e, req) - return core.get_value(json_data, name, field, allow_many_nested=True) - - def parse_querystring(self, req, name, field): - """Pull a querystring value from the request.""" - return core.get_value(req.args, name, field) - - def parse_form(self, req, name, field): - """Pull a form value from the request.""" - try: - return core.get_value(req.form, name, field) - except AttributeError: - pass - return core.missing - - def parse_headers(self, req, name, field): - """Pull a value from the header data.""" - return core.get_value(req.headers, name, field) - - def parse_cookies(self, req, name, field): - """Pull a value from the cookiejar.""" - return core.get_value(req.cookies, name, field) - - def parse_files(self, req, name, field): - """Pull a file from the request.""" - return core.get_value(req.files, name, field) + + return json_data + + def load_querystring(self, req, schema): + """Read query params from the request. + + Is a multidict.""" + return MultiDictProxy(req.args, schema) + + def load_form(self, req, schema): + """Read form values from the request. + + Is a multidict.""" + return MultiDictProxy(req.form, schema) + + def load_headers(self, req, schema): + """Read headers from the request. + + Is a multidict.""" + return MultiDictProxy(req.headers, schema) + + def load_cookies(self, req, schema): + """Read cookies from the request.""" + return req.cookies + + def load_files(self, req, schema): + """Read files from the request. + + Is a multidict.""" + return MultiDictProxy(req.files, schema) def handle_error(self, error, req, schema, error_status_code, error_headers): """Handles errors during parsing. Aborts the current HTTP request and diff --git a/src/webargs/multidictproxy.py b/src/webargs/multidictproxy.py new file mode 100644 index 00000000..9ac4fa5c --- /dev/null +++ b/src/webargs/multidictproxy.py @@ -0,0 +1,67 @@ +from webargs.compat import MARSHMALLOW_VERSION_INFO, Mapping +from webargs.core import missing, is_multiple + + +class MultiDictProxy(Mapping): + """ + A proxy object which wraps multidict types along with a matching schema + Whenever a value is looked up, it is checked against the schema to see if + there is a matching field where `is_multiple` is True. If there is, then + the data should be loaded as a list or tuple. + + In all other cases, __getitem__ proxies directly to the input multidict. + """ + + def __init__(self, multidict, schema): + self.data = multidict + self.multiple_keys = self._collect_multiple_keys(schema) + + def _collect_multiple_keys(self, schema): + result = set() + for name, field in schema.fields.items(): + if not is_multiple(field): + continue + if MARSHMALLOW_VERSION_INFO[0] < 3: + result.add(field.load_from if field.load_from is not None else name) + else: + result.add(field.data_key if field.data_key is not None else name) + return result + + def __getitem__(self, key): + val = self.data.get(key, missing) + if val is not missing and key in self.multiple_keys: + if hasattr(self.data, "getlist"): + return self.data.getlist(key) + elif hasattr(self.data, "getall"): + return self.data.getall(key) + elif isinstance(val, (list, tuple)): + return val + if val is None: + return None + else: + return [val] + return val + + def __delitem__(self, key): + del self.data[key] + + def __setitem__(self, key, value): + self.data[key] = value + + def __getattr__(self, name): + return getattr(self.data, name) + + def __iter__(self): + return iter(self.data) + + def __contains__(self, x): + return x in self.data + + def __len__(self): + return len(self.data) + + def __eq__(self, other): + return self.data == other + + def __ne__(self, other): + return self.data != other diff --git a/src/webargs/testing.py b/src/webargs/testing.py index 922bc473..50cf27cf 100644 --- a/src/webargs/testing.py +++ b/src/webargs/testing.py @@ -40,24 +40,23 @@ def testapp(self): def test_parse_querystring_args(self, testapp): assert testapp.get("/echo?name=Fred").json == {"name": "Fred"} - def test_parse_querystring_with_query_location_specified(self, testapp): - assert testapp.get("/echo_query?name=Steve").json == {"name": "Steve"} - def test_parse_form(self, testapp): - assert testapp.post("/echo", {"name": "Joe"}).json == {"name": "Joe"} + assert testapp.post("/echo_form", {"name": "Joe"}).json == {"name": "Joe"} def test_parse_json(self, testapp): - assert testapp.post_json("/echo", {"name": "Fred"}).json == {"name": "Fred"} + assert testapp.post_json("/echo_json", {"name": "Fred"}).json == { + "name": "Fred" + } def test_parse_querystring_default(self, testapp): assert testapp.get("/echo").json == {"name": "World"} def test_parse_json_default(self, testapp): - assert testapp.post_json("/echo", {}).json == {"name": "World"} + assert testapp.post_json("/echo_json", {}).json == {"name": "World"} def test_parse_json_with_charset(self, testapp): res = testapp.post( - "/echo", + "/echo_json", json.dumps({"name": "Steve"}), content_type="application/json;charset=UTF-8", ) @@ -65,23 +64,27 @@ def test_parse_json_with_charset(self, testapp): def test_parse_json_with_vendor_media_type(self, testapp): res = testapp.post( - "/echo", + "/echo_json", json.dumps({"name": "Steve"}), content_type="application/vnd.api+json;charset=UTF-8", ) assert res.json == {"name": "Steve"} def test_parse_json_ignores_extra_data(self, testapp): - assert testapp.post_json("/echo", {"extra": "data"}).json == {"name": "World"} + assert testapp.post_json("/echo_json", {"extra": "data"}).json == { + "name": "World" + } - def test_parse_json_blank(self, testapp): - assert testapp.post_json("/echo", None).json == {"name": "World"} + def test_parse_json_empty(self, testapp): + assert testapp.post_json("/echo_json", {}).json == {"name": "World"} - def test_parse_json_ignore_unexpected_int(self, testapp): - assert testapp.post_json("/echo", 1).json == {"name": "World"} + def test_parse_json_error_unexpected_int(self, testapp): + res = testapp.post_json("/echo_json", 1, expect_errors=True) + assert res.status_code == 422 - def test_parse_json_ignore_unexpected_list(self, testapp): - assert testapp.post_json("/echo", [{"extra": "data"}]).json == {"name": "World"} + def test_parse_json_error_unexpected_list(self, testapp): + res = testapp.post_json("/echo_json", [{"extra": "data"}], expect_errors=True) + assert res.status_code == 422 def test_parse_json_many_schema_invalid_input(self, testapp): res = testapp.post_json( @@ -93,11 +96,14 @@ def test_parse_json_many_schema(self, testapp): res = testapp.post_json("/echo_many_schema", [{"name": "Steve"}]).json assert res == [{"name": "Steve"}] - def test_parse_json_many_schema_ignore_malformed_data(self, testapp): - assert testapp.post_json("/echo_many_schema", {"extra": "data"}).json == [] + def test_parse_json_many_schema_error_malformed_data(self, testapp): + res = testapp.post_json( + "/echo_many_schema", {"extra": "data"}, expect_errors=True + ) + assert res.status_code == 422 def test_parsing_form_default(self, testapp): - assert testapp.post("/echo", {}).json == {"name": "World"} + assert testapp.post("/echo_form", {}).json == {"name": "World"} def test_parse_querystring_multiple(self, testapp): expected = {"name": ["steve", "Loria"]} @@ -106,16 +112,25 @@ def test_parse_querystring_multiple(self, testapp): def test_parse_form_multiple(self, testapp): expected = {"name": ["steve", "Loria"]} assert ( - testapp.post("/echo_multi", {"name": ["steve", "Loria"]}).json == expected + testapp.post("/echo_multi_form", {"name": ["steve", "Loria"]}).json + == expected ) def test_parse_json_list(self, testapp): expected = {"name": ["Steve"]} - assert testapp.post_json("/echo_multi", {"name": "Steve"}).json == expected + assert ( + testapp.post_json("/echo_multi_json", {"name": ["Steve"]}).json == expected + ) + + def test_parse_json_list_error_malformed_data(self, testapp): + res = testapp.post_json( + "/echo_multi_json", {"name": "Steve"}, expect_errors=True + ) + assert res.status_code == 422 def test_parse_json_with_nonascii_chars(self, testapp): text = u"øˆƒ£ºº∆ƒˆ∆" - assert testapp.post_json("/echo", {"name": text}).json == {"name": text} + assert testapp.post_json("/echo_json", {"name": text}).json == {"name": text} # https://github.com/marshmallow-code/webargs/issues/427 def test_parse_json_with_nonutf8_chars(self, testapp): @@ -130,7 +145,7 @@ def test_parse_json_with_nonutf8_chars(self, testapp): assert res.json == {"json": ["Invalid JSON body."]} def test_validation_error_returns_422_response(self, testapp): - res = testapp.post("/echo", {"name": "b"}, expect_errors=True) + res = testapp.post("/echo_json", {"name": "b"}, expect_errors=True) assert res.status_code == 422 def test_user_validation_error_returns_422_response_by_default(self, testapp): @@ -187,10 +202,6 @@ def test_parse_nested_many_missing(self, testapp): res = testapp.post_json("/echo_nested_many", in_data) assert res.json == {} - def test_parse_json_if_no_json(self, testapp): - res = testapp.post("/echo") - assert res.json == {"name": "World"} - def test_parse_files(self, testapp): res = testapp.post( "/echo_file", {"myfile": webtest.Upload("README.rst", b"data")} @@ -199,8 +210,14 @@ def test_parse_files(self, testapp): # https://github.com/sloria/webargs/pull/297 def test_empty_json(self, testapp): + res = testapp.post("/echo_json") + assert res.status_code == 200 + assert res.json == {"name": "World"} + + # https://github.com/sloria/webargs/pull/297 + def test_empty_json_with_headers(self, testapp): res = testapp.post( - "/echo", + "/echo_json", "", headers={"Accept": "application/json", "Content-Type": "application/json"}, ) @@ -210,7 +227,7 @@ def test_empty_json(self, testapp): # https://github.com/sloria/webargs/issues/329 def test_invalid_json(self, testapp): res = testapp.post( - "/echo", + "/echo_json", '{"foo": "bar", }', headers={"Accept": "application/json", "Content-Type": "application/json"}, expect_errors=True, diff --git a/tests/apps/flask_app.py b/tests/apps/flask_app.py index 019cb9f9..abbc9589 100644 --- a/tests/apps/flask_app.py +++ b/tests/apps/flask_app.py @@ -3,77 +3,100 @@ from flask.views import MethodView import marshmallow as ma -from webargs import fields +from webargs import fields, dict2schema from webargs.flaskparser import parser, use_args, use_kwargs from webargs.core import MARSHMALLOW_VERSION_INFO +if MARSHMALLOW_VERSION_INFO[0] < 3: + schema_kwargs = {"strict": True} +else: + schema_kwargs = {"unknown": ma.EXCLUDE} + class TestAppConfig: TESTING = True -hello_args = {"name": fields.Str(missing="World", validate=lambda n: len(n) >= 3)} -hello_multiple = {"name": fields.List(fields.Str())} +hello_args = dict2schema( + {"name": fields.Str(missing="World", validate=lambda n: len(n) >= 3)} +)(**schema_kwargs) +hello_multiple = dict2schema({"name": fields.List(fields.Str())})(**schema_kwargs) class HelloSchema(ma.Schema): name = fields.Str(missing="World", validate=lambda n: len(n) >= 3) -strict_kwargs = {"strict": True} if MARSHMALLOW_VERSION_INFO[0] < 3 else {} -hello_many_schema = HelloSchema(many=True, **strict_kwargs) +hello_many_schema = HelloSchema(many=True, **schema_kwargs) app = Flask(__name__) app.config.from_object(TestAppConfig) -@app.route("/echo", methods=["GET", "POST"]) +@app.route("/echo", methods=["GET"]) def echo(): - return J(parser.parse(hello_args)) + return J(parser.parse(hello_args, location="query")) + +@app.route("/echo_form", methods=["POST"]) +def echo_form(): + return J(parser.parse(hello_args, location="form")) -@app.route("/echo_query") -def echo_query(): - return J(parser.parse(hello_args, request, locations=("query",))) + +@app.route("/echo_json", methods=["POST"]) +def echo_json(): + return J(parser.parse(hello_args)) -@app.route("/echo_use_args", methods=["GET", "POST"]) -@use_args(hello_args) +@app.route("/echo_use_args", methods=["GET"]) +@use_args(hello_args, location="query") def echo_use_args(args): return J(args) -@app.route("/echo_use_args_validated", methods=["GET", "POST"]) -@use_args({"value": fields.Int()}, validate=lambda args: args["value"] > 42) +@app.route("/echo_use_args_validated", methods=["POST"]) +@use_args( + {"value": fields.Int()}, validate=lambda args: args["value"] > 42, location="form" +) def echo_use_args_validated(args): return J(args) -@app.route("/echo_use_kwargs", methods=["GET", "POST"]) -@use_kwargs(hello_args) +@app.route("/echo_use_kwargs", methods=["GET"]) +@use_kwargs(hello_args, location="query") def echo_use_kwargs(name): return J({"name": name}) -@app.route("/echo_multi", methods=["GET", "POST"]) +@app.route("/echo_multi", methods=["GET"]) def multi(): + return J(parser.parse(hello_multiple, location="query")) + + +@app.route("/echo_multi_form", methods=["POST"]) +def multi_form(): + return J(parser.parse(hello_multiple, location="form")) + + +@app.route("/echo_multi_json", methods=["POST"]) +def multi_json(): return J(parser.parse(hello_multiple)) @app.route("/echo_many_schema", methods=["GET", "POST"]) def many_nested(): - arguments = parser.parse(hello_many_schema, locations=("json",)) + arguments = parser.parse(hello_many_schema) return Response(json.dumps(arguments), content_type="application/json") @app.route("/echo_use_args_with_path_param/") -@use_args({"value": fields.Int()}) +@use_args({"value": fields.Int()}, location="query") def echo_use_args_with_path(args, name): return J(args) @app.route("/echo_use_kwargs_with_path_param/") -@use_kwargs({"value": fields.Int()}) +@use_kwargs({"value": fields.Int()}, location="query") def echo_use_kwargs_with_path(name, value): return J({"value": value}) @@ -89,18 +112,18 @@ def always_fail(value): @app.route("/echo_headers") def echo_headers(): - return J(parser.parse(hello_args, locations=("headers",))) + return J(parser.parse(hello_args, location="headers")) @app.route("/echo_cookie") def echo_cookie(): - return J(parser.parse(hello_args, request, locations=("cookies",))) + return J(parser.parse(hello_args, request, location="cookies")) @app.route("/echo_file", methods=["POST"]) def echo_file(): args = {"myfile": fields.Field()} - result = parser.parse(args, locations=("files",)) + result = parser.parse(args, location="files") fp = result["myfile"] content = fp.read().decode("utf8") return J({"myfile": content}) @@ -108,11 +131,11 @@ def echo_file(): @app.route("/echo_view_arg/") def echo_view_arg(view_arg): - return J(parser.parse({"view_arg": fields.Int()}, locations=("view_args",))) + return J(parser.parse({"view_arg": fields.Int()}, location="view_args")) @app.route("/echo_view_arg_use_args/") -@use_args({"view_arg": fields.Int(location="view_args")}) +@use_args({"view_arg": fields.Int()}, location="view_args") def echo_view_arg_with_use_args(args, **kwargs): return J(args) diff --git a/tests/test_core.py b/tests/test_core.py index be8038b7..c4c1ba14 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -9,15 +9,15 @@ from django.utils.datastructures import MultiValueDict as DjMultiDict from bottle import MultiDict as BotMultiDict -from webargs import fields, missing, ValidationError +from webargs import fields, ValidationError from webargs.core import ( Parser, - get_value, dict2schema, is_json, get_mimetype, MARSHMALLOW_VERSION_INFO, ) +from webargs.multidictproxy import MultiDictProxy strict_kwargs = {"strict": True} if MARSHMALLOW_VERSION_INFO[0] < 3 else {} @@ -33,14 +33,14 @@ def __init__(self, status_code, headers): class MockRequestParser(Parser): """A minimal parser implementation that parses mock requests.""" - def parse_querystring(self, req, name, field): - return get_value(req.query, name, field) + def load_querystring(self, req, schema): + return MultiDictProxy(req.query, schema) - def parse_json(self, req, name, field): - return get_value(req.json, name, field) + def load_json(self, req, schema): + return req.json - def parse_cookies(self, req, name, field): - return get_value(req.cookies, name, field) + def load_cookies(self, req, schema): + return req.cookies @pytest.yield_fixture(scope="function") @@ -59,66 +59,70 @@ def parser(): # Parser tests -@mock.patch("webargs.core.Parser.parse_json") -def test_parse_json_called_by_parse_arg(parse_json, web_request): - field = fields.Field() +@mock.patch("webargs.core.Parser.load_json") +def test_load_json_called_by_parse_default(load_json, web_request): + schema = dict2schema({"foo": fields.Field()})() + load_json.return_value = {"foo": 1} p = Parser() - p.parse_arg("foo", field, web_request) - parse_json.assert_called_with(web_request, "foo", field) + p.parse(schema, web_request) + load_json.assert_called_with(web_request, schema) -@mock.patch("webargs.core.Parser.parse_querystring") -def test_parse_querystring_called_by_parse_arg(parse_querystring, web_request): - field = fields.Field() - p = Parser() - p.parse_arg("foo", field, web_request) - assert parse_querystring.called_once() - +@pytest.mark.parametrize( + "location", ["querystring", "form", "headers", "cookies", "files"] +) +def test_load_nondefault_called_by_parse_with_location(location, web_request): + with mock.patch( + "webargs.core.Parser.load_{}".format(location) + ) as mock_loadfunc, mock.patch("webargs.core.Parser.load_json") as load_json: + mock_loadfunc.return_value = {} + load_json.return_value = {} + p = Parser() + + # ensure that without location=..., the loader is not called (json is + # called) + p.parse({"foo": fields.Field()}, web_request) + assert mock_loadfunc.call_count == 0 + assert load_json.call_count == 1 -@mock.patch("webargs.core.Parser.parse_form") -def test_parse_form_called_by_parse_arg(parse_form, web_request): - field = fields.Field() - p = Parser() - p.parse_arg("foo", field, web_request) - assert parse_form.called_once() + # but when location=... is given, the loader *is* called and json is + # not called + p.parse({"foo": fields.Field()}, web_request, location=location) + assert mock_loadfunc.call_count == 1 + # it was already 1, should not go up + assert load_json.call_count == 1 -@mock.patch("webargs.core.Parser.parse_json") -def test_parse_json_not_called_when_json_not_a_location(parse_json, web_request): - field = fields.Field() - p = Parser() - p.parse_arg("foo", field, web_request, locations=("form", "querystring")) - assert parse_json.call_count == 0 - +def test_parse(parser, web_request): + web_request.json = {"username": 42, "password": 42} + argmap = {"username": fields.Field(), "password": fields.Field()} + ret = parser.parse(argmap, web_request) + assert {"username": 42, "password": 42} == ret -@mock.patch("webargs.core.Parser.parse_headers") -def test_parse_headers_called_when_headers_is_a_location(parse_headers, web_request): - field = fields.Field() - p = Parser() - p.parse_arg("foo", field, web_request) - assert parse_headers.call_count == 0 - p.parse_arg("foo", field, web_request, locations=("headers",)) - parse_headers.assert_called() +@pytest.mark.skipif( + MARSHMALLOW_VERSION_INFO[0] < 3, reason="unknown=EXCLUDE added in marshmallow3" +) +def test_parse_with_excluding_schema(parser, web_request): + """ + This is new in webargs 6.x ; it's the way you can "get back" the behavior + of webargs 5.x in which extra args are ignored + """ + from marshmallow import EXCLUDE -@mock.patch("webargs.core.Parser.parse_cookies") -def test_parse_cookies_called_when_cookies_is_a_location(parse_cookies, web_request): - field = fields.Field() - p = Parser() - p.parse_arg("foo", field, web_request) - assert parse_cookies.call_count == 0 - p.parse_arg("foo", field, web_request, locations=("cookies",)) - parse_cookies.assert_called() + web_request.json = {"username": 42, "password": 42, "fjords": 42} + class CustomSchema(Schema): + username = fields.Field() + password = fields.Field() -@mock.patch("webargs.core.Parser.parse_json") -def test_parse(parse_json, web_request): - parse_json.return_value = 42 - argmap = {"username": fields.Field(), "password": fields.Field()} - p = Parser() - ret = p.parse(argmap, web_request) + ret = parser.parse(CustomSchema(unknown=EXCLUDE), web_request) assert {"username": 42, "password": 42} == ret + # but without unknown=EXCLUDE, it blows up + with pytest.raises(ValidationError, match="Unknown field."): + parser.parse(CustomSchema(), web_request) + def test_parse_required_arg_raises_validation_error(parser, web_request): web_request.json = {} @@ -141,13 +145,10 @@ def test_arg_allow_none(parser, web_request): assert result == {"first": "Steve", "last": None} -@mock.patch("webargs.core.Parser.parse_json") -def test_parse_required_arg(parse_json, web_request): - arg = fields.Field(required=True) - parse_json.return_value = 42 - p = Parser() - result = p.parse_arg("foo", arg, web_request, locations=("json",)) - assert result == 42 +def test_parse_required_arg(parser, web_request): + web_request.json = {"foo": 42} + result = parser.parse({"foo": fields.Field(required=True)}, web_request) + assert result == {"foo": 42} def test_parse_required_list(parser, web_request): @@ -185,21 +186,21 @@ def test_parse_missing_list(parser, web_request): assert parser.parse(args, web_request) == {} -def test_default_locations(): - assert set(Parser.DEFAULT_LOCATIONS) == set(["json", "querystring", "form"]) +def test_default_location(): + assert Parser.DEFAULT_LOCATION == "json" def test_missing_with_default(parser, web_request): web_request.json = {} args = {"val": fields.Field(missing="pizza")} - result = parser.parse(args, web_request, locations=("json",)) + result = parser.parse(args, web_request) assert result["val"] == "pizza" def test_default_can_be_none(parser, web_request): web_request.json = {} args = {"val": fields.Field(missing=None, allow_none=True)} - result = parser.parse(args, web_request, locations=("json",)) + result = parser.parse(args, web_request) assert result["val"] is None @@ -217,34 +218,22 @@ def test_arg_with_default_and_location(parser, web_request): assert parser.parse(args, web_request) == {"p": 1} -def test_value_error_raised_if_parse_arg_called_with_invalid_location(web_request): +def test_value_error_raised_if_parse_called_with_invalid_location(parser, web_request): field = fields.Field() - p = Parser() - with pytest.raises(ValueError) as excinfo: - p.parse_arg("foo", field, web_request, locations=("invalidlocation", "headers")) - msg = "Invalid locations arguments: {0}".format(["invalidlocation"]) - assert msg in str(excinfo.value) - - -def test_value_error_raised_if_invalid_location_on_field(web_request, parser): - with pytest.raises(ValueError) as excinfo: - parser.parse({"foo": fields.Field(location="invalidlocation")}, web_request) - msg = "Invalid locations arguments: {0}".format(["invalidlocation"]) - assert msg in str(excinfo.value) + with pytest.raises(ValueError, match="Invalid location argument: invalidlocation"): + parser.parse({"foo": field}, web_request, location="invalidlocation") @mock.patch("webargs.core.Parser.handle_error") -@mock.patch("webargs.core.Parser.parse_json") -def test_handle_error_called_when_parsing_raises_error( - parse_json, handle_error, web_request -): - val_err = ValidationError("error occurred") - parse_json.side_effect = val_err +def test_handle_error_called_when_parsing_raises_error(handle_error, web_request): + def always_fail(*args, **kwargs): + raise ValidationError("error occurred") + p = Parser() - p.parse({"foo": fields.Field()}, web_request, locations=("json",)) - handle_error.assert_called() - parse_json.side_effect = ValidationError("another exception") - p.parse({"foo": fields.Field()}, web_request, locations=("json",)) + assert handle_error.call_count == 0 + p.parse({"foo": fields.Field()}, web_request, validate=always_fail) + assert handle_error.call_count == 1 + p.parse({"foo": fields.Field()}, web_request, validate=always_fail) assert handle_error.call_count == 2 @@ -254,22 +243,15 @@ def test_handle_error_reraises_errors(web_request): p.handle_error(ValidationError("error raised"), web_request, Schema()) -@mock.patch("webargs.core.Parser.parse_headers") -def test_locations_as_init_arguments(parse_headers, web_request): - p = Parser(locations=("headers",)) +@mock.patch("webargs.core.Parser.load_headers") +def test_location_as_init_argument(load_headers, web_request): + p = Parser(location="headers") + load_headers.return_value = {} p.parse({"foo": fields.Field()}, web_request) - assert parse_headers.called - - -@mock.patch("webargs.core.Parser.parse_files") -def test_parse_files(parse_files, web_request): - p = Parser() - p.parse({"foo": fields.Field()}, web_request, locations=("files",)) - assert parse_files.called + assert load_headers.called -@mock.patch("webargs.core.Parser.parse_json") -def test_custom_error_handler(parse_json, web_request): +def test_custom_error_handler(web_request): class CustomError(Exception): pass @@ -277,19 +259,27 @@ def error_handler(error, req, schema, status_code, headers): assert isinstance(schema, Schema) raise CustomError(error) - parse_json.side_effect = ValidationError("parse_json failed") + def failing_validate_func(args): + raise ValidationError("parsing failed") + + class MySchema(Schema): + foo = fields.Int() + + myschema = MySchema(**strict_kwargs) + web_request.json = {"foo": "hello world"} + p = Parser(error_handler=error_handler) with pytest.raises(CustomError): - p.parse({"foo": fields.Field()}, web_request) + p.parse(myschema, web_request, validate=failing_validate_func) -@mock.patch("webargs.core.Parser.parse_json") -def test_custom_error_handler_decorator(parse_json, web_request): +def test_custom_error_handler_decorator(web_request): class CustomError(Exception): pass - parse_json.side_effect = ValidationError("parse_json failed") - + mock_schema = mock.Mock(spec=Schema) + mock_schema.strict = True + mock_schema.load.side_effect = ValidationError("parsing json failed") parser = Parser() @parser.error_handler @@ -298,53 +288,47 @@ def handle_error(error, req, schema, status_code, headers): raise CustomError(error) with pytest.raises(CustomError): - parser.parse({"foo": fields.Field()}, web_request) + parser.parse(mock_schema, web_request) -def test_custom_location_handler(web_request): +def test_custom_location_loader(web_request): web_request.data = {"foo": 42} parser = Parser() - @parser.location_handler("data") - def parse_data(req, name, arg): - return req.data.get(name, missing) + @parser.location_loader("data") + def load_data(req, schema): + return req.data - result = parser.parse({"foo": fields.Int()}, web_request, locations=("data",)) + result = parser.parse({"foo": fields.Int()}, web_request, location="data") assert result["foo"] == 42 -def test_custom_location_handler_with_data_key(web_request): +def test_custom_location_loader_with_data_key(web_request): web_request.data = {"X-Foo": 42} parser = Parser() - @parser.location_handler("data") - def parse_data(req, name, arg): - return req.data.get(name, missing) + @parser.location_loader("data") + def load_data(req, schema): + return req.data data_key_kwarg = { "load_from" if (MARSHMALLOW_VERSION_INFO[0] < 3) else "data_key": "X-Foo" } result = parser.parse( - {"x_foo": fields.Int(**data_key_kwarg)}, web_request, locations=("data",) + {"x_foo": fields.Int(**data_key_kwarg)}, web_request, location="data" ) assert result["x_foo"] == 42 -def test_full_input_validation(web_request): +def test_full_input_validation(parser, web_request): web_request.json = {"foo": 41, "bar": 42} - parser = MockRequestParser() args = {"foo": fields.Int(), "bar": fields.Int()} with pytest.raises(ValidationError): # Test that `validate` receives dictionary of args - parser.parse( - args, - web_request, - locations=("json",), - validate=lambda args: args["foo"] > args["bar"], - ) + parser.parse(args, web_request, validate=lambda args: args["foo"] > args["bar"]) def test_full_input_validation_with_multiple_validators(web_request, parser): @@ -360,31 +344,29 @@ def validate2(args): web_request.json = {"a": 2, "b": 1} validators = [validate1, validate2] with pytest.raises(ValidationError, match="b must be > a"): - parser.parse(args, web_request, locations=("json",), validate=validators) + parser.parse(args, web_request, validate=validators) web_request.json = {"a": 1, "b": 2} with pytest.raises(ValidationError, match="a must be > b"): - parser.parse(args, web_request, locations=("json",), validate=validators) + parser.parse(args, web_request, validate=validators) -def test_required_with_custom_error(web_request): +def test_required_with_custom_error(parser, web_request): web_request.json = {} - parser = MockRequestParser() args = { "foo": fields.Str(required=True, error_messages={"required": "We need foo"}) } with pytest.raises(ValidationError) as excinfo: # Test that `validate` receives dictionary of args - parser.parse(args, web_request, locations=("json",)) + parser.parse(args, web_request) assert "We need foo" in excinfo.value.messages["foo"] if MARSHMALLOW_VERSION_INFO[0] < 3: assert "foo" in excinfo.value.field_names -def test_required_with_custom_error_and_validation_error(web_request): +def test_required_with_custom_error_and_validation_error(parser, web_request): web_request.json = {"foo": ""} - parser = MockRequestParser() args = { "foo": fields.Str( required="We need foo", @@ -394,7 +376,7 @@ def test_required_with_custom_error_and_validation_error(web_request): } with pytest.raises(ValidationError) as excinfo: # Test that `validate` receives dictionary of args - parser.parse(args, web_request, locations=("json",)) + parser.parse(args, web_request) assert "foo required length is 3" in excinfo.value.args[0]["foo"] if MARSHMALLOW_VERSION_INFO[0] < 3: @@ -410,7 +392,7 @@ def validate(val): parser = MockRequestParser() args = {"text": fields.Str()} with pytest.raises(ValidationError) as excinfo: - parser.parse(args, web_request, locations=("json",), validate=validate) + parser.parse(args, web_request, validate=validate) assert excinfo.value.messages == ["Invalid value."] @@ -420,14 +402,6 @@ def test_invalid_argument_for_validate(web_request, parser): assert "not a callable or list of callables." in excinfo.value.args[0] -def test_get_value_basic(): - assert get_value({"foo": 42}, "foo", False) == 42 - assert get_value({"foo": 42}, "bar", False) is missing - assert get_value({"foos": ["a", "b"]}, "foos", True) == ["a", "b"] - # https://github.com/marshmallow-code/webargs/pull/30 - assert get_value({"foos": ["a", "b"]}, "bar", True) is missing - - def create_bottle_multi_dict(): d = BotMultiDict() d["foos"] = "a" @@ -443,9 +417,24 @@ def create_bottle_multi_dict(): @pytest.mark.parametrize("input_dict", multidicts) -def test_get_value_multidict(input_dict): - field = fields.List(fields.Str()) - assert get_value(input_dict, "foos", field) == ["a", "b"] +def test_multidict_proxy(input_dict): + class ListSchema(Schema): + foos = fields.List(fields.Str()) + + class StrSchema(Schema): + foos = fields.Str() + + # this MultiDictProxy is aware that "foos" is a list field and will + # therefore produce a list with __getitem__ + list_wrapped_multidict = MultiDictProxy(input_dict, ListSchema()) + + # this MultiDictProxy is under the impression that "foos" is just a string + # and it should return "a" or "b" + # the decision between "a" and "b" in this case belongs to the framework + str_wrapped_multidict = MultiDictProxy(input_dict, StrSchema()) + + assert list_wrapped_multidict["foos"] == ["a", "b"] + assert str_wrapped_multidict["foos"] in ("a", "b") def test_parse_with_data_key(web_request): @@ -456,7 +445,7 @@ def test_parse_with_data_key(web_request): "load_from" if (MARSHMALLOW_VERSION_INFO[0] < 3) else "data_key": "Content-Type" } args = {"content_type": fields.Field(**data_key_kwargs)} - parsed = parser.parse(args, web_request, locations=("json",)) + parsed = parser.parse(args, web_request) assert parsed == {"content_type": "application/json"} @@ -470,7 +459,7 @@ def test_load_from_is_checked_after_given_key(web_request): parser = MockRequestParser() args = {"content_type": fields.Field(load_from="Content-Type")} - parsed = parser.parse(args, web_request, locations=("json",)) + parsed = parser.parse(args, web_request) assert parsed == {"content_type": "application/json"} @@ -483,7 +472,7 @@ def test_parse_with_data_key_retains_field_name_in_error(web_request): } args = {"content_type": fields.Str(**data_key_kwargs)} with pytest.raises(ValidationError) as excinfo: - parser.parse(args, web_request, locations=("json",)) + parser.parse(args, web_request) assert "Content-Type" in excinfo.value.messages assert excinfo.value.messages["Content-Type"] == ["Not a valid string."] @@ -496,7 +485,7 @@ def test_parse_nested_with_data_key(web_request): } args = {"nested_arg": fields.Nested({"right": fields.Field(**data_key_kwarg)})} - parsed = parser.parse(args, web_request, locations=("json",)) + parsed = parser.parse(args, web_request) assert parsed == {"nested_arg": {"right": "OK"}} @@ -513,7 +502,7 @@ def test_parse_nested_with_missing_key_and_data_key(web_request): ) } - parsed = parser.parse(args, web_request, locations=("json",)) + parsed = parser.parse(args, web_request) assert parsed == {"nested_arg": {"found": None}} @@ -523,7 +512,7 @@ def test_parse_nested_with_default(web_request): web_request.json = {"nested_arg": {}} args = {"nested_arg": fields.Nested({"miss": fields.Field(missing="")})} - parsed = parser.parse(args, web_request, locations=("json",)) + parsed = parser.parse(args, web_request) assert parsed == {"nested_arg": {"miss": ""}} @@ -554,8 +543,8 @@ def test_use_args_stacked(web_request, parser): web_request.json = {"username": "foo"} web_request.query = {"page": 42} - @parser.use_args(query_args, web_request, locations=("query",)) - @parser.use_args(json_args, web_request, locations=("json",)) + @parser.use_args(query_args, web_request, location="query") + @parser.use_args(json_args, web_request) def viewfunc(query_parsed, json_parsed): return {"json": json_parsed, "query": query_parsed} @@ -570,8 +559,8 @@ def test_use_kwargs_stacked(web_request, parser): web_request.json = {"username": "foo"} web_request.query = {"page": 42} - @parser.use_kwargs(query_args, web_request, locations=("query",)) - @parser.use_kwargs(json_args, web_request, locations=("json",)) + @parser.use_kwargs(query_args, web_request, location="query") + @parser.use_kwargs(json_args, web_request) def viewfunc(page, username): return {"json": {"username": username}, "query": {"page": page}} @@ -592,21 +581,21 @@ def viewfunc(*args, **kwargs): def test_list_allowed_missing(web_request, parser): args = {"name": fields.List(fields.Str())} - web_request.json = {"fakedata": True} + web_request.json = {} result = parser.parse(args, web_request) assert result == {} def test_int_list_allowed_missing(web_request, parser): args = {"name": fields.List(fields.Int())} - web_request.json = {"fakedata": True} + web_request.json = {} result = parser.parse(args, web_request) assert result == {} def test_multiple_arg_required_with_int_conversion(web_request, parser): args = {"ids": fields.List(fields.Int(), required=True)} - web_request.json = {"fakedata": True} + web_request.json = {} with pytest.raises(ValidationError) as excinfo: parser.parse(args, web_request) assert excinfo.value.messages == {"ids": ["Missing data for required field."]} @@ -747,10 +736,22 @@ def test_warning_raised_if_schema_is_not_in_strict_mode(self, web_request, parse assert "strict=True" in str(warning.message) def test_use_kwargs_stacked(self, web_request, parser): + if MARSHMALLOW_VERSION_INFO[0] >= 3: + from marshmallow import EXCLUDE + + class PageSchema(Schema): + page = fields.Int() + + pageschema = PageSchema(unknown=EXCLUDE) + userschema = self.UserSchema(unknown=EXCLUDE) + else: + pageschema = {"page": fields.Int()} + userschema = self.UserSchema(**strict_kwargs) + web_request.json = {"email": "foo@bar.com", "password": "bar", "page": 42} - @parser.use_kwargs({"page": fields.Int()}, web_request) - @parser.use_kwargs(self.UserSchema(**strict_kwargs), web_request) + @parser.use_kwargs(pageschema, web_request) + @parser.use_kwargs(userschema, web_request) def viewfunc(email, password, page): return {"email": email, "password": password, "page": page} @@ -774,18 +775,18 @@ def validate_schema(self, data, original_data, **kwargs): return True web_request.json = {"name": "Eric Cartman"} - res = parser.parse(UserSchema, web_request, locations=("json",)) + res = parser.parse(UserSchema, web_request) assert res == {"name": "Eric Cartman"} -def test_use_args_with_custom_locations_in_parser(web_request, parser): +def test_use_args_with_custom_location_in_parser(web_request, parser): custom_args = {"foo": fields.Str()} web_request.json = {} - parser.locations = ("custom",) + parser.location = "custom" - @parser.location_handler("custom") - def parse_custom(req, name, arg): - return "bar" + @parser.location_loader("custom") + def load_custom(schema, req): + return {"foo": "bar"} @parser.use_args(custom_args, web_request) def viewfunc(args): @@ -913,16 +914,6 @@ def test_type_conversion_with_multiple_required(web_request, parser): parser.parse(args, web_request) -def test_arg_location_param(web_request, parser): - web_request.json = {"foo": 24} - web_request.cookies = {"foo": 42} - args = {"foo": fields.Field(location="cookies")} - - parsed = parser.parse(args, web_request) - - assert parsed["foo"] == 42 - - def test_validation_errors_in_validator_are_passed_to_handle_error(parser, web_request): def validate(value): raise ValidationError("Something went wrong.") @@ -1041,23 +1032,23 @@ def test_parse_with_error_status_code_and_headers(web_request): assert error.headers == {"X-Foo": "bar"} -@mock.patch("webargs.core.Parser.parse_json") -def test_custom_schema_class(parse_json, web_request): +@mock.patch("webargs.core.Parser.load_json") +def test_custom_schema_class(load_json, web_request): class CustomSchema(Schema): @pre_load def pre_load(self, data, **kwargs): data["value"] += " world" return data - parse_json.return_value = "hello" + load_json.return_value = {"value": "hello"} argmap = {"value": fields.Str()} p = Parser(schema_class=CustomSchema) ret = p.parse(argmap, web_request) assert ret == {"value": "hello world"} -@mock.patch("webargs.core.Parser.parse_json") -def test_custom_default_schema_class(parse_json, web_request): +@mock.patch("webargs.core.Parser.load_json") +def test_custom_default_schema_class(load_json, web_request): class CustomSchema(Schema): @pre_load def pre_load(self, data, **kwargs): @@ -1067,7 +1058,7 @@ def pre_load(self, data, **kwargs): class CustomParser(Parser): DEFAULT_SCHEMA_CLASS = CustomSchema - parse_json.return_value = "hello" + load_json.return_value = {"value": "hello"} argmap = {"value": fields.Str()} p = CustomParser() ret = p.parse(argmap, web_request) diff --git a/tests/test_flaskparser.py b/tests/test_flaskparser.py index 5122196c..501075b6 100644 --- a/tests/test_flaskparser.py +++ b/tests/test_flaskparser.py @@ -7,7 +7,7 @@ import pytest from flask import Flask -from webargs import fields, ValidationError, missing +from webargs import fields, ValidationError, missing, dict2schema from webargs.flaskparser import parser, abort from webargs.core import MARSHMALLOW_VERSION_INFO, json @@ -33,23 +33,31 @@ def test_use_args_with_view_args_parsing(self, testapp): assert res.json == {"view_arg": 42} def test_use_args_on_a_method_view(self, testapp): - res = testapp.post("/echo_method_view_use_args", {"val": 42}) + res = testapp.post_json("/echo_method_view_use_args", {"val": 42}) assert res.json == {"val": 42} def test_use_kwargs_on_a_method_view(self, testapp): - res = testapp.post("/echo_method_view_use_kwargs", {"val": 42}) + res = testapp.post_json("/echo_method_view_use_kwargs", {"val": 42}) assert res.json == {"val": 42} def test_use_kwargs_with_missing_data(self, testapp): - res = testapp.post("/echo_use_kwargs_missing", {"username": "foo"}) + res = testapp.post_json("/echo_use_kwargs_missing", {"username": "foo"}) assert res.json == {"username": "foo"} # regression test for https://github.com/marshmallow-code/webargs/issues/145 def test_nested_many_with_data_key(self, testapp): - res = testapp.post_json("/echo_nested_many_data_key", {"x_field": [{"id": 42}]}) - # https://github.com/marshmallow-code/marshmallow/pull/714 + post_with_raw_fieldname_args = ( + "/echo_nested_many_data_key", + {"x_field": [{"id": 42}]}, + ) + # under marhsmallow2 this is allowed and works if MARSHMALLOW_VERSION_INFO[0] < 3: + res = testapp.post_json(*post_with_raw_fieldname_args) assert res.json == {"x_field": [{"id": 42}]} + # but under marshmallow3 , only data_key is checked, field name is ignored + else: + res = testapp.post_json(*post_with_raw_fieldname_args, expect_errors=True) + assert res.status_code == 422 res = testapp.post_json("/echo_nested_many_data_key", {"X-Field": [{"id": 24}]}) assert res.json == {"x_field": [{"id": 24}]} @@ -81,10 +89,13 @@ def validate(x): assert type(abort_kwargs["exc"]) == ValidationError -def test_parse_form_returns_missing_if_no_form(): +@pytest.mark.parametrize("mimetype", [None, "application/json"]) +def test_load_json_returns_missing_if_no_data(mimetype): req = mock.Mock() - req.form.get.side_effect = AttributeError("no form") - assert parser.parse_form(req, "foo", fields.Field()) is missing + req.mimetype = mimetype + req.get_data.return_value = "" + schema = dict2schema({"foo": fields.Field()})() + assert parser.load_json(req, schema) is missing def test_abort_with_message():