diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 0faa0bd4..fc9f8781 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,58 @@ Changelog --------- +6.0.0 (unreleased) +****************** + +Features: + +* *Backwards-incompatible*: Schemas will now load all data from a location, not + only data specified by fields. As a result, schemas with validators which + examine the full input data may change in behavior. The `unknown` parameter + on schemas may be used to alter this. For example, + `unknown=marshmallow.EXCLUDE` will produce behavior similar to webargs v5 + +Bug fixes: + +* *Backwards-incompatible*: all parsers now require the Content-Type to be set + correctly when processing JSON request bodies. This impacts ``DjangoParser``, + ``FalconParser``, ``FlaskParser``, and ``PyramidParser`` + +Refactoring: + +* *Backwards-incompatible*: Schema fields may not specify a location any + longer, and `Parser.use_args` and `Parser.use_kwargs` now accept `location` + (singular) instead of `locations` (plural). Instead of using a single field or + schema with multiple `locations`, users are recommended to make multiple + calls to `use_args` or `use_kwargs` with a distinct schema per location. For + example, code should be rewritten like this: + +.. code-block:: python + + # under webargs v5 + @parser.use_args( + { + "q1": ma.fields.Int(location="query"), + "q2": ma.fields.Int(location="query"), + "h1": ma.fields.Int(location="headers"), + }, + locations=("query", "headers"), + ) + def foo(q1, q2, h1): + ... + + + # should be split up like so under webargs v6 + @parser.use_args({"q1": ma.fields.Int(), "q2": ma.fields.Int()}, location="query") + @parser.use_args({"h1": ma.fields.Int()}, location="headers") + def foo(q1, q2, h1): + ... + +* The `location_handler` decorator has been removed and replaced with + `location_loader`. `location_loader` serves the same purpose (letting you + write custom hooks for loading data) but its expected method signature is + different. See the docs on `location_loader` for proper usage. + 5.5.2 (2019-10-06) ****************** diff --git a/docs/advanced.rst b/docs/advanced.rst index 4b16bafb..e29c9cd1 100644 --- a/docs/advanced.rst +++ b/docs/advanced.rst @@ -6,7 +6,7 @@ This section includes guides for advanced usage patterns. Custom Location Handlers ------------------------ -To add your own custom location handler, write a function that receives a request, an argument name, and a :class:`Field `, then decorate that function with :func:`Parser.location_handler `. +To add your own custom location handler, write a function that receives a request, and a :class:`Schema `, then decorate that function with :func:`Parser.location_loader `. .. code-block:: python @@ -15,17 +15,78 @@ To add your own custom location handler, write a function that receives a reques from webargs.flaskparser import parser - @parser.location_handler("data") - def parse_data(request, name, field): - return request.data.get(name) + @parser.location_loader("data") + def load_data(request, schema): + return request.data # Now 'data' can be specified as a location - @parser.use_args({"per_page": fields.Int()}, locations=("data",)) + @parser.use_args({"per_page": fields.Int()}, location="data") def posts(args): return "displaying {} posts".format(args["per_page"]) +.. NOTE:: + + The schema is passed so that it can be used to wrap multidict types and + unpack List fields correctly. If you are writing a loader for a multidict + type, consider looking at + :class:`MultiDictProxy ` for an + example of how to do this. + +"meta" Locations +~~~~~~~~~~~~~~~~ + +You can define your own locations which mix data from several existing +locations. + +The `json_or_form` location does this -- first trying to load data as JSON and +then falling back to a form body -- and its implementation is quite simple: + + +.. code-block:: python + + def load_json_or_form(self, req, schema): + """Load data from a request, accepting either JSON or form-encoded + data. + + The data will first be loaded as JSON, and, if that fails, it will be + loaded as a form post. + """ + data = self.load_json(req, schema) + if data is not missing: + return data + return self.load_form(req, schema) + + +You can imagine your own locations with custom behaviors like this. +For example, to mix query parameters and form body data, you might write the +following: + +.. code-block:: python + + from webargs import fields + from webargs.multidictproxy import MultiDictProxy + from webargs.flaskparser import parser + + + @parser.location_loader("query_and_form") + def load_data(request, schema): + # relies on the Flask (werkzeug) MultiDict type's implementation of + # these methods, but when you're extending webargs, you may know things + # about your framework of choice + newdata = request.args.copy() + newdata.update(request.form) + return MultiDictProxy(newdata, schema) + + + # Now 'query_and_form' means you can send these values in either location, + # and they will be *mixed* together into a new dict to pass to your schema + @parser.use_args({"favorite_food": fields.String()}, location="query_and_form") + def set_favorite_food(args): + ... # do stuff + return "your favorite food is now set to {}".format(args["favorite_food"]) + marshmallow Integration ----------------------- @@ -64,7 +125,7 @@ When you need more flexibility in defining input schemas, you can pass a marshma # You can add additional parameters - @use_kwargs({"posts_per_page": fields.Int(missing=10, location="query")}) + @use_kwargs({"posts_per_page": fields.Int(missing=10)}, location="query") @use_args(UserSchema()) def profile_posts(args, posts_per_page): username = args["username"] @@ -211,12 +272,12 @@ Using the :class:`Method ` and :class:`Function ` and implement the `parse_*` method(s) you need to override. For example, here is a custom Flask parser that handles nested query string arguments. +To add your own parser, extend :class:`Parser ` and implement the `load_*` method(s) you need to override. For example, here is a custom Flask parser that handles nested query string arguments. .. code-block:: python @@ -245,8 +306,8 @@ To add your own parser, extend :class:`Parser ` and impleme } """ - def parse_querystring(self, req, name, field): - return core.get_value(_structure_dict(req.args), name, field) + def load_querystring(self, req, schema): + return _structure_dict(req.args) def _structure_dict(dict_): @@ -309,7 +370,7 @@ For example, you might implement JSON PATCH according to `RFC 6902 ` call: .. code-block:: python + # "json" is the default, used explicitly below @app.route("/stacked", methods=["POST"]) - @use_args( - { - "page": fields.Int(location="query"), - "q": fields.Str(location="query"), - "name": fields.Str(location="json"), - } - ) - def viewfunc(args): - page = args["page"] - # ... - -Alternatively, you can pass multiple locations to `use_args `: - -.. code-block:: python - - @app.route("/stacked", methods=["POST"]) - @use_args( - {"page": fields.Int(), "q": fields.Str(), "name": fields.Str()}, - locations=("query", "json"), - ) - def viewfunc(args): - page = args["page"] - # ... - -However, this allows ``page`` and ``q`` to be passed in the request body and ``name`` to be passed as a query parameter. - -To restrict the arguments to single locations without having to pass ``location`` to every field, you can call the `use_args ` multiple times: - -.. code-block:: python - - query_args = {"page": fields.Int(), "q": fields.Int()} - json_args = {"name": fields.Str()} - - - @app.route("/stacked", methods=["POST"]) - @use_args(query_args, locations=("query",)) - @use_args(json_args, locations=("json",)) + @use_args({"page": fields.Int(), "q": fields.Str()}, location="query") + @use_args({"name": fields.Str()}, location="json") def viewfunc(query_parsed, json_parsed): page = query_parsed["page"] name = json_parsed["name"] @@ -377,12 +404,12 @@ To reduce boilerplate, you could create shortcuts, like so: import functools - query = functools.partial(use_args, locations=("query",)) - body = functools.partial(use_args, locations=("json",)) + query = functools.partial(use_args, location="query") + body = functools.partial(use_args, location="json") - @query(query_args) - @body(json_args) + @query({"page": fields.Int(), "q": fields.Int()}) + @body({"name": fields.Str()}) def viewfunc(query_parsed, json_parsed): page = query_parsed["page"] name = json_parsed["name"] diff --git a/docs/framework_support.rst b/docs/framework_support.rst index c1b2f9a2..e58e5cd1 100644 --- a/docs/framework_support.rst +++ b/docs/framework_support.rst @@ -22,9 +22,9 @@ When using the :meth:`use_args ` decor @app.route("/user/") - @use_args({"per_page": fields.Int()}) + @use_args({"per_page": fields.Int()}, location="query") def user_detail(args, uid): - return ("The user page for user {uid}, " "showing {per_page} posts.").format( + return ("The user page for user {uid}, showing {per_page} posts.").format( uid=uid, per_page=args["per_page"] ) @@ -64,7 +64,7 @@ The `FlaskParser` supports parsing values from a request's ``view_args``. @app.route("/greeting//") - @use_args({"name": fields.Str(location="view_args")}) + @use_args({"name": fields.Str()}, location="view_args") def greeting(args, **kwargs): return "Hello {}".format(args["name"]) @@ -95,7 +95,7 @@ When using the :meth:`use_args ` dec } - @use_args(account_args) + @use_args(account_args, location="form") def login_user(request, args): if request.method == "POST": login(args["username"], args["password"]) @@ -114,7 +114,7 @@ When using the :meth:`use_args ` dec class BlogPostView(View): - @use_args(blog_args) + @use_args(blog_args, location="query") def get(self, request, args): blog_post = Post.objects.get(title__iexact=args["title"], author=args["author"]) return render_to_response("post_template.html", {"post": blog_post}) @@ -239,7 +239,7 @@ When using the :meth:`use_args ` d from webargs.pyramidparser import use_args - @use_args({"uid": fields.Str(), "per_page": fields.Int()}) + @use_args({"uid": fields.Str(), "per_page": fields.Int()}, location="query") def user_detail(request, args): uid = args["uid"] return Response( @@ -261,7 +261,7 @@ The `PyramidParser` supports parsing values from a request's matchdict. from webargs.pyramidparser import use_args - @use_args({"mymatch": fields.Int()}, locations=("matchdict",)) + @use_args({"mymatch": fields.Int()}, location="matchdict") def matched(request, args): return Response("The value for mymatch is {}".format(args["mymatch"])) @@ -317,7 +317,7 @@ You can easily implement hooks by using `parser.parse ` supports parsing value from webargs.aiohttpparser import use_args - @parser.use_args({"slug": fields.Str(location="match_info")}) + @parser.use_args({"slug": fields.Str()}, location="match_info") def article_detail(request, args): return web.Response(body="Slug: {}".format(args["slug"]).encode("utf-8")) diff --git a/docs/index.rst b/docs/index.rst index e152b9f3..6e2ae812 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -17,7 +17,7 @@ webargs is a Python library for parsing and validating HTTP request objects, wit @app.route("/") - @use_args({"name": fields.Str(required=True)}) + @use_args({"name": fields.Str(required=True)}, location="query") def index(args): return "Hello " + args["name"] @@ -28,13 +28,15 @@ webargs is a Python library for parsing and validating HTTP request objects, wit # curl http://localhost:5000/\?name\='World' # Hello World -Webargs will automatically parse: +By default Webargs will automatically parse JSON request bodies. But it also +has support for: **Query Parameters** :: + $ curl http://localhost:5000/\?name\='Freddie' + Hello Freddie - $ curl http://localhost:5000/\?name\='Freddie' - Hello Freddie + # pass location="query" to use_args **Form Data** :: @@ -42,12 +44,16 @@ Webargs will automatically parse: $ curl -d 'name=Brian' http://localhost:5000/ Hello Brian + # pass location="form" to use_args + **JSON Data** :: $ curl -X POST -H "Content-Type: application/json" -d '{"name":"Roger"}' http://localhost:5000/ Hello Roger + # pass location="json" (or omit location) to use_args + and, optionally: - Headers diff --git a/docs/quickstart.rst b/docs/quickstart.rst index b43371d6..e5ff7ee1 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -23,17 +23,11 @@ Arguments are specified as a dictionary of name -> :class:`Field ` to add nested field functionality to the other locations. + Of the default supported locations in webargs, only the ``json`` request location supports nested datastructures. You can, however, :ref:`implement your own data loader ` to add nested field functionality to the other locations. Next Steps ---------- diff --git a/src/webargs/aiohttpparser.py b/src/webargs/aiohttpparser.py index 9e76e05a..ad170ce4 100644 --- a/src/webargs/aiohttpparser.py +++ b/src/webargs/aiohttpparser.py @@ -28,11 +28,11 @@ def index(request, args): from aiohttp.web import Request from aiohttp import web_exceptions from marshmallow import Schema, ValidationError -from marshmallow.fields import Field from webargs import core from webargs.core import json from webargs.asyncparser import AsyncParser +from webargs.multidictproxy import MultiDictProxy def is_json_request(req: Request) -> bool: @@ -73,24 +73,32 @@ class AIOHTTPParser(AsyncParser): """aiohttp request argument parser.""" __location_map__ = dict( - match_info="parse_match_info", - path="parse_match_info", + match_info="load_match_info", + path="load_match_info", **core.Parser.__location_map__ ) - def parse_querystring(self, req: Request, name: str, field: Field) -> typing.Any: - """Pull a querystring value from the request.""" - return core.get_value(req.query, name, field) + def load_querystring(self, req: Request, schema: Schema) -> MultiDictProxy: + """Return query params from the request as a MultiDictProxy.""" + return MultiDictProxy(req.query, schema) - async def parse_form(self, req: Request, name: str, field: Field) -> typing.Any: - """Pull a form value from the request.""" + async def load_form(self, req: Request, schema: Schema) -> MultiDictProxy: + """Return form values from the request as a MultiDictProxy.""" post_data = self._cache.get("post") if post_data is None: self._cache["post"] = await req.post() - return core.get_value(self._cache["post"], name, field) - - async def parse_json(self, req: Request, name: str, field: Field) -> typing.Any: - """Pull a json value from the request.""" + return MultiDictProxy(self._cache["post"], schema) + + async def load_json_or_form( + self, req: Request, schema: Schema + ) -> typing.Union[typing.Dict, MultiDictProxy]: + data = await self.load_json(req, schema) + if data is not core.missing: + return data + return await self.load_form(req, schema) + + async def load_json(self, req: Request, schema: Schema) -> typing.Dict: + """Return a parsed json payload from the request.""" json_data = self._cache.get("json") if json_data is None: if not (req.body_exists and is_json_request(req)): @@ -101,30 +109,30 @@ async def parse_json(self, req: Request, name: str, field: Field) -> typing.Any: if e.doc == "": return core.missing else: - return self.handle_invalid_json_error(e, req) + return self._handle_invalid_json_error(e, req) except UnicodeDecodeError as e: - return self.handle_invalid_json_error(e, req) + return self._handle_invalid_json_error(e, req) self._cache["json"] = json_data - return core.get_value(json_data, name, field, allow_many_nested=True) + return json_data - def parse_headers(self, req: Request, name: str, field: Field) -> typing.Any: - """Pull a value from the header data.""" - return core.get_value(req.headers, name, field) + def load_headers(self, req: Request, schema: Schema) -> MultiDictProxy: + """Return headers from the request as a MultiDictProxy.""" + return MultiDictProxy(req.headers, schema) - def parse_cookies(self, req: Request, name: str, field: Field) -> typing.Any: - """Pull a value from the cookiejar.""" - return core.get_value(req.cookies, name, field) + def load_cookies(self, req: Request, schema: Schema) -> MultiDictProxy: + """Return cookies from the request as a MultiDictProxy.""" + return MultiDictProxy(req.cookies, schema) - def parse_files(self, req: Request, name: str, field: Field) -> None: + def load_files(self, req: Request, schema: Schema) -> "typing.NoReturn": raise NotImplementedError( - "parse_files is not implemented. You may be able to use parse_form for " + "load_files is not implemented. You may be able to use load_form for " "parsing upload data." ) - def parse_match_info(self, req: Request, name: str, field: Field) -> typing.Any: - """Pull a value from the request's ``match_info``.""" - return core.get_value(req.match_info, name, field) + def load_match_info(self, req: Request, schema: Schema) -> typing.Mapping: + """Load the request's ``match_info``.""" + return req.match_info def get_request_from_view_args( self, view: typing.Callable, args: typing.Iterable, kwargs: typing.Mapping @@ -140,7 +148,8 @@ def get_request_from_view_args( elif isinstance(arg, web.View): req = arg.request break - assert isinstance(req, web.Request), "Request argument not found for handler" + if not isinstance(req, web.Request): + raise ValueError("Request argument not found for handler") return req def handle_error( @@ -166,7 +175,7 @@ def handle_error( content_type="application/json", ) - def handle_invalid_json_error( + def _handle_invalid_json_error( self, error: typing.Union[json.JSONDecodeError, UnicodeDecodeError], req: Request, diff --git a/src/webargs/asyncparser.py b/src/webargs/asyncparser.py index 82369018..f67dc57c 100644 --- a/src/webargs/asyncparser.py +++ b/src/webargs/asyncparser.py @@ -8,7 +8,6 @@ from marshmallow import Schema, ValidationError from marshmallow.fields import Field import marshmallow as ma -from marshmallow.utils import missing from webargs import core @@ -22,53 +21,12 @@ class AsyncParser(core.Parser): either coroutines or regular methods. """ - async def _parse_request( - self, schema: Schema, req: Request, locations: typing.Iterable - ) -> typing.Union[dict, list]: - if schema.many: - assert ( - "json" in locations - ), "schema.many=True is only supported for JSON location" - # The ad hoc Nested field is more like a workaround or a helper, - # and it servers its purpose fine. However, if somebody has a desire - # to re-design the support of bulk-type arguments, go ahead. - parsed = await self.parse_arg( - name="json", - field=ma.fields.Nested(schema, many=True), - req=req, - locations=locations, - ) - if parsed is missing: - parsed = [] - else: - argdict = schema.fields - parsed = {} - for argname, field_obj in argdict.items(): - if core.MARSHMALLOW_VERSION_INFO[0] < 3: - parsed_value = await self.parse_arg( - argname, field_obj, req, locations - ) - # If load_from is specified on the field, try to parse from that key - if parsed_value is missing and field_obj.load_from: - parsed_value = await self.parse_arg( - field_obj.load_from, field_obj, req, locations - ) - argname = field_obj.load_from - else: - argname = field_obj.data_key or argname - parsed_value = await self.parse_arg( - argname, field_obj, req, locations - ) - if parsed_value is not missing: - parsed[argname] = parsed_value - return parsed - # TODO: Lots of duplication from core.Parser here. Rethink. async def parse( self, argmap: ArgMap, req: Request = None, - locations: typing.Iterable = None, + location: str = None, validate: Validate = None, error_status_code: typing.Union[int, None] = None, error_headers: typing.Union[typing.Mapping[str, str], None] = None, @@ -77,17 +35,18 @@ async def parse( Receives the same arguments as `webargs.core.Parser.parse`. """ - self.clear_cache() # in case someone used `parse_*()` + self.clear_cache() # in case someone used `location_load_*()` req = req if req is not None else self.get_default_request() - assert req is not None, "Must pass req object" + if req is None: + raise ValueError("Must pass req object") data = None validators = core._ensure_list_of_callables(validate) schema = self._get_schema(argmap, req) try: - parsed = await self._parse_request( - schema=schema, req=req, locations=locations or self.locations + location_data = await self._load_location_data( + schema=schema, req=req, location=location or self.location ) - result = schema.load(parsed) + result = schema.load(location_data) data = result.data if core.MARSHMALLOW_VERSION_INFO[0] < 3 else result self._validate_arguments(data, validators) except ma.exceptions.ValidationError as error: @@ -96,6 +55,25 @@ async def parse( ) return data + async def _load_location_data(self, schema, req, location): + """Return a dictionary-like object for the location on the given request. + + Needs to have the schema in hand in order to correctly handle loading + lists from multidict objects and `many=True` schemas. + """ + loader_func = self._get_loader(location) + if asyncio.iscoroutinefunction(loader_func): + data = await loader_func(req, schema) + else: + data = loader_func(req, schema) + + # when the desired location is empty (no data), provide an empty + # dict as the default so that optional arguments in a location + # (e.g. optional JSON body) work smoothly + if data is core.missing: + data = {} + return data + async def _on_validation_error( self, error: ValidationError, @@ -111,7 +89,7 @@ def use_args( self, argmap: ArgMap, req: typing.Optional[Request] = None, - locations: typing.Iterable = None, + location: str = None, as_kwargs: bool = False, validate: Validate = None, error_status_code: typing.Optional[int] = None, @@ -121,7 +99,7 @@ def use_args( Receives the same arguments as `webargs.core.Parser.use_args`. """ - locations = locations or self.locations + location = location or self.location request_obj = req # Optimization: If argmap is passed as a dictionary, we only need # to generate a Schema once @@ -143,7 +121,7 @@ async def wrapper(*args, **kwargs): parsed_args = await self.parse( argmap, req=req_obj, - locations=locations, + location=location, validate=validate, error_status_code=error_status_code, error_headers=error_headers, @@ -168,7 +146,7 @@ def wrapper(*args, **kwargs): parsed_args = yield from self.parse( # type: ignore argmap, req=req_obj, - locations=locations, + location=location, validate=validate, error_status_code=error_status_code, error_headers=error_headers, @@ -192,29 +170,3 @@ def use_kwargs(self, *args, **kwargs) -> typing.Callable: """ return super().use_kwargs(*args, **kwargs) - - async def parse_arg( - self, name: str, field: Field, req: Request, locations: typing.Iterable = None - ) -> typing.Any: - location = field.metadata.get("location") - if location: - locations_to_check = self._validated_locations([location]) - else: - locations_to_check = self._validated_locations(locations or self.locations) - - for location in locations_to_check: - value = await self._get_value(name, field, req=req, location=location) - # Found the value; validate and return it - if value is not core.missing: - return value - return core.missing - - async def _get_value( - self, name: str, argobj: Field, req: Request, location: str - ) -> typing.Any: - function = self._get_handler(location) - if asyncio.iscoroutinefunction(function): - value = await function(req, name, argobj) - else: - value = function(req, name, argobj) - return value diff --git a/src/webargs/bottleparser.py b/src/webargs/bottleparser.py index 568dc658..37a501a1 100644 --- a/src/webargs/bottleparser.py +++ b/src/webargs/bottleparser.py @@ -20,51 +20,54 @@ def index(args): import bottle from webargs import core -from webargs.core import json +from webargs.multidictproxy import MultiDictProxy class BottleParser(core.Parser): """Bottle.py request argument parser.""" - def parse_querystring(self, req, name, field): - """Pull a querystring value from the request.""" - return core.get_value(req.query, name, field) - - def parse_form(self, req, name, field): - """Pull a form value from the request.""" - return core.get_value(req.forms, name, field) - - def parse_json(self, req, name, field): - """Pull a json value from the request.""" - json_data = self._cache.get("json") - if json_data is None: - try: - self._cache["json"] = json_data = req.json - except AttributeError: - return core.missing - except json.JSONDecodeError as e: - if e.doc == "": - return core.missing - else: - return self.handle_invalid_json_error(e, req) - except UnicodeDecodeError as e: - return self.handle_invalid_json_error(e, req) - - if json_data is None: - return core.missing - return core.get_value(json_data, name, field, allow_many_nested=True) - - def parse_headers(self, req, name, field): - """Pull a value from the header data.""" - return core.get_value(req.headers, name, field) - - def parse_cookies(self, req, name, field): - """Pull a value from the cookiejar.""" - return req.get_cookie(name) - - def parse_files(self, req, name, field): - """Pull a file from the request.""" - return core.get_value(req.files, name, field) + def _handle_invalid_json_error(self, error, req, *args, **kwargs): + raise bottle.HTTPError( + status=400, body={"json": ["Invalid JSON body."]}, exception=error + ) + + def _raw_load_json(self, req): + """Read a json payload from the request.""" + try: + data = req.json + except AttributeError: + return core.missing + + # unfortunately, bottle does not distinguish between an emtpy body, "", + # and a body containing the valid JSON value null, "null" + # so these can't be properly disambiguated + # as our best-effort solution, treat None as missing and ignore the + # (admittedly unusual) "null" case + # see: https://github.com/bottlepy/bottle/issues/1160 + if data is None: + return core.missing + else: + return data + + def load_querystring(self, req, schema): + """Return query params from the request as a MultiDictProxy.""" + return MultiDictProxy(req.query, schema) + + def load_form(self, req, schema): + """Return form values from the request as a MultiDictProxy.""" + return MultiDictProxy(req.forms, schema) + + def load_headers(self, req, schema): + """Return headers from the request as a MultiDictProxy.""" + return MultiDictProxy(req.headers, schema) + + def load_cookies(self, req, schema): + """Return cookies from the request.""" + return req.cookies + + def load_files(self, req, schema): + """Return files from the request as a MultiDictProxy.""" + return MultiDictProxy(req.files, schema) def handle_error(self, error, req, schema, error_status_code, error_headers): """Handles errors during parsing. Aborts the current request with a @@ -78,11 +81,6 @@ def handle_error(self, error, req, schema, error_status_code, error_headers): exception=error, ) - def handle_invalid_json_error(self, error, req, *args, **kwargs): - raise bottle.HTTPError( - status=400, body={"json": ["Invalid JSON body."]}, exception=error - ) - def get_default_request(self): """Override to use bottle's thread-local request object by default.""" return bottle.request diff --git a/src/webargs/core.py b/src/webargs/core.py index fe2f39b6..ea1b9307 100644 --- a/src/webargs/core.py +++ b/src/webargs/core.py @@ -14,9 +14,9 @@ import marshmallow as ma from marshmallow import ValidationError -from marshmallow.utils import missing, is_collection +from marshmallow.utils import missing -from webargs.compat import Mapping, iteritems, MARSHMALLOW_VERSION_INFO +from webargs.compat import Mapping, MARSHMALLOW_VERSION_INFO from webargs.dict2schema import dict2schema from webargs.fields import DelimitedList @@ -28,7 +28,6 @@ "dict2schema", "is_multiple", "Parser", - "get_value", "missing", "parse_json", ] @@ -74,42 +73,6 @@ def is_json(mimetype): return False -def get_value(data, name, field, allow_many_nested=False): - """Get a value from a dictionary. Handles ``MultiDict`` types when - ``field`` handles repeated/multi-value arguments. - If the value is not found, return `missing`. - - :param object data: Mapping (e.g. `dict`) or list-like instance to - pull the value from. - :param str name: Name of the key. - :param bool allow_many_nested: Whether to allow a list of nested objects - (it is valid only for JSON format, so it is set to True in ``parse_json`` - methods). - """ - missing_value = missing - if allow_many_nested and isinstance(field, ma.fields.Nested) and field.many: - if is_collection(data): - return data - - if not hasattr(data, "get"): - return missing_value - - multiple = is_multiple(field) - val = data.get(name, missing_value) - if multiple and val is not missing: - if hasattr(data, "getlist"): - return data.getlist(name) - elif hasattr(data, "getall"): - return data.getall(name) - elif isinstance(val, (list, tuple)): - return val - if val is None: - return None - else: - return [val] - return val - - def parse_json(s, encoding="utf-8"): if isinstance(s, bytes): try: @@ -142,15 +105,16 @@ class Parser(object): """Base parser class that provides high-level implementation for parsing a request. - Descendant classes must provide lower-level implementations for parsing - different locations, e.g. ``parse_json``, ``parse_querystring``, etc. + Descendant classes must provide lower-level implementations for reading + data from different locations, e.g. ``load_json``, ``load_querystring``, + etc. - :param tuple locations: Default locations to parse. + :param str location: Default location to use for data :param callable error_handler: Custom error handler function. """ - #: Default locations to check for data - DEFAULT_LOCATIONS = ("querystring", "form", "json") + #: Default location to check for data + DEFAULT_LOCATION = "json" #: The marshmallow Schema class to use when creating new schemas DEFAULT_SCHEMA_CLASS = ma.Schema #: Default status code to return for validation errors @@ -160,38 +124,33 @@ class Parser(object): #: Maps location => method name __location_map__ = { - "json": "parse_json", - "querystring": "parse_querystring", - "query": "parse_querystring", - "form": "parse_form", - "headers": "parse_headers", - "cookies": "parse_cookies", - "files": "parse_files", + "json": "load_json", + "querystring": "load_querystring", + "query": "load_querystring", + "form": "load_form", + "headers": "load_headers", + "cookies": "load_cookies", + "files": "load_files", + "json_or_form": "load_json_or_form", } - def __init__(self, locations=None, error_handler=None, schema_class=None): - self.locations = locations or self.DEFAULT_LOCATIONS + def __init__(self, location=None, error_handler=None, schema_class=None): + self.location = location or self.DEFAULT_LOCATION self.error_callback = _callable_or_raise(error_handler) self.schema_class = schema_class or self.DEFAULT_SCHEMA_CLASS #: A short-lived cache to store results from processing request bodies. self._cache = {} - def _validated_locations(self, locations): - """Ensure that the given locations argument is valid. + def _get_loader(self, location): + """Get the loader function for the given location. - :raises: ValueError if a given locations includes an invalid location. + :raises: ValueError if a given location is invalid. """ - # The set difference between the given locations and the available locations - # will be the set of invalid locations valid_locations = set(self.__location_map__.keys()) - given = set(locations) - invalid_locations = given - valid_locations - if len(invalid_locations): - msg = "Invalid locations arguments: {0}".format(list(invalid_locations)) + if location not in valid_locations: + msg = "Invalid location argument: {0}".format(location) raise ValueError(msg) - return locations - def _get_handler(self, location): # Parsing function to call # May be a method name (str) or a function func = self.__location_map__.get(location) @@ -204,73 +163,20 @@ def _get_handler(self, location): raise ValueError('Invalid location: "{0}"'.format(location)) return function - def _get_value(self, name, argobj, req, location): - function = self._get_handler(location) - return function(req, name, argobj) - - def parse_arg(self, name, field, req, locations=None): - """Parse a single argument from a request. - - .. note:: - This method does not perform validation on the argument. + def _load_location_data(self, schema, req, location): + """Return a dictionary-like object for the location on the given request. - :param str name: The name of the value. - :param marshmallow.fields.Field field: The marshmallow `Field` for the request - parameter. - :param req: The request object to parse. - :param tuple locations: The locations ('json', 'querystring', etc.) where - to search for the value. - :return: The unvalidated argument value or `missing` if the value cannot - be found on the request. + Needs to have the schema in hand in order to correctly handle loading + lists from multidict objects and `many=True` schemas. """ - location = field.metadata.get("location") - if location: - locations_to_check = self._validated_locations([location]) - else: - locations_to_check = self._validated_locations(locations or self.locations) - - for location in locations_to_check: - value = self._get_value(name, field, req=req, location=location) - # Found the value; validate and return it - if value is not missing: - return value - return missing - - def _parse_request(self, schema, req, locations): - """Return a parsed arguments dictionary for the current request.""" - if schema.many: - assert ( - "json" in locations - ), "schema.many=True is only supported for JSON location" - # The ad hoc Nested field is more like a workaround or a helper, - # and it servers its purpose fine. However, if somebody has a desire - # to re-design the support of bulk-type arguments, go ahead. - parsed = self.parse_arg( - name="json", - field=ma.fields.Nested(schema, many=True), - req=req, - locations=locations, - ) - if parsed is missing: - parsed = [] - else: - argdict = schema.fields - parsed = {} - for argname, field_obj in iteritems(argdict): - if MARSHMALLOW_VERSION_INFO[0] < 3: - parsed_value = self.parse_arg(argname, field_obj, req, locations) - # If load_from is specified on the field, try to parse from that key - if parsed_value is missing and field_obj.load_from: - parsed_value = self.parse_arg( - field_obj.load_from, field_obj, req, locations - ) - argname = field_obj.load_from - else: - argname = field_obj.data_key or argname - parsed_value = self.parse_arg(argname, field_obj, req, locations) - if parsed_value is not missing: - parsed[argname] = parsed_value - return parsed + loader_func = self._get_loader(location) + data = loader_func(req, schema) + # when the desired location is empty (no data), provide an empty + # dict as the default so that optional arguments in a location + # (e.g. optional JSON body) work smoothly + if data is missing: + data = {} + return data def _on_validation_error( self, error, req, schema, error_status_code, error_headers @@ -310,6 +216,10 @@ def _get_schema(self, argmap, req): return schema def _clone(self): + """Clone the current parser in order to ensure that it has a fresh and + independent cache. This is used whenever `Parser.parse` is called, so + that these methods always have separate caches. + """ clone = copy(self) clone.clear_cache() return clone @@ -318,7 +228,7 @@ def parse( self, argmap, req=None, - locations=None, + location=None, validate=None, error_status_code=None, error_headers=None, @@ -329,9 +239,10 @@ def parse( of argname -> `marshmallow.fields.Field` pairs, or a callable which accepts a request and returns a `marshmallow.Schema`. :param req: The request object to parse. - :param tuple locations: Where on the request to search for values. - Can include one or more of ``('json', 'querystring', 'form', - 'headers', 'cookies', 'files')``. + :param str location: Where on the request to load values. + Can be any of the values in :py:attr:`~__location_map__`. By + default, that means one of ``('json', 'query', 'querystring', + 'form', 'headers', 'cookies', 'files', 'json_or_form')``. :param callable validate: Validation function or list of validation functions that receives the dictionary of parsed arguments. Validator either returns a boolean or raises a :exc:`ValidationError`. @@ -342,18 +253,18 @@ def parse( :return: A dictionary of parsed arguments """ - self.clear_cache() # in case someone used `parse_*()` req = req if req is not None else self.get_default_request() - assert req is not None, "Must pass req object" + if req is None: + raise ValueError("Must pass req object") data = None validators = _ensure_list_of_callables(validate) parser = self._clone() schema = self._get_schema(argmap, req) try: - parsed = parser._parse_request( - schema=schema, req=req, locations=locations or self.locations + location_data = parser._load_location_data( + schema=schema, req=req, location=location or self.location ) - result = schema.load(parsed) + result = schema.load(location_data) data = result.data if MARSHMALLOW_VERSION_INFO[0] < 3 else result parser._validate_arguments(data, validators) except ma.exceptions.ValidationError as error: @@ -397,7 +308,7 @@ def use_args( self, argmap, req=None, - locations=None, + location=None, as_kwargs=False, validate=None, error_status_code=None, @@ -408,14 +319,14 @@ def use_args( Example usage with Flask: :: @app.route('/echo', methods=['get', 'post']) - @parser.use_args({'name': fields.Str()}) + @parser.use_args({'name': fields.Str()}, location="querystring") def greet(args): return 'Hello ' + args['name'] :param argmap: Either a `marshmallow.Schema`, a `dict` of argname -> `marshmallow.fields.Field` pairs, or a callable which accepts a request and returns a `marshmallow.Schema`. - :param tuple locations: Where on the request to search for values. + :param str locations: Where on the request to load values. :param bool as_kwargs: Whether to insert arguments as keyword arguments. :param callable validate: Validation function that receives the dictionary of parsed arguments. If the function returns ``False``, the parser @@ -425,7 +336,7 @@ def greet(args): :param dict error_headers: Headers passed to error handler functions when a a `ValidationError` is raised. """ - locations = locations or self.locations + location = location or self.location request_obj = req # Optimization: If argmap is passed as a dictionary, we only need # to generate a Schema once @@ -441,11 +352,12 @@ def wrapper(*args, **kwargs): if not req_obj: req_obj = self.get_request_from_view_args(func, args, kwargs) + # NOTE: At this point, argmap may be a Schema, or a callable parsed_args = self.parse( argmap, req=req_obj, - locations=locations, + location=location, validate=validate, error_status_code=error_status_code, error_headers=error_headers, @@ -481,19 +393,23 @@ def greet(name): kwargs["as_kwargs"] = True return self.use_args(*args, **kwargs) - def location_handler(self, name): - """Decorator that registers a function for parsing a request location. - The wrapped function receives a request, the name of the argument, and - the corresponding `Field ` object. + def location_loader(self, name): + """Decorator that registers a function for loading a request location. + The wrapped function receives a schema and a request. + + The schema will usually not be relevant, but it's important in some + cases -- most notably in order to correctly load multidict values into + list fields. Without the schema, there would be no way to know whether + to simply `.get()` or `.getall()` from a multidict for a given value. Example: :: from webargs import core parser = core.Parser() - @parser.location_handler("name") - def parse_data(request, name, field): - return request.data.get(name) + @parser.location_loader("name") + def load_data(request, schema): + return request.data :param str name: The name of the location to register. """ @@ -531,41 +447,95 @@ def handle_error(error, req, schema, status_code, headers): self.error_callback = func return func + def _handle_invalid_json_error(self, error, req, *args, **kwargs): + """Internal hook for overriding treatment of JSONDecodeErrors. + + Invoked by default `load_json` implementation. + + External parsers can just implement their own behavior for load_json , + so this is not part of the public parser API. + """ + raise error + + def load_json(self, req, schema): + """Load JSON from a request object or return `missing` if no value can + be found. + """ + # NOTE: although this implementation is real/concrete and used by + # several of the parsers in webargs, it relies on the internal hooks + # `_handle_invalid_json_error` and `_raw_load_json` + # these methods are not part of the public API and are used to simplify + # code sharing amongst the built-in webargs parsers + if "json" not in self._cache: + try: + json_data = self._raw_load_json(req) + except json.JSONDecodeError as e: + if e.doc == "": + json_data = missing + else: + return self._handle_invalid_json_error(e, req) + except UnicodeDecodeError as e: + return self._handle_invalid_json_error(e, req) + self._cache["json"] = json_data + + return self._cache["json"] + + def load_json_or_form(self, req, schema): + """Load data from a request, accepting either JSON or form-encoded + data. + + The data will first be loaded as JSON, and, if that fails, it will be + loaded as a form post. + """ + data = self.load_json(req, schema) + if data is not missing: + return data + return self.load_form(req, schema) + # Abstract Methods - def parse_json(self, req, name, arg): - """Pull a JSON value from a request object or return `missing` if the - value cannot be found. + def _raw_load_json(self, req): + """Internal hook method for implementing load_json() + + Get a request body for feeding in to `load_json`, and parse it either + using core.parse_json() or similar utilities which raise + JSONDecodeErrors. + Ensure consistent behavior when encountering decoding errors. + + The default implementation here simply returns `missing`, and the default + implementation of `load_json` above will pass that value through. + However, by implementing a "mostly concrete" version of load_json with + this as a hook for getting data, we consolidate the logic for handling + those JSONDecodeErrors. """ return missing - def parse_querystring(self, req, name, arg): - """Pull a value from the query string of a request object or return `missing` if - the value cannot be found. + def load_querystring(self, req, schema): + """Load the query string of a request object or return `missing` if no + value can be found. """ return missing - def parse_form(self, req, name, arg): - """Pull a value from the form data of a request object or return - `missing` if the value cannot be found. + def load_form(self, req, schema): + """Load the form data of a request object or return `missing` if no + value can be found. """ return missing - def parse_headers(self, req, name, arg): - """Pull a value from the headers or return `missing` if the value - cannot be found. + def load_headers(self, req, schema): + """Load the headers or return `missing` if no value can be found. """ return missing - def parse_cookies(self, req, name, arg): - """Pull a cookie value from the request or return `missing` if the value - cannot be found. + def load_cookies(self, req, schema): + """Load the cookies from the request or return `missing` if no value + can be found. """ return missing - def parse_files(self, req, name, arg): - """Pull a file from the request or return `missing` if the value file - cannot be found. + def load_files(self, req, schema): + """Load files from the request or return `missing` if no values can be + found. """ return missing diff --git a/src/webargs/djangoparser.py b/src/webargs/djangoparser.py index fd5cc11c..6e65713d 100644 --- a/src/webargs/djangoparser.py +++ b/src/webargs/djangoparser.py @@ -19,7 +19,11 @@ def get(self, args, request): return HttpResponse('Hello ' + args['name']) """ from webargs import core -from webargs.core import json +from webargs.multidictproxy import MultiDictProxy + + +def is_json_request(req): + return core.is_json(req.content_type) class DjangoParser(core.Parser): @@ -33,41 +37,36 @@ class DjangoParser(core.Parser): the parser and returning the appropriate `HTTPResponse`. """ - def parse_querystring(self, req, name, field): - """Pull the querystring value from the request.""" - return core.get_value(req.GET, name, field) - - def parse_form(self, req, name, field): - """Pull the form value from the request.""" - return core.get_value(req.POST, name, field) - - def parse_json(self, req, name, field): - """Pull a json value from the request body.""" - json_data = self._cache.get("json") - if json_data is None: - try: - self._cache["json"] = json_data = core.parse_json(req.body) - except AttributeError: - return core.missing - except json.JSONDecodeError as e: - if e.doc == "": - return core.missing - else: - return self.handle_invalid_json_error(e, req) - return core.get_value(json_data, name, field, allow_many_nested=True) - - def parse_cookies(self, req, name, field): - """Pull the value from the cookiejar.""" - return core.get_value(req.COOKIES, name, field) - - def parse_headers(self, req, name, field): + def _raw_load_json(self, req): + """Read a json payload from the request for the core parser's load_json + + Checks the input mimetype and may return 'missing' if the mimetype is + non-json, even if the request body is parseable as json.""" + if not is_json_request(req): + return core.missing + + return core.parse_json(req.body) + + def load_querystring(self, req, schema): + """Return query params from the request as a MultiDictProxy.""" + return MultiDictProxy(req.GET, schema) + + def load_form(self, req, schema): + """Return form values from the request as a MultiDictProxy.""" + return MultiDictProxy(req.POST, schema) + + def load_cookies(self, req, schema): + """Return cookies from the request.""" + return req.COOKIES + + def load_headers(self, req, schema): raise NotImplementedError( "Header parsing not supported by {0}".format(self.__class__.__name__) ) - def parse_files(self, req, name, field): - """Pull a file from the request.""" - return core.get_value(req.FILES, name, field) + def load_files(self, req, schema): + """Return files from the request as a MultiDictProxy.""" + return MultiDictProxy(req.FILES, schema) def get_request_from_view_args(self, view, args, kwargs): # The first argument is either `self` or `request` @@ -76,9 +75,6 @@ def get_request_from_view_args(self, view, args, kwargs): except AttributeError: # first arg is request return args[0] - def handle_invalid_json_error(self, error, req, *args, **kwargs): - raise error - parser = DjangoParser() use_args = parser.use_args diff --git a/src/webargs/falconparser.py b/src/webargs/falconparser.py index b8c5ec76..2032c917 100644 --- a/src/webargs/falconparser.py +++ b/src/webargs/falconparser.py @@ -5,7 +5,7 @@ from falcon.util.uri import parse_query_string from webargs import core -from webargs.core import json +from webargs.multidictproxy import MultiDictProxy HTTP_422 = "422 Unprocessable Entity" @@ -30,23 +30,6 @@ def is_json_request(req): return content_type and core.is_json(content_type) -def parse_json_body(req): - if req.content_length in (None, 0): - # Nothing to do - return {} - if is_json_request(req): - body = req.stream.read() - if body: - try: - return core.parse_json(body) - except json.JSONDecodeError as e: - if e.doc == "": - return core.missing - else: - raise - return {} - - # NOTE: Adapted from falcon.request.Request._parse_form_urlencoded def parse_form_body(req): if ( @@ -69,7 +52,8 @@ def parse_form_body(req): return parse_query_string( body, keep_blank_qs_values=req.options.keep_blank_qs_values ) - return {} + + return core.missing class HTTPError(falcon.HTTPError): @@ -91,12 +75,20 @@ def to_dict(self, *args, **kwargs): class FalconParser(core.Parser): """Falcon request argument parser.""" - def parse_querystring(self, req, name, field): - """Pull a querystring value from the request.""" - return core.get_value(req.params, name, field) + # Note on the use of MultiDictProxy throughout: + # Falcon parses query strings and form values into ordinary dicts, but with + # the values listified where appropriate + # it is still therefore necessary in these cases to wrap them in + # MultiDictProxy because we need to use the schema to determine when single + # values should be wrapped in lists due to the type of the destination + # field + + def load_querystring(self, req, schema): + """Return query params from the request as a MultiDictProxy.""" + return MultiDictProxy(req.params, schema) - def parse_form(self, req, name, field): - """Pull a form value from the request. + def load_form(self, req, schema): + """Return form values from the request as a MultiDictProxy .. note:: @@ -105,44 +97,46 @@ def parse_form(self, req, name, field): form = self._cache.get("form") if form is None: self._cache["form"] = form = parse_form_body(req) - return core.get_value(form, name, field) + if form is core.missing: + return form + return MultiDictProxy(form, schema) - def parse_json(self, req, name, field): - """Pull a JSON body value from the request. + def _raw_load_json(self, req): + """Return a json payload from the request for the core parser's load_json - .. note:: - - The request stream will be read and left at EOF. - """ - json_data = self._cache.get("json_data") - if json_data is None: - try: - self._cache["json_data"] = json_data = parse_json_body(req) - except json.JSONDecodeError as e: - return self.handle_invalid_json_error(e, req) - return core.get_value(json_data, name, field, allow_many_nested=True) - - def parse_headers(self, req, name, field): - """Pull a header value from the request.""" - # Use req.get_headers rather than req.headers for performance - return req.get_header(name, required=False) or core.missing - - def parse_cookies(self, req, name, field): - """Pull a cookie value from the request.""" - cookies = self._cache.get("cookies") - if cookies is None: - self._cache["cookies"] = cookies = req.cookies - return core.get_value(cookies, name, field) + Checks the input mimetype and may return 'missing' if the mimetype is + non-json, even if the request body is parseable as json.""" + if not is_json_request(req) or req.content_length in (None, 0): + return core.missing + body = req.stream.read() + if body: + return core.parse_json(body) + else: + return core.missing + + def load_headers(self, req, schema): + """Return headers from the request.""" + # Falcon only exposes headers as a dict (not multidict) + return req.headers + + def load_cookies(self, req, schema): + """Return cookies from the request.""" + # Cookies are expressed in Falcon as a dict, but the possibility of + # multiple values for a cookie is preserved internally -- if desired in + # the future, webargs could add a MultiDict type for Cookies here built + # from (req, schema), but Falcon does not provide one out of the box + return req.cookies def get_request_from_view_args(self, view, args, kwargs): """Get request from a resource method's arguments. Assumes that request is the second argument. """ req = args[1] - assert isinstance(req, falcon.Request), "Argument is not a falcon.Request" + if not isinstance(req, falcon.Request): + raise TypeError("Argument is not a falcon.Request") return req - def parse_files(self, req, name, field): + def load_files(self, req, schema): raise NotImplementedError( "Parsing files not yet supported by {0}".format(self.__class__.__name__) ) @@ -154,7 +148,7 @@ def handle_error(self, error, req, schema, error_status_code, error_headers): raise LookupError("Status code {0} not supported".format(error_status_code)) raise HTTPError(status, errors=error.messages, headers=error_headers) - def handle_invalid_json_error(self, error, req, *args, **kwargs): + def _handle_invalid_json_error(self, error, req, *args, **kwargs): status = status_map[400] messages = {"json": ["Invalid JSON body."]} raise HTTPError(status, errors=messages) diff --git a/src/webargs/flaskparser.py b/src/webargs/flaskparser.py index 9f6b38f7..0774a9c7 100644 --- a/src/webargs/flaskparser.py +++ b/src/webargs/flaskparser.py @@ -23,7 +23,8 @@ def index(args): from werkzeug.exceptions import HTTPException from webargs import core -from webargs.core import json +from webargs.compat import MARSHMALLOW_VERSION_INFO +from webargs.multidictproxy import MultiDictProxy def abort(http_status_code, exc=None, **kwargs): @@ -48,61 +49,63 @@ class FlaskParser(core.Parser): """Flask request argument parser.""" __location_map__ = dict( - view_args="parse_view_args", - path="parse_view_args", + view_args="load_view_args", + path="load_view_args", **core.Parser.__location_map__ ) - def parse_view_args(self, req, name, field): - """Pull a value from the request's ``view_args``.""" - return core.get_value(req.view_args, name, field) - - def parse_json(self, req, name, field): - """Pull a json value from the request.""" - json_data = self._cache.get("json") - if json_data is None: - # We decode the json manually here instead of - # using req.get_json() so that we can handle - # JSONDecodeErrors consistently - data = req.get_data(cache=True) - try: - self._cache["json"] = json_data = core.parse_json(data) - except json.JSONDecodeError as e: - if e.doc == "": - return core.missing - else: - return self.handle_invalid_json_error(e, req) - return core.get_value(json_data, name, field, allow_many_nested=True) - - def parse_querystring(self, req, name, field): - """Pull a querystring value from the request.""" - return core.get_value(req.args, name, field) - - def parse_form(self, req, name, field): - """Pull a form value from the request.""" - try: - return core.get_value(req.form, name, field) - except AttributeError: - pass - return core.missing - - def parse_headers(self, req, name, field): - """Pull a value from the header data.""" - return core.get_value(req.headers, name, field) - - def parse_cookies(self, req, name, field): - """Pull a value from the cookiejar.""" - return core.get_value(req.cookies, name, field) - - def parse_files(self, req, name, field): - """Pull a file from the request.""" - return core.get_value(req.files, name, field) + def _raw_load_json(self, req): + """Return a json payload from the request for the core parser's load_json + + Checks the input mimetype and may return 'missing' if the mimetype is + non-json, even if the request body is parseable as json.""" + if not is_json_request(req): + return core.missing + + return core.parse_json(req.get_data(cache=True)) + + def _handle_invalid_json_error(self, error, req, *args, **kwargs): + abort(400, exc=error, messages={"json": ["Invalid JSON body."]}) + + def load_view_args(self, req, schema): + """Return the request's ``view_args`` or ``missing`` if there are none.""" + return req.view_args or core.missing + + def load_querystring(self, req, schema): + """Return query params from the request as a MultiDictProxy.""" + return MultiDictProxy(req.args, schema) + + def load_form(self, req, schema): + """Return form values from the request as a MultiDictProxy.""" + return MultiDictProxy(req.form, schema) + + def load_headers(self, req, schema): + """Return headers from the request as a MultiDictProxy.""" + return MultiDictProxy(req.headers, schema) + + def load_cookies(self, req, schema): + """Return cookies from the request.""" + return req.cookies + + def load_files(self, req, schema): + """Return files from the request as a MultiDictProxy.""" + return MultiDictProxy(req.files, schema) def handle_error(self, error, req, schema, error_status_code, error_headers): """Handles errors during parsing. Aborts the current HTTP request and responds with a 422 error. """ status_code = error_status_code or self.DEFAULT_VALIDATION_STATUS + # on marshmallow 2, a many schema receiving a non-list value will + # produce this specific error back -- reformat it to match the + # marshmallow 3 message so that Flask can properly encode it + messages = error.messages + if ( + MARSHMALLOW_VERSION_INFO[0] < 3 + and schema.many + and messages == {0: {}, "_schema": ["Invalid input type."]} + ): + messages.pop(0) abort( status_code, exc=error, @@ -111,9 +114,6 @@ def handle_error(self, error, req, schema, error_status_code, error_headers): headers=error_headers, ) - def handle_invalid_json_error(self, error, req, *args, **kwargs): - abort(400, exc=error, messages={"json": ["Invalid JSON body."]}) - def get_default_request(self): """Override to use Flask's thread-local request objec by default""" return flask.request diff --git a/src/webargs/multidictproxy.py b/src/webargs/multidictproxy.py new file mode 100644 index 00000000..29bc106c --- /dev/null +++ b/src/webargs/multidictproxy.py @@ -0,0 +1,66 @@ +from webargs.compat import MARSHMALLOW_VERSION_INFO, Mapping +from webargs.core import missing, is_multiple + + +class MultiDictProxy(Mapping): + """ + A proxy object which wraps multidict types along with a matching schema + Whenever a value is looked up, it is checked against the schema to see if + there is a matching field where `is_multiple` is True. If there is, then + the data should be loaded as a list or tuple. + + In all other cases, __getitem__ proxies directly to the input multidict. + """ + + def __init__(self, multidict, schema): + self.data = multidict + self.multiple_keys = self._collect_multiple_keys(schema) + + def _collect_multiple_keys(self, schema): + result = set() + for name, field in schema.fields.items(): + if not is_multiple(field): + continue + if MARSHMALLOW_VERSION_INFO[0] < 3: + result.add(field.load_from if field.load_from is not None else name) + else: + result.add(field.data_key if field.data_key is not None else name) + return result + + def __getitem__(self, key): + val = self.data.get(key, missing) + if val is missing or key not in self.multiple_keys: + return val + if hasattr(self.data, "getlist"): + return self.data.getlist(key) + elif hasattr(self.data, "getall"): + return self.data.getall(key) + elif isinstance(val, (list, tuple)): + return val + if val is None: + return None + return [val] + + def __delitem__(self, key): + del self.data[key] + + def __setitem__(self, key, value): + self.data[key] = value + + def __getattr__(self, name): + return getattr(self.data, name) + + def __iter__(self): + return iter(self.data) + + def __contains__(self, x): + return x in self.data + + def __len__(self): + return len(self.data) + + def __eq__(self, other): + return self.data == other + + def __ne__(self, other): + return self.data != other diff --git a/src/webargs/pyramidparser.py b/src/webargs/pyramidparser.py index 6da01af6..be45a0df 100644 --- a/src/webargs/pyramidparser.py +++ b/src/webargs/pyramidparser.py @@ -34,56 +34,56 @@ def hello_world(request, args): from webargs import core from webargs.core import json from webargs.compat import text_type +from webargs.multidictproxy import MultiDictProxy + + +def is_json_request(req): + return core.is_json(req.headers.get("content-type")) class PyramidParser(core.Parser): """Pyramid request argument parser.""" __location_map__ = dict( - matchdict="parse_matchdict", - path="parse_matchdict", + matchdict="load_matchdict", + path="load_matchdict", **core.Parser.__location_map__ ) - def parse_querystring(self, req, name, field): - """Pull a querystring value from the request.""" - return core.get_value(req.GET, name, field) - - def parse_form(self, req, name, field): - """Pull a form value from the request.""" - return core.get_value(req.POST, name, field) - - def parse_json(self, req, name, field): - """Pull a json value from the request.""" - json_data = self._cache.get("json") - if json_data is None: - try: - self._cache["json"] = json_data = core.parse_json(req.body, req.charset) - except json.JSONDecodeError as e: - if e.doc == "": - return core.missing - else: - return self.handle_invalid_json_error(e, req) - if json_data is None: - return core.missing - return core.get_value(json_data, name, field, allow_many_nested=True) + def _raw_load_json(self, req): + """Return a json payload from the request for the core parser's load_json + + Checks the input mimetype and may return 'missing' if the mimetype is + non-json, even if the request body is parseable as json.""" + if not is_json_request(req): + return core.missing + + return core.parse_json(req.body, req.charset) + + def load_querystring(self, req, schema): + """Return query params from the request as a MultiDictProxy.""" + return MultiDictProxy(req.GET, schema) + + def load_form(self, req, schema): + """Return form values from the request as a MultiDictProxy.""" + return MultiDictProxy(req.POST, schema) - def parse_cookies(self, req, name, field): - """Pull the value from the cookiejar.""" - return core.get_value(req.cookies, name, field) + def load_cookies(self, req, schema): + """Return cookies from the request as a MultiDictProxy.""" + return MultiDictProxy(req.cookies, schema) - def parse_headers(self, req, name, field): - """Pull a value from the header data.""" - return core.get_value(req.headers, name, field) + def load_headers(self, req, schema): + """Return headers from the request as a MultiDictProxy.""" + return MultiDictProxy(req.headers, schema) - def parse_files(self, req, name, field): - """Pull a file from the request.""" + def load_files(self, req, schema): + """Return files from the request as a MultiDictProxy.""" files = ((k, v) for k, v in req.POST.items() if hasattr(v, "file")) - return core.get_value(MultiDict(files), name, field) + return MultiDictProxy(MultiDict(files), schema) - def parse_matchdict(self, req, name, field): - """Pull a value from the request's `matchdict`.""" - return core.get_value(req.matchdict, name, field) + def load_matchdict(self, req, schema): + """Return the request's ``matchdict`` as a MultiDictProxy.""" + return MultiDictProxy(req.matchdict, schema) def handle_error(self, error, req, schema, error_status_code, error_headers): """Handles errors during parsing. Aborts the current HTTP request and @@ -100,7 +100,7 @@ def handle_error(self, error, req, schema, error_status_code, error_headers): response.body = body.encode("utf-8") if isinstance(body, text_type) else body raise response - def handle_invalid_json_error(self, error, req, *args, **kwargs): + def _handle_invalid_json_error(self, error, req, *args, **kwargs): messages = {"json": ["Invalid JSON body."]} response = exception_response( 400, detail=text_type(messages), content_type="application/json" @@ -113,7 +113,7 @@ def use_args( self, argmap, req=None, - locations=core.Parser.DEFAULT_LOCATIONS, + location=core.Parser.DEFAULT_LOCATION, as_kwargs=False, validate=None, error_status_code=None, @@ -127,7 +127,7 @@ def use_args( of argname -> `marshmallow.fields.Field` pairs, or a callable which accepts a request and returns a `marshmallow.Schema`. :param req: The request object to parse. Pulled off of the view by default. - :param tuple locations: Where on the request to search for values. + :param str location: Where on the request to load values. :param bool as_kwargs: Whether to insert arguments as keyword arguments. :param callable validate: Validation function that receives the dictionary of parsed arguments. If the function returns ``False``, the parser @@ -137,7 +137,7 @@ def use_args( :param dict error_headers: Headers passed to error handler functions when a a `ValidationError` is raised. """ - locations = locations or self.locations + location = location or self.location # Optimization: If argmap is passed as a dictionary, we only need # to generate a Schema once if isinstance(argmap, collections.Mapping): @@ -155,7 +155,7 @@ def wrapper(obj, *args, **kwargs): parsed_args = self.parse( argmap, req=request, - locations=locations, + location=location, validate=validate, error_status_code=error_status_code, error_headers=error_headers, diff --git a/src/webargs/testing.py b/src/webargs/testing.py index 922bc473..d17a71b5 100644 --- a/src/webargs/testing.py +++ b/src/webargs/testing.py @@ -40,24 +40,35 @@ def testapp(self): def test_parse_querystring_args(self, testapp): assert testapp.get("/echo?name=Fred").json == {"name": "Fred"} - def test_parse_querystring_with_query_location_specified(self, testapp): - assert testapp.get("/echo_query?name=Steve").json == {"name": "Steve"} - def test_parse_form(self, testapp): - assert testapp.post("/echo", {"name": "Joe"}).json == {"name": "Joe"} + assert testapp.post("/echo_form", {"name": "Joe"}).json == {"name": "Joe"} def test_parse_json(self, testapp): - assert testapp.post_json("/echo", {"name": "Fred"}).json == {"name": "Fred"} + assert testapp.post_json("/echo_json", {"name": "Fred"}).json == { + "name": "Fred" + } + + def test_parse_json_missing(self, testapp): + assert testapp.post("/echo_json", "").json == {"name": "World"} + + def test_parse_json_or_form(self, testapp): + assert testapp.post_json("/echo_json_or_form", {"name": "Fred"}).json == { + "name": "Fred" + } + assert testapp.post("/echo_json_or_form", {"name": "Joe"}).json == { + "name": "Joe" + } + assert testapp.post("/echo_json_or_form", "").json == {"name": "World"} def test_parse_querystring_default(self, testapp): assert testapp.get("/echo").json == {"name": "World"} def test_parse_json_default(self, testapp): - assert testapp.post_json("/echo", {}).json == {"name": "World"} + assert testapp.post_json("/echo_json", {}).json == {"name": "World"} def test_parse_json_with_charset(self, testapp): res = testapp.post( - "/echo", + "/echo_json", json.dumps({"name": "Steve"}), content_type="application/json;charset=UTF-8", ) @@ -65,23 +76,27 @@ def test_parse_json_with_charset(self, testapp): def test_parse_json_with_vendor_media_type(self, testapp): res = testapp.post( - "/echo", + "/echo_json", json.dumps({"name": "Steve"}), content_type="application/vnd.api+json;charset=UTF-8", ) assert res.json == {"name": "Steve"} - def test_parse_json_ignores_extra_data(self, testapp): - assert testapp.post_json("/echo", {"extra": "data"}).json == {"name": "World"} + def test_parse_ignore_extra_data(self, testapp): + assert testapp.post_json( + "/echo_ignoring_extra_data", {"extra": "data"} + ).json == {"name": "World"} - def test_parse_json_blank(self, testapp): - assert testapp.post_json("/echo", None).json == {"name": "World"} + def test_parse_json_empty(self, testapp): + assert testapp.post_json("/echo_json", {}).json == {"name": "World"} - def test_parse_json_ignore_unexpected_int(self, testapp): - assert testapp.post_json("/echo", 1).json == {"name": "World"} + def test_parse_json_error_unexpected_int(self, testapp): + res = testapp.post_json("/echo_json", 1, expect_errors=True) + assert res.status_code == 422 - def test_parse_json_ignore_unexpected_list(self, testapp): - assert testapp.post_json("/echo", [{"extra": "data"}]).json == {"name": "World"} + def test_parse_json_error_unexpected_list(self, testapp): + res = testapp.post_json("/echo_json", [{"extra": "data"}], expect_errors=True) + assert res.status_code == 422 def test_parse_json_many_schema_invalid_input(self, testapp): res = testapp.post_json( @@ -93,34 +108,54 @@ def test_parse_json_many_schema(self, testapp): res = testapp.post_json("/echo_many_schema", [{"name": "Steve"}]).json assert res == [{"name": "Steve"}] - def test_parse_json_many_schema_ignore_malformed_data(self, testapp): - assert testapp.post_json("/echo_many_schema", {"extra": "data"}).json == [] + def test_parse_json_many_schema_error_malformed_data(self, testapp): + res = testapp.post_json( + "/echo_many_schema", {"extra": "data"}, expect_errors=True + ) + assert res.status_code == 422 def test_parsing_form_default(self, testapp): - assert testapp.post("/echo", {}).json == {"name": "World"} + assert testapp.post("/echo_form", {}).json == {"name": "World"} def test_parse_querystring_multiple(self, testapp): expected = {"name": ["steve", "Loria"]} assert testapp.get("/echo_multi?name=steve&name=Loria").json == expected + # test that passing a single value parses correctly + # on parsers like falconparser, where there is no native MultiDict type, + # this verifies the usage of MultiDictProxy to ensure that single values + # are "listified" + def test_parse_querystring_multiple_single_value(self, testapp): + expected = {"name": ["steve"]} + assert testapp.get("/echo_multi?name=steve").json == expected + def test_parse_form_multiple(self, testapp): expected = {"name": ["steve", "Loria"]} assert ( - testapp.post("/echo_multi", {"name": ["steve", "Loria"]}).json == expected + testapp.post("/echo_multi_form", {"name": ["steve", "Loria"]}).json + == expected ) def test_parse_json_list(self, testapp): expected = {"name": ["Steve"]} - assert testapp.post_json("/echo_multi", {"name": "Steve"}).json == expected + assert ( + testapp.post_json("/echo_multi_json", {"name": ["Steve"]}).json == expected + ) + + def test_parse_json_list_error_malformed_data(self, testapp): + res = testapp.post_json( + "/echo_multi_json", {"name": "Steve"}, expect_errors=True + ) + assert res.status_code == 422 def test_parse_json_with_nonascii_chars(self, testapp): text = u"øˆƒ£ºº∆ƒˆ∆" - assert testapp.post_json("/echo", {"name": text}).json == {"name": text} + assert testapp.post_json("/echo_json", {"name": text}).json == {"name": text} # https://github.com/marshmallow-code/webargs/issues/427 def test_parse_json_with_nonutf8_chars(self, testapp): res = testapp.post( - "/echo", + "/echo_json", b"\xfe", headers={"Accept": "application/json", "Content-Type": "application/json"}, expect_errors=True, @@ -130,7 +165,7 @@ def test_parse_json_with_nonutf8_chars(self, testapp): assert res.json == {"json": ["Invalid JSON body."]} def test_validation_error_returns_422_response(self, testapp): - res = testapp.post("/echo", {"name": "b"}, expect_errors=True) + res = testapp.post_json("/echo_json", {"name": "b"}, expect_errors=True) assert res.status_code == 422 def test_user_validation_error_returns_422_response_by_default(self, testapp): @@ -187,10 +222,6 @@ def test_parse_nested_many_missing(self, testapp): res = testapp.post_json("/echo_nested_many", in_data) assert res.json == {} - def test_parse_json_if_no_json(self, testapp): - res = testapp.post("/echo") - assert res.json == {"name": "World"} - def test_parse_files(self, testapp): res = testapp.post( "/echo_file", {"myfile": webtest.Upload("README.rst", b"data")} @@ -199,8 +230,14 @@ def test_parse_files(self, testapp): # https://github.com/sloria/webargs/pull/297 def test_empty_json(self, testapp): + res = testapp.post("/echo_json") + assert res.status_code == 200 + assert res.json == {"name": "World"} + + # https://github.com/sloria/webargs/pull/297 + def test_empty_json_with_headers(self, testapp): res = testapp.post( - "/echo", + "/echo_json", "", headers={"Accept": "application/json", "Content-Type": "application/json"}, ) @@ -210,7 +247,7 @@ def test_empty_json(self, testapp): # https://github.com/sloria/webargs/issues/329 def test_invalid_json(self, testapp): res = testapp.post( - "/echo", + "/echo_json", '{"foo": "bar", }', headers={"Accept": "application/json", "Content-Type": "application/json"}, expect_errors=True, diff --git a/src/webargs/tornadoparser.py b/src/webargs/tornadoparser.py index 984c1e56..f7077379 100644 --- a/src/webargs/tornadoparser.py +++ b/src/webargs/tornadoparser.py @@ -15,11 +15,12 @@ def get(self, args): self.write(response) """ import tornado.web +import tornado.concurrent from tornado.escape import _unicode from webargs import core from webargs.compat import basestring -from webargs.core import json +from webargs.multidictproxy import MultiDictProxy class HTTPError(tornado.web.HTTPError): @@ -31,93 +32,92 @@ def __init__(self, *args, **kwargs): super(HTTPError, self).__init__(*args, **kwargs) -def parse_json_body(req): - """Return the decoded JSON body from the request.""" +def is_json_request(req): content_type = req.headers.get("Content-Type") - if content_type and core.is_json(content_type): - try: - return core.parse_json(req.body) - except TypeError: - pass - except json.JSONDecodeError as e: - if e.doc == "": - return core.missing - else: - raise - return {} + return content_type is not None and core.is_json(content_type) -# From tornado.web.RequestHandler.decode_argument -def decode_argument(value, name=None): - """Decodes an argument from the request. +class WebArgsTornadoMultiDictProxy(MultiDictProxy): + """ + Override class for Tornado multidicts, handles argument decoding + requirements. """ - try: - return _unicode(value) - except UnicodeDecodeError: - raise HTTPError(400, "Invalid unicode in %s: %r" % (name or "url", value[:40])) + def __getitem__(self, key): + try: + value = self.data.get(key, core.missing) + if value is core.missing: + return core.missing + elif key in self.multiple_keys: + return [_unicode(v) if isinstance(v, basestring) else v for v in value] + elif value and isinstance(value, (list, tuple)): + value = value[0] + + if isinstance(value, basestring): + return _unicode(value) + else: + return value + # based on tornado.web.RequestHandler.decode_argument + except UnicodeDecodeError: + raise HTTPError(400, "Invalid unicode in %s: %r" % (key, value[:40])) -def get_value(d, name, field): - """Handle gets from 'multidicts' made of lists - It handles cases: ``{"key": [value]}`` and ``{"key": value}`` +class WebArgsTornadoCookiesMultiDictProxy(MultiDictProxy): """ - multiple = core.is_multiple(field) - value = d.get(name, core.missing) - if value is core.missing: - return core.missing - if multiple and value is not core.missing: - return [ - decode_argument(v, name) if isinstance(v, basestring) else v for v in value - ] - ret = value - if value and isinstance(value, (list, tuple)): - ret = value[0] - if isinstance(ret, basestring): - return decode_argument(ret, name) - else: - return ret + And a special override for cookies because they come back as objects with a + `value` attribute we need to extract. + Also, does not use the `_unicode` decoding step + """ + + def __getitem__(self, key): + cookie = self.data.get(key, core.missing) + if cookie is core.missing: + return core.missing + elif key in self.multiple_keys: + return [cookie.value] + else: + return cookie.value class TornadoParser(core.Parser): """Tornado request argument parser.""" - def parse_json(self, req, name, field): - """Pull a json value from the request.""" - json_data = self._cache.get("json") - if json_data is None: - try: - self._cache["json"] = json_data = parse_json_body(req) - except json.JSONDecodeError as e: - return self.handle_invalid_json_error(e, req) - if json_data is None: - return core.missing - return core.get_value(json_data, name, field, allow_many_nested=True) + def _raw_load_json(self, req): + """Return a json payload from the request for the core parser's load_json - def parse_querystring(self, req, name, field): - """Pull a querystring value from the request.""" - return get_value(req.query_arguments, name, field) + Checks the input mimetype and may return 'missing' if the mimetype is + non-json, even if the request body is parseable as json.""" + if not is_json_request(req): + return core.missing - def parse_form(self, req, name, field): - """Pull a form value from the request.""" - return get_value(req.body_arguments, name, field) + # request.body may be a concurrent.Future on streaming requests + # this would cause a TypeError if we try to parse it + if isinstance(req.body, tornado.concurrent.Future): + return core.missing - def parse_headers(self, req, name, field): - """Pull a value from the header data.""" - return get_value(req.headers, name, field) + return core.parse_json(req.body) - def parse_cookies(self, req, name, field): - """Pull a value from the header data.""" - cookie = req.cookies.get(name) + def load_querystring(self, req, schema): + """Return query params from the request as a MultiDictProxy.""" + return WebArgsTornadoMultiDictProxy(req.query_arguments, schema) - if cookie is not None: - return [cookie.value] if core.is_multiple(field) else cookie.value - else: - return [] if core.is_multiple(field) else None + def load_form(self, req, schema): + """Return form values from the request as a MultiDictProxy.""" + return WebArgsTornadoMultiDictProxy(req.body_arguments, schema) + + def load_headers(self, req, schema): + """Return headers from the request as a MultiDictProxy.""" + return WebArgsTornadoMultiDictProxy(req.headers, schema) + + def load_cookies(self, req, schema): + """Return cookies from the request as a MultiDictProxy.""" + # use the specialized subclass specifically for handling Tornado + # cookies + return WebArgsTornadoCookiesMultiDictProxy(req.cookies, schema) - def parse_files(self, req, name, field): - """Pull a file from the request.""" - return get_value(req.files, name, field) + def load_files(self, req, schema): + """Return files from the request as a MultiDictProxy.""" + return WebArgsTornadoMultiDictProxy(req.files, schema) def handle_error(self, error, req, schema, error_status_code, error_headers): """Handles errors during parsing. Raises a `tornado.web.HTTPError` @@ -136,7 +136,7 @@ def handle_error(self, error, req, schema, error_status_code, error_headers): headers=error_headers, ) - def handle_invalid_json_error(self, error, req, *args, **kwargs): + def _handle_invalid_json_error(self, error, req, *args, **kwargs): raise HTTPError( 400, log_message="Invalid JSON body.", diff --git a/src/webargs/webapp2parser.py b/src/webargs/webapp2parser.py index 9da15585..90212d5f 100644 --- a/src/webargs/webapp2parser.py +++ b/src/webargs/webapp2parser.py @@ -31,45 +31,37 @@ def get_kwargs(self, name=None): import webob.multidict from webargs import core -from webargs.core import json +from webargs.multidictproxy import MultiDictProxy class Webapp2Parser(core.Parser): """webapp2 request argument parser.""" - def parse_json(self, req, name, field): - """Pull a json value from the request.""" - json_data = self._cache.get("json") - if json_data is None: - try: - self._cache["json"] = json_data = core.parse_json(req.body) - except json.JSONDecodeError as e: - if e.doc == "": - return core.missing - else: - raise - return core.get_value(json_data, name, field, allow_many_nested=True) - - def parse_querystring(self, req, name, field): - """Pull a querystring value from the request.""" - return core.get_value(req.GET, name, field) - - def parse_form(self, req, name, field): - """Pull a form value from the request.""" - return core.get_value(req.POST, name, field) - - def parse_cookies(self, req, name, field): - """Pull the value from the cookiejar.""" - return core.get_value(req.cookies, name, field) - - def parse_headers(self, req, name, field): - """Pull a value from the header data.""" - return core.get_value(req.headers, name, field) - - def parse_files(self, req, name, field): - """Pull a file from the request.""" + def _raw_load_json(self, req): + """Return a json payload from the request for the core parser's + load_json""" + return core.parse_json(req.body) + + def load_querystring(self, req, schema): + """Return query params from the request as a MultiDictProxy.""" + return MultiDictProxy(req.GET, schema) + + def load_form(self, req, schema): + """Return form values from the request as a MultiDictProxy.""" + return MultiDictProxy(req.POST, schema) + + def load_cookies(self, req, schema): + """Return cookies from the request as a MultiDictProxy.""" + return MultiDictProxy(req.cookies, schema) + + def load_headers(self, req, schema): + """Return headers from the request as a MultiDictProxy.""" + return MultiDictProxy(req.headers, schema) + + def load_files(self, req, schema): + """Return files from the request as a MultiDictProxy.""" files = ((k, v) for k, v in req.POST.items() if hasattr(v, "file")) - return core.get_value(webob.multidict.MultiDict(files), name, field) + return MultiDictProxy(webob.multidict.MultiDict(files), schema) def get_default_request(self): return webapp2.get_request() diff --git a/tests/apps/aiohttp_app.py b/tests/apps/aiohttp_app.py index dcdf6efa..c6933920 100644 --- a/tests/apps/aiohttp_app.py +++ b/tests/apps/aiohttp_app.py @@ -1,10 +1,9 @@ import asyncio import aiohttp -from aiohttp.web import json_response -from aiohttp import web import marshmallow as ma - +from aiohttp import web +from aiohttp.web import json_response from webargs import fields from webargs.aiohttpparser import parser, use_args, use_kwargs from webargs.core import MARSHMALLOW_VERSION_INFO, json @@ -25,12 +24,29 @@ class Meta: strict_kwargs = {"strict": True} if MARSHMALLOW_VERSION_INFO[0] < 3 else {} hello_many_schema = HelloSchema(many=True, **strict_kwargs) +# variant which ignores unknown fields +exclude_kwargs = ( + {"strict": True} if MARSHMALLOW_VERSION_INFO[0] < 3 else {"unknown": ma.EXCLUDE} +) +hello_exclude_schema = HelloSchema(**exclude_kwargs) + + ##### Handlers ##### async def echo(request): + parsed = await parser.parse(hello_args, request, location="query") + return json_response(parsed) + + +async def echo_form(request): + parsed = await parser.parse(hello_args, request, location="form") + return json_response(parsed) + + +async def echo_json(request): try: - parsed = await parser.parse(hello_args, request) + parsed = await parser.parse(hello_args, request, location="json") except json.JSONDecodeError: raise web.HTTPBadRequest( body=json.dumps(["Invalid JSON."]).encode("utf-8"), @@ -39,48 +55,70 @@ async def echo(request): return json_response(parsed) -async def echo_query(request): - parsed = await parser.parse(hello_args, request, locations=("query",)) +async def echo_json_or_form(request): + try: + parsed = await parser.parse(hello_args, request, location="json_or_form") + except json.JSONDecodeError: + raise web.HTTPBadRequest( + body=json.dumps(["Invalid JSON."]).encode("utf-8"), + content_type="application/json", + ) return json_response(parsed) -@use_args(hello_args) +@use_args(hello_args, location="query") async def echo_use_args(request, args): return json_response(args) -@use_kwargs(hello_args) +@use_kwargs(hello_args, location="query") async def echo_use_kwargs(request, name): return json_response({"name": name}) -@use_args({"value": fields.Int()}, validate=lambda args: args["value"] > 42) +@use_args( + {"value": fields.Int()}, validate=lambda args: args["value"] > 42, location="form" +) async def echo_use_args_validated(request, args): return json_response(args) +async def echo_ignoring_extra_data(request): + return json_response(await parser.parse(hello_exclude_schema, request)) + + async def echo_multi(request): + parsed = await parser.parse(hello_multiple, request, location="query") + return json_response(parsed) + + +async def echo_multi_form(request): + parsed = await parser.parse(hello_multiple, request, location="form") + return json_response(parsed) + + +async def echo_multi_json(request): parsed = await parser.parse(hello_multiple, request) return json_response(parsed) async def echo_many_schema(request): - parsed = await parser.parse(hello_many_schema, request, locations=("json",)) + parsed = await parser.parse(hello_many_schema, request) return json_response(parsed) -@use_args({"value": fields.Int()}) +@use_args({"value": fields.Int()}, location="query") async def echo_use_args_with_path_param(request, args): return json_response(args) -@use_kwargs({"value": fields.Int()}) +@use_kwargs({"value": fields.Int()}, location="query") async def echo_use_kwargs_with_path_param(request, value): return json_response({"value": value}) -@use_args({"page": fields.Int(), "q": fields.Int()}, locations=("query",)) -@use_args({"name": fields.Str()}, locations=("json",)) +@use_args({"page": fields.Int(), "q": fields.Int()}, location="query") +@use_args({"name": fields.Str()}) async def echo_use_args_multiple(request, query_parsed, json_parsed): return json_response({"query_parsed": query_parsed, "json_parsed": json_parsed}) @@ -95,12 +133,12 @@ def always_fail(value): async def echo_headers(request): - parsed = await parser.parse(hello_args, request, locations=("headers",)) + parsed = await parser.parse(hello_exclude_schema, request, location="headers") return json_response(parsed) async def echo_cookie(request): - parsed = await parser.parse(hello_args, request, locations=("cookies",)) + parsed = await parser.parse(hello_args, request, location="cookies") return json_response(parsed) @@ -134,25 +172,27 @@ async def echo_nested_many_data_key(request): async def echo_match_info(request): - parsed = await parser.parse({"mymatch": fields.Int(location="match_info")}, request) + parsed = await parser.parse( + {"mymatch": fields.Int()}, request, location="match_info" + ) return json_response(parsed) class EchoHandler: - @use_args(hello_args) + @use_args(hello_args, location="query") async def get(self, request, args): return json_response(args) class EchoHandlerView(web.View): @asyncio.coroutine - @use_args(hello_args) + @use_args(hello_args, location="query") def get(self, args): return json_response(args) @asyncio.coroutine -@use_args(HelloSchema, as_kwargs=True) +@use_args(HelloSchema, as_kwargs=True, location="query") def echo_use_schema_as_kwargs(request, name): return json_response({"name": name}) @@ -168,12 +208,17 @@ def add_route(app, methods, route, handler): def create_app(): app = aiohttp.web.Application() - add_route(app, ["GET", "POST"], "/echo", echo) - add_route(app, ["GET"], "/echo_query", echo_query) - add_route(app, ["GET", "POST"], "/echo_use_args", echo_use_args) - add_route(app, ["GET", "POST"], "/echo_use_kwargs", echo_use_kwargs) - add_route(app, ["GET", "POST"], "/echo_use_args_validated", echo_use_args_validated) - add_route(app, ["GET", "POST"], "/echo_multi", echo_multi) + add_route(app, ["GET"], "/echo", echo) + add_route(app, ["POST"], "/echo_form", echo_form) + add_route(app, ["POST"], "/echo_json", echo_json) + add_route(app, ["POST"], "/echo_json_or_form", echo_json_or_form) + add_route(app, ["GET"], "/echo_use_args", echo_use_args) + add_route(app, ["GET"], "/echo_use_kwargs", echo_use_kwargs) + add_route(app, ["POST"], "/echo_use_args_validated", echo_use_args_validated) + add_route(app, ["POST"], "/echo_ignoring_extra_data", echo_ignoring_extra_data) + add_route(app, ["GET"], "/echo_multi", echo_multi) + add_route(app, ["POST"], "/echo_multi_form", echo_multi_form) + add_route(app, ["POST"], "/echo_multi_json", echo_multi_json) add_route(app, ["GET", "POST"], "/echo_many_schema", echo_many_schema) add_route( app, diff --git a/tests/apps/bottle_app.py b/tests/apps/bottle_app.py index b8b9ae7c..abacea83 100644 --- a/tests/apps/bottle_app.py +++ b/tests/apps/bottle_app.py @@ -6,6 +6,7 @@ from webargs.bottleparser import parser, use_args, use_kwargs from webargs.core import MARSHMALLOW_VERSION_INFO + hello_args = {"name": fields.Str(missing="World", validate=lambda n: len(n) >= 3)} hello_multiple = {"name": fields.List(fields.Str())} @@ -17,61 +18,100 @@ class HelloSchema(ma.Schema): strict_kwargs = {"strict": True} if MARSHMALLOW_VERSION_INFO[0] < 3 else {} hello_many_schema = HelloSchema(many=True, **strict_kwargs) +# variant which ignores unknown fields +exclude_kwargs = ( + {"strict": True} if MARSHMALLOW_VERSION_INFO[0] < 3 else {"unknown": ma.EXCLUDE} +) +hello_exclude_schema = HelloSchema(**exclude_kwargs) + app = Bottle() debug(True) -@app.route("/echo", method=["GET", "POST"]) +@app.route("/echo", method=["GET"]) def echo(): - return parser.parse(hello_args, request) + return parser.parse(hello_args, request, location="query") -@app.route("/echo_query") -def echo_query(): - return parser.parse(hello_args, request, locations=("query",)) +@app.route("/echo_form", method=["POST"]) +def echo_form(): + return parser.parse(hello_args, location="form") -@app.route("/echo_use_args", method=["GET", "POST"]) -@use_args(hello_args) -def echo_use_args(args): - return args +@app.route("/echo_json", method=["POST"]) +def echo_json(): + return parser.parse(hello_args) -@app.route("/echo_use_kwargs", method=["GET", "POST"], apply=use_kwargs(hello_args)) -def echo_use_kwargs(name): - return {"name": name} +@app.route("/echo_json_or_form", method=["POST"]) +def echo_json_or_form(): + return parser.parse(hello_args, location="json_or_form") + + +@app.route("/echo_use_args", method=["GET"]) +@use_args(hello_args, location="query") +def echo_use_args(args): + return args @app.route( "/echo_use_args_validated", - method=["GET", "POST"], - apply=use_args({"value": fields.Int()}, validate=lambda args: args["value"] > 42), + method=["POST"], + apply=use_args( + {"value": fields.Int()}, + validate=lambda args: args["value"] > 42, + location="form", + ), ) def echo_use_args_validated(args): return args -@app.route("/echo_multi", method=["GET", "POST"]) +@app.route("/echo_ignoring_extra_data", method=["POST"]) +def echo_json_ignore_extra_data(): + return parser.parse(hello_exclude_schema) + + +@app.route( + "/echo_use_kwargs", method=["GET"], apply=use_kwargs(hello_args, location="query") +) +def echo_use_kwargs(name): + return {"name": name} + + +@app.route("/echo_multi", method=["GET"]) def echo_multi(): - return parser.parse(hello_multiple, request) + return parser.parse(hello_multiple, request, location="query") + + +@app.route("/echo_multi_form", method=["POST"]) +def multi_form(): + return parser.parse(hello_multiple, location="form") + + +@app.route("/echo_multi_json", method=["POST"]) +def multi_json(): + return parser.parse(hello_multiple) -@app.route("/echo_many_schema", method=["GET", "POST"]) +@app.route("/echo_many_schema", method=["POST"]) def echo_many_schema(): - arguments = parser.parse(hello_many_schema, request, locations=("json",)) + arguments = parser.parse(hello_many_schema, request) return HTTPResponse(body=json.dumps(arguments), content_type="application/json") @app.route( - "/echo_use_args_with_path_param/", apply=use_args({"value": fields.Int()}) + "/echo_use_args_with_path_param/", + apply=use_args({"value": fields.Int()}, location="query"), ) def echo_use_args_with_path_param(args, name): return args @app.route( - "/echo_use_kwargs_with_path_param/", apply=use_kwargs({"value": fields.Int()}) + "/echo_use_kwargs_with_path_param/", + apply=use_kwargs({"value": fields.Int()}, location="query"), ) def echo_use_kwargs_with_path_param(name, value): return {"value": value} @@ -88,18 +128,20 @@ def always_fail(value): @app.route("/echo_headers") def echo_headers(): - return parser.parse(hello_args, request, locations=("headers",)) + # the "exclude schema" must be used in this case because WSGI headers may + # be populated with many fields not sent by the caller + return parser.parse(hello_exclude_schema, request, location="headers") @app.route("/echo_cookie") def echo_cookie(): - return parser.parse(hello_args, request, locations=("cookies",)) + return parser.parse(hello_args, request, location="cookies") @app.route("/echo_file", method=["POST"]) def echo_file(): args = {"myfile": fields.Field()} - result = parser.parse(args, locations=("files",)) + result = parser.parse(args, location="files") myfile = result["myfile"] content = myfile.file.read().decode("utf8") return {"myfile": content} diff --git a/tests/apps/django_app/base/settings.py b/tests/apps/django_app/base/settings.py index a127df7c..0dd41b0b 100644 --- a/tests/apps/django_app/base/settings.py +++ b/tests/apps/django_app/base/settings.py @@ -7,7 +7,7 @@ TEMPLATE_DEBUG = True -ALLOWED_HOSTS = [] +ALLOWED_HOSTS = ["*"] # Application definition INSTALLED_APPS = ("django.contrib.contenttypes",) diff --git a/tests/apps/django_app/base/urls.py b/tests/apps/django_app/base/urls.py index 9613c743..07a86e91 100644 --- a/tests/apps/django_app/base/urls.py +++ b/tests/apps/django_app/base/urls.py @@ -2,12 +2,19 @@ from tests.apps.django_app.echo import views + urlpatterns = [ url(r"^echo$", views.echo), - url(r"^echo_query$", views.echo_query), + url(r"^echo_form$", views.echo_form), + url(r"^echo_json$", views.echo_json), + url(r"^echo_json_or_form$", views.echo_json_or_form), url(r"^echo_use_args$", views.echo_use_args), + url(r"^echo_use_args_validated$", views.echo_use_args_validated), + url(r"^echo_ignoring_extra_data$", views.echo_ignoring_extra_data), url(r"^echo_use_kwargs$", views.echo_use_kwargs), url(r"^echo_multi$", views.echo_multi), + url(r"^echo_multi_form$", views.echo_multi_form), + url(r"^echo_multi_json$", views.echo_multi_json), url(r"^echo_many_schema$", views.echo_many_schema), url( r"^echo_use_args_with_path_param/(?P\w+)$", diff --git a/tests/apps/django_app/echo/views.py b/tests/apps/django_app/echo/views.py index d08e83aa..d236ff42 100644 --- a/tests/apps/django_app/echo/views.py +++ b/tests/apps/django_app/echo/views.py @@ -7,6 +7,7 @@ from webargs.djangoparser import parser, use_args, use_kwargs from webargs.core import MARSHMALLOW_VERSION_INFO + hello_args = {"name": fields.Str(missing="World", validate=lambda n: len(n) >= 3)} hello_multiple = {"name": fields.List(fields.Str())} @@ -18,90 +19,143 @@ class HelloSchema(ma.Schema): strict_kwargs = {"strict": True} if MARSHMALLOW_VERSION_INFO[0] < 3 else {} hello_many_schema = HelloSchema(many=True, **strict_kwargs) +# variant which ignores unknown fields +exclude_kwargs = ( + {"strict": True} if MARSHMALLOW_VERSION_INFO[0] < 3 else {"unknown": ma.EXCLUDE} +) +hello_exclude_schema = HelloSchema(**exclude_kwargs) + def json_response(data, **kwargs): return HttpResponse(json.dumps(data), content_type="application/json", **kwargs) +def handle_view_errors(f): + def wrapped(*args, **kwargs): + try: + return f(*args, **kwargs) + except ma.ValidationError as err: + return json_response(err.messages, status=422) + except json.JSONDecodeError: + return json_response({"json": ["Invalid JSON body."]}, status=400) + + return wrapped + + +@handle_view_errors def echo(request): - try: - args = parser.parse(hello_args, request) - except ma.ValidationError as err: - return json_response(err.messages, status=parser.DEFAULT_VALIDATION_STATUS) - except json.JSONDecodeError: - return json_response({"json": ["Invalid JSON body."]}, status=400) - return json_response(args) + return json_response(parser.parse(hello_args, request, location="query")) -def echo_query(request): - return json_response(parser.parse(hello_args, request, locations=("query",))) +@handle_view_errors +def echo_form(request): + return json_response(parser.parse(hello_args, request, location="form")) -@use_args(hello_args) +@handle_view_errors +def echo_json(request): + return json_response(parser.parse(hello_args, request)) + + +@handle_view_errors +def echo_json_or_form(request): + return json_response(parser.parse(hello_args, request, location="json_or_form")) + + +@handle_view_errors +@use_args(hello_args, location="query") def echo_use_args(request, args): return json_response(args) -@use_kwargs(hello_args) +@handle_view_errors +@use_args( + {"value": fields.Int()}, validate=lambda args: args["value"] > 42, location="form" +) +def echo_use_args_validated(args): + return json_response(args) + + +@handle_view_errors +def echo_ignoring_extra_data(request): + return json_response(parser.parse(hello_exclude_schema, request)) + + +@handle_view_errors +@use_kwargs(hello_args, location="query") def echo_use_kwargs(request, name): return json_response({"name": name}) +@handle_view_errors def echo_multi(request): + return json_response(parser.parse(hello_multiple, request, location="query")) + + +@handle_view_errors +def echo_multi_form(request): + return json_response(parser.parse(hello_multiple, request, location="form")) + + +@handle_view_errors +def echo_multi_json(request): return json_response(parser.parse(hello_multiple, request)) +@handle_view_errors def echo_many_schema(request): - try: - return json_response( - parser.parse(hello_many_schema, request, locations=("json",)) - ) - except ma.ValidationError as err: - return json_response(err.messages, status=parser.DEFAULT_VALIDATION_STATUS) + return json_response(parser.parse(hello_many_schema, request)) -@use_args({"value": fields.Int()}) +@handle_view_errors +@use_args({"value": fields.Int()}, location="query") def echo_use_args_with_path_param(request, args, name): return json_response(args) -@use_kwargs({"value": fields.Int()}) +@handle_view_errors +@use_kwargs({"value": fields.Int()}, location="query") def echo_use_kwargs_with_path_param(request, value, name): return json_response({"value": value}) +@handle_view_errors def always_error(request): def always_fail(value): raise ma.ValidationError("something went wrong") argmap = {"text": fields.Str(validate=always_fail)} - try: - return parser.parse(argmap, request) - except ma.ValidationError as err: - return json_response(err.messages, status=parser.DEFAULT_VALIDATION_STATUS) + return parser.parse(argmap, request) +@handle_view_errors def echo_headers(request): - return json_response(parser.parse(hello_args, request, locations=("headers",))) + return json_response( + parser.parse(hello_exclude_schema, request, location="headers") + ) +@handle_view_errors def echo_cookie(request): - return json_response(parser.parse(hello_args, request, locations=("cookies",))) + return json_response(parser.parse(hello_args, request, location="cookies")) +@handle_view_errors def echo_file(request): args = {"myfile": fields.Field()} - result = parser.parse(args, request, locations=("files",)) + result = parser.parse(args, request, location="files") myfile = result["myfile"] content = myfile.read().decode("utf8") return json_response({"myfile": content}) +@handle_view_errors def echo_nested(request): argmap = {"name": fields.Nested({"first": fields.Str(), "last": fields.Str()})} return json_response(parser.parse(argmap, request)) +@handle_view_errors def echo_nested_many(request): argmap = { "users": fields.Nested({"id": fields.Int(), "name": fields.Str()}, many=True) @@ -110,27 +164,33 @@ def echo_nested_many(request): class EchoCBV(View): + @handle_view_errors def get(self, request): - try: - args = parser.parse(hello_args, self.request) - except ma.ValidationError as err: - return json_response(err.messages, status=parser.DEFAULT_VALIDATION_STATUS) - return json_response(args) + location_kwarg = {} if request.method == "POST" else {"location": "query"} + return json_response(parser.parse(hello_args, self.request, **location_kwarg)) post = get class EchoUseArgsCBV(View): - @use_args(hello_args) + @handle_view_errors + @use_args(hello_args, location="query") def get(self, request, args): return json_response(args) - post = get + @handle_view_errors + @use_args(hello_args) + def post(self, request, args): + return json_response(args) class EchoUseArgsWithParamCBV(View): - @use_args(hello_args) + @handle_view_errors + @use_args(hello_args, location="query") def get(self, request, args, pid): return json_response(args) - post = get + @handle_view_errors + @use_args(hello_args) + def post(self, request, args, pid): + return json_response(args) diff --git a/tests/apps/falcon_app.py b/tests/apps/falcon_app.py index f68541f5..c1a63710 100644 --- a/tests/apps/falcon_app.py +++ b/tests/apps/falcon_app.py @@ -1,10 +1,8 @@ -from webargs.core import json - import falcon import marshmallow as ma from webargs import fields +from webargs.core import MARSHMALLOW_VERSION_INFO, json from webargs.falconparser import parser, use_args, use_kwargs -from webargs.core import MARSHMALLOW_VERSION_INFO hello_args = {"name": fields.Str(missing="World", validate=lambda n: len(n) >= 3)} hello_multiple = {"name": fields.List(fields.Str())} @@ -17,74 +15,92 @@ class HelloSchema(ma.Schema): strict_kwargs = {"strict": True} if MARSHMALLOW_VERSION_INFO[0] < 3 else {} hello_many_schema = HelloSchema(many=True, **strict_kwargs) +# variant which ignores unknown fields +exclude_kwargs = ( + {"strict": True} if MARSHMALLOW_VERSION_INFO[0] < 3 else {"unknown": ma.EXCLUDE} +) +hello_exclude_schema = HelloSchema(**exclude_kwargs) + class Echo(object): def on_get(self, req, resp): - try: - parsed = parser.parse(hello_args, req) - except json.JSONDecodeError: - resp.body = json.dumps(["Invalid JSON."]) - resp.status = falcon.HTTP_400 - else: - resp.body = json.dumps(parsed) + parsed = parser.parse(hello_args, req, location="query") + resp.body = json.dumps(parsed) - on_post = on_get + +class EchoForm(object): + def on_post(self, req, resp): + parsed = parser.parse(hello_args, req, location="form") + resp.body = json.dumps(parsed) -class EchoQuery(object): - def on_get(self, req, resp): - parsed = parser.parse(hello_args, req, locations=("query",)) +class EchoJSON(object): + def on_post(self, req, resp): + parsed = parser.parse(hello_args, req) + resp.body = json.dumps(parsed) + + +class EchoJSONOrForm(object): + def on_post(self, req, resp): + parsed = parser.parse(hello_args, req, location="json_or_form") resp.body = json.dumps(parsed) class EchoUseArgs(object): - @use_args(hello_args) + @use_args(hello_args, location="query") def on_get(self, req, resp, args): resp.body = json.dumps(args) - on_post = on_get - class EchoUseKwargs(object): - @use_kwargs(hello_args) + @use_kwargs(hello_args, location="query") def on_get(self, req, resp, name): resp.body = json.dumps({"name": name}) - on_post = on_get - class EchoUseArgsValidated(object): - @use_args({"value": fields.Int()}, validate=lambda args: args["value"] > 42) - def on_get(self, req, resp, args): + @use_args( + {"value": fields.Int()}, + validate=lambda args: args["value"] > 42, + location="form", + ) + def on_post(self, req, resp, args): resp.body = json.dumps(args) - on_post = on_get + +class EchoJSONIgnoreExtraData(object): + def on_post(self, req, resp): + resp.body = json.dumps(parser.parse(hello_exclude_schema, req)) class EchoMulti(object): def on_get(self, req, resp): - resp.body = json.dumps(parser.parse(hello_multiple, req)) + resp.body = json.dumps(parser.parse(hello_multiple, req, location="query")) - on_post = on_get +class EchoMultiForm(object): + def on_post(self, req, resp): + resp.body = json.dumps(parser.parse(hello_multiple, req, location="form")) -class EchoManySchema(object): - def on_get(self, req, resp): - resp.body = json.dumps( - parser.parse(hello_many_schema, req, locations=("json",)) - ) - on_post = on_get +class EchoMultiJSON(object): + def on_post(self, req, resp): + resp.body = json.dumps(parser.parse(hello_multiple, req)) + + +class EchoManySchema(object): + def on_post(self, req, resp): + resp.body = json.dumps(parser.parse(hello_many_schema, req)) class EchoUseArgsWithPathParam(object): - @use_args({"value": fields.Int()}) + @use_args({"value": fields.Int()}, location="query") def on_get(self, req, resp, args, name): resp.body = json.dumps(args) class EchoUseKwargsWithPathParam(object): - @use_kwargs({"value": fields.Int()}) + @use_kwargs({"value": fields.Int()}, location="query") def on_get(self, req, resp, value, name): resp.body = json.dumps({"value": value}) @@ -102,12 +118,17 @@ def always_fail(value): class EchoHeaders(object): def on_get(self, req, resp): - resp.body = json.dumps(parser.parse(hello_args, req, locations=("headers",))) + class HeaderSchema(ma.Schema): + NAME = fields.Str(missing="World") + + resp.body = json.dumps( + parser.parse(HeaderSchema(**exclude_kwargs), req, location="headers") + ) class EchoCookie(object): def on_get(self, req, resp): - resp.body = json.dumps(parser.parse(hello_args, req, locations=("cookies",))) + resp.body = json.dumps(parser.parse(hello_args, req, location="cookies")) class EchoNested(object): @@ -134,7 +155,7 @@ def hook(req, resp, params): return hook -@falcon.before(use_args_hook(hello_args)) +@falcon.before(use_args_hook(hello_args, location="query")) class EchoUseArgsHook(object): def on_get(self, req, resp): resp.body = json.dumps(req.context["args"]) @@ -143,11 +164,16 @@ def on_get(self, req, resp): def create_app(): app = falcon.API() app.add_route("/echo", Echo()) - app.add_route("/echo_query", EchoQuery()) + app.add_route("/echo_form", EchoForm()) + app.add_route("/echo_json", EchoJSON()) + app.add_route("/echo_json_or_form", EchoJSONOrForm()) app.add_route("/echo_use_args", EchoUseArgs()) app.add_route("/echo_use_kwargs", EchoUseKwargs()) app.add_route("/echo_use_args_validated", EchoUseArgsValidated()) + app.add_route("/echo_ignoring_extra_data", EchoJSONIgnoreExtraData()) app.add_route("/echo_multi", EchoMulti()) + app.add_route("/echo_multi_form", EchoMultiForm()) + app.add_route("/echo_multi_json", EchoMultiJSON()) app.add_route("/echo_many_schema", EchoManySchema()) app.add_route("/echo_use_args_with_path_param/{name}", EchoUseArgsWithPathParam()) app.add_route( diff --git a/tests/apps/flask_app.py b/tests/apps/flask_app.py index 019cb9f9..5cb6230b 100644 --- a/tests/apps/flask_app.py +++ b/tests/apps/flask_app.py @@ -23,57 +23,90 @@ class HelloSchema(ma.Schema): strict_kwargs = {"strict": True} if MARSHMALLOW_VERSION_INFO[0] < 3 else {} hello_many_schema = HelloSchema(many=True, **strict_kwargs) +# variant which ignores unknown fields +exclude_kwargs = ( + {"strict": True} if MARSHMALLOW_VERSION_INFO[0] < 3 else {"unknown": ma.EXCLUDE} +) +hello_exclude_schema = HelloSchema(**exclude_kwargs) + app = Flask(__name__) app.config.from_object(TestAppConfig) -@app.route("/echo", methods=["GET", "POST"]) +@app.route("/echo", methods=["GET"]) def echo(): + return J(parser.parse(hello_args, location="query")) + + +@app.route("/echo_form", methods=["POST"]) +def echo_form(): + return J(parser.parse(hello_args, location="form")) + + +@app.route("/echo_json", methods=["POST"]) +def echo_json(): return J(parser.parse(hello_args)) -@app.route("/echo_query") -def echo_query(): - return J(parser.parse(hello_args, request, locations=("query",))) +@app.route("/echo_json_or_form", methods=["POST"]) +def echo_json_or_form(): + return J(parser.parse(hello_args, location="json_or_form")) -@app.route("/echo_use_args", methods=["GET", "POST"]) -@use_args(hello_args) +@app.route("/echo_use_args", methods=["GET"]) +@use_args(hello_args, location="query") def echo_use_args(args): return J(args) -@app.route("/echo_use_args_validated", methods=["GET", "POST"]) -@use_args({"value": fields.Int()}, validate=lambda args: args["value"] > 42) +@app.route("/echo_use_args_validated", methods=["POST"]) +@use_args( + {"value": fields.Int()}, validate=lambda args: args["value"] > 42, location="form" +) def echo_use_args_validated(args): return J(args) -@app.route("/echo_use_kwargs", methods=["GET", "POST"]) -@use_kwargs(hello_args) +@app.route("/echo_ignoring_extra_data", methods=["POST"]) +def echo_json_ignore_extra_data(): + return J(parser.parse(hello_exclude_schema)) + + +@app.route("/echo_use_kwargs", methods=["GET"]) +@use_kwargs(hello_args, location="query") def echo_use_kwargs(name): return J({"name": name}) -@app.route("/echo_multi", methods=["GET", "POST"]) +@app.route("/echo_multi", methods=["GET"]) def multi(): + return J(parser.parse(hello_multiple, location="query")) + + +@app.route("/echo_multi_form", methods=["POST"]) +def multi_form(): + return J(parser.parse(hello_multiple, location="form")) + + +@app.route("/echo_multi_json", methods=["POST"]) +def multi_json(): return J(parser.parse(hello_multiple)) @app.route("/echo_many_schema", methods=["GET", "POST"]) def many_nested(): - arguments = parser.parse(hello_many_schema, locations=("json",)) + arguments = parser.parse(hello_many_schema) return Response(json.dumps(arguments), content_type="application/json") @app.route("/echo_use_args_with_path_param/") -@use_args({"value": fields.Int()}) +@use_args({"value": fields.Int()}, location="query") def echo_use_args_with_path(args, name): return J(args) @app.route("/echo_use_kwargs_with_path_param/") -@use_kwargs({"value": fields.Int()}) +@use_kwargs({"value": fields.Int()}, location="query") def echo_use_kwargs_with_path(name, value): return J({"value": value}) @@ -89,18 +122,20 @@ def always_fail(value): @app.route("/echo_headers") def echo_headers(): - return J(parser.parse(hello_args, locations=("headers",))) + # the "exclude schema" must be used in this case because WSGI headers may + # be populated with many fields not sent by the caller + return J(parser.parse(hello_exclude_schema, location="headers")) @app.route("/echo_cookie") def echo_cookie(): - return J(parser.parse(hello_args, request, locations=("cookies",))) + return J(parser.parse(hello_args, request, location="cookies")) @app.route("/echo_file", methods=["POST"]) def echo_file(): args = {"myfile": fields.Field()} - result = parser.parse(args, locations=("files",)) + result = parser.parse(args, location="files") fp = result["myfile"] content = fp.read().decode("utf8") return J({"myfile": content}) @@ -108,11 +143,11 @@ def echo_file(): @app.route("/echo_view_arg/") def echo_view_arg(view_arg): - return J(parser.parse({"view_arg": fields.Int()}, locations=("view_args",))) + return J(parser.parse({"view_arg": fields.Int()}, location="view_args")) @app.route("/echo_view_arg_use_args/") -@use_args({"view_arg": fields.Int(location="view_args")}) +@use_args({"view_arg": fields.Int()}, location="view_args") def echo_view_arg_with_use_args(args, **kwargs): return J(args) diff --git a/tests/apps/pyramid_app.py b/tests/apps/pyramid_app.py index 438ca721..0f2c361c 100644 --- a/tests/apps/pyramid_app.py +++ b/tests/apps/pyramid_app.py @@ -19,8 +19,22 @@ class HelloSchema(ma.Schema): strict_kwargs = {"strict": True} if MARSHMALLOW_VERSION_INFO[0] < 3 else {} hello_many_schema = HelloSchema(many=True, **strict_kwargs) +# variant which ignores unknown fields +exclude_kwargs = ( + {"strict": True} if MARSHMALLOW_VERSION_INFO[0] < 3 else {"unknown": ma.EXCLUDE} +) +hello_exclude_schema = HelloSchema(**exclude_kwargs) + def echo(request): + return parser.parse(hello_args, request, location="query") + + +def echo_form(request): + return parser.parse(hello_args, request, location="form") + + +def echo_json(request): try: return parser.parse(hello_args, request) except json.JSONDecodeError: @@ -30,39 +44,69 @@ def echo(request): raise error +def echo_json_or_form(request): + try: + return parser.parse(hello_args, request, location="json_or_form") + except json.JSONDecodeError: + error = HTTPBadRequest() + error.body = json.dumps(["Invalid JSON."]).encode("utf-8") + error.content_type = "application/json" + raise error + + +def echo_json_ignore_extra_data(request): + try: + return parser.parse(hello_exclude_schema, request) + except json.JSONDecodeError: + error = HTTPBadRequest() + error.body = json.dumps(["Invalid JSON."]).encode("utf-8") + error.content_type = "application/json" + raise error + + def echo_query(request): - return parser.parse(hello_args, request, locations=("query",)) + return parser.parse(hello_args, request, location="query") -@use_args(hello_args) +@use_args(hello_args, location="query") def echo_use_args(request, args): return args -@use_args({"value": fields.Int()}, validate=lambda args: args["value"] > 42) +@use_args( + {"value": fields.Int()}, validate=lambda args: args["value"] > 42, location="form" +) def echo_use_args_validated(request, args): return args -@use_kwargs(hello_args) +@use_kwargs(hello_args, location="query") def echo_use_kwargs(request, name): return {"name": name} def echo_multi(request): + return parser.parse(hello_multiple, request, location="query") + + +def echo_multi_form(request): + return parser.parse(hello_multiple, request, location="form") + + +def echo_multi_json(request): return parser.parse(hello_multiple, request) def echo_many_schema(request): - return parser.parse(hello_many_schema, request, locations=("json",)) + return parser.parse(hello_many_schema, request) -@use_args({"value": fields.Int()}) +@use_args({"value": fields.Int()}, location="query") def echo_use_args_with_path_param(request, args): return args -@use_kwargs({"value": fields.Int()}) +@use_kwargs({"value": fields.Int()}, location="query") def echo_use_kwargs_with_path_param(request, value): return {"value": value} @@ -76,16 +120,16 @@ def always_fail(value): def echo_headers(request): - return parser.parse(hello_args, request, locations=("headers",)) + return parser.parse(hello_exclude_schema, request, location="headers") def echo_cookie(request): - return parser.parse(hello_args, request, locations=("cookies",)) + return parser.parse(hello_args, request, location="cookies") def echo_file(request): args = {"myfile": fields.Field()} - result = parser.parse(args, request, locations=("files",)) + result = parser.parse(args, request, location="files") myfile = result["myfile"] content = myfile.file.read().decode("utf8") return {"myfile": content} @@ -104,14 +148,14 @@ def echo_nested_many(request): def echo_matchdict(request): - return parser.parse({"mymatch": fields.Int()}, request, locations=("matchdict",)) + return parser.parse({"mymatch": fields.Int()}, request, location="matchdict") class EchoCallable(object): def __init__(self, request): self.request = request - @use_args({"value": fields.Int()}) + @use_args({"value": fields.Int()}, location="query") def __call__(self, args): return args @@ -127,11 +171,17 @@ def create_app(): config = Configurator() add_route(config, "/echo", echo) + add_route(config, "/echo_form", echo_form) + add_route(config, "/echo_json", echo_json) + add_route(config, "/echo_json_or_form", echo_json_or_form) add_route(config, "/echo_query", echo_query) + add_route(config, "/echo_ignoring_extra_data", echo_json_ignore_extra_data) add_route(config, "/echo_use_args", echo_use_args) add_route(config, "/echo_use_args_validated", echo_use_args_validated) add_route(config, "/echo_use_kwargs", echo_use_kwargs) add_route(config, "/echo_multi", echo_multi) + add_route(config, "/echo_multi_form", echo_multi_form) + add_route(config, "/echo_multi_json", echo_multi_json) add_route(config, "/echo_many_schema", echo_many_schema) add_route( config, "/echo_use_args_with_path_param/{name}", echo_use_args_with_path_param diff --git a/tests/test_core.py b/tests/test_core.py index be8038b7..d52aa9a6 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -9,15 +9,15 @@ from django.utils.datastructures import MultiValueDict as DjMultiDict from bottle import MultiDict as BotMultiDict -from webargs import fields, missing, ValidationError +from webargs import fields, ValidationError from webargs.core import ( Parser, - get_value, dict2schema, is_json, get_mimetype, MARSHMALLOW_VERSION_INFO, ) +from webargs.multidictproxy import MultiDictProxy strict_kwargs = {"strict": True} if MARSHMALLOW_VERSION_INFO[0] < 3 else {} @@ -33,14 +33,14 @@ def __init__(self, status_code, headers): class MockRequestParser(Parser): """A minimal parser implementation that parses mock requests.""" - def parse_querystring(self, req, name, field): - return get_value(req.query, name, field) + def load_querystring(self, req, schema): + return MultiDictProxy(req.query, schema) - def parse_json(self, req, name, field): - return get_value(req.json, name, field) + def load_json(self, req, schema): + return req.json - def parse_cookies(self, req, name, field): - return get_value(req.cookies, name, field) + def load_cookies(self, req, schema): + return req.cookies @pytest.yield_fixture(scope="function") @@ -59,65 +59,73 @@ def parser(): # Parser tests -@mock.patch("webargs.core.Parser.parse_json") -def test_parse_json_called_by_parse_arg(parse_json, web_request): - field = fields.Field() +@mock.patch("webargs.core.Parser.load_json") +def test_load_json_called_by_parse_default(load_json, web_request): + schema = dict2schema({"foo": fields.Field()})() + load_json.return_value = {"foo": 1} p = Parser() - p.parse_arg("foo", field, web_request) - parse_json.assert_called_with(web_request, "foo", field) + p.parse(schema, web_request) + load_json.assert_called_with(web_request, schema) -@mock.patch("webargs.core.Parser.parse_querystring") -def test_parse_querystring_called_by_parse_arg(parse_querystring, web_request): - field = fields.Field() - p = Parser() - p.parse_arg("foo", field, web_request) - assert parse_querystring.called_once() - +@pytest.mark.parametrize( + "location", ["querystring", "form", "headers", "cookies", "files"] +) +def test_load_nondefault_called_by_parse_with_location(location, web_request): + with mock.patch( + "webargs.core.Parser.load_{}".format(location) + ) as mock_loadfunc, mock.patch("webargs.core.Parser.load_json") as load_json: + mock_loadfunc.return_value = {} + load_json.return_value = {} + p = Parser() + + # ensure that without location=..., the loader is not called (json is + # called) + p.parse({"foo": fields.Field()}, web_request) + assert mock_loadfunc.call_count == 0 + assert load_json.call_count == 1 -@mock.patch("webargs.core.Parser.parse_form") -def test_parse_form_called_by_parse_arg(parse_form, web_request): - field = fields.Field() - p = Parser() - p.parse_arg("foo", field, web_request) - assert parse_form.called_once() + # but when location=... is given, the loader *is* called and json is + # not called + p.parse({"foo": fields.Field()}, web_request, location=location) + assert mock_loadfunc.call_count == 1 + # it was already 1, should not go up + assert load_json.call_count == 1 -@mock.patch("webargs.core.Parser.parse_json") -def test_parse_json_not_called_when_json_not_a_location(parse_json, web_request): - field = fields.Field() - p = Parser() - p.parse_arg("foo", field, web_request, locations=("form", "querystring")) - assert parse_json.call_count == 0 +def test_parse(parser, web_request): + web_request.json = {"username": 42, "password": 42} + argmap = {"username": fields.Field(), "password": fields.Field()} + ret = parser.parse(argmap, web_request) + assert {"username": 42, "password": 42} == ret -@mock.patch("webargs.core.Parser.parse_headers") -def test_parse_headers_called_when_headers_is_a_location(parse_headers, web_request): - field = fields.Field() - p = Parser() - p.parse_arg("foo", field, web_request) - assert parse_headers.call_count == 0 - p.parse_arg("foo", field, web_request, locations=("headers",)) - parse_headers.assert_called() +@pytest.mark.skipif( + MARSHMALLOW_VERSION_INFO[0] < 3, reason="unknown=... added in marshmallow3" +) +def test_parse_with_unknown_behavior_specified(parser, web_request): + # This is new in webargs 6.x ; it's the way you can "get back" the behavior + # of webargs 5.x in which extra args are ignored + from marshmallow import EXCLUDE, INCLUDE, RAISE + web_request.json = {"username": 42, "password": 42, "fjords": 42} -@mock.patch("webargs.core.Parser.parse_cookies") -def test_parse_cookies_called_when_cookies_is_a_location(parse_cookies, web_request): - field = fields.Field() - p = Parser() - p.parse_arg("foo", field, web_request) - assert parse_cookies.call_count == 0 - p.parse_arg("foo", field, web_request, locations=("cookies",)) - parse_cookies.assert_called() + class CustomSchema(Schema): + username = fields.Field() + password = fields.Field() + # with no unknown setting or unknown=RAISE, it blows up + with pytest.raises(ValidationError, match="Unknown field."): + parser.parse(CustomSchema(), web_request) + with pytest.raises(ValidationError, match="Unknown field."): + parser.parse(CustomSchema(unknown=RAISE), web_request) -@mock.patch("webargs.core.Parser.parse_json") -def test_parse(parse_json, web_request): - parse_json.return_value = 42 - argmap = {"username": fields.Field(), "password": fields.Field()} - p = Parser() - ret = p.parse(argmap, web_request) + # with unknown=EXCLUDE the data is omitted + ret = parser.parse(CustomSchema(unknown=EXCLUDE), web_request) assert {"username": 42, "password": 42} == ret + # with unknown=INCLUDE it is added even though it isn't part of the schema + ret = parser.parse(CustomSchema(unknown=INCLUDE), web_request) + assert {"username": 42, "password": 42, "fjords": 42} == ret def test_parse_required_arg_raises_validation_error(parser, web_request): @@ -141,13 +149,10 @@ def test_arg_allow_none(parser, web_request): assert result == {"first": "Steve", "last": None} -@mock.patch("webargs.core.Parser.parse_json") -def test_parse_required_arg(parse_json, web_request): - arg = fields.Field(required=True) - parse_json.return_value = 42 - p = Parser() - result = p.parse_arg("foo", arg, web_request, locations=("json",)) - assert result == 42 +def test_parse_required_arg(parser, web_request): + web_request.json = {"foo": 42} + result = parser.parse({"foo": fields.Field(required=True)}, web_request) + assert result == {"foo": 42} def test_parse_required_list(parser, web_request): @@ -185,21 +190,21 @@ def test_parse_missing_list(parser, web_request): assert parser.parse(args, web_request) == {} -def test_default_locations(): - assert set(Parser.DEFAULT_LOCATIONS) == set(["json", "querystring", "form"]) +def test_default_location(): + assert Parser.DEFAULT_LOCATION == "json" def test_missing_with_default(parser, web_request): web_request.json = {} args = {"val": fields.Field(missing="pizza")} - result = parser.parse(args, web_request, locations=("json",)) + result = parser.parse(args, web_request) assert result["val"] == "pizza" def test_default_can_be_none(parser, web_request): web_request.json = {} args = {"val": fields.Field(missing=None, allow_none=True)} - result = parser.parse(args, web_request, locations=("json",)) + result = parser.parse(args, web_request) assert result["val"] is None @@ -217,34 +222,22 @@ def test_arg_with_default_and_location(parser, web_request): assert parser.parse(args, web_request) == {"p": 1} -def test_value_error_raised_if_parse_arg_called_with_invalid_location(web_request): +def test_value_error_raised_if_parse_called_with_invalid_location(parser, web_request): field = fields.Field() - p = Parser() - with pytest.raises(ValueError) as excinfo: - p.parse_arg("foo", field, web_request, locations=("invalidlocation", "headers")) - msg = "Invalid locations arguments: {0}".format(["invalidlocation"]) - assert msg in str(excinfo.value) - - -def test_value_error_raised_if_invalid_location_on_field(web_request, parser): - with pytest.raises(ValueError) as excinfo: - parser.parse({"foo": fields.Field(location="invalidlocation")}, web_request) - msg = "Invalid locations arguments: {0}".format(["invalidlocation"]) - assert msg in str(excinfo.value) + with pytest.raises(ValueError, match="Invalid location argument: invalidlocation"): + parser.parse({"foo": field}, web_request, location="invalidlocation") @mock.patch("webargs.core.Parser.handle_error") -@mock.patch("webargs.core.Parser.parse_json") -def test_handle_error_called_when_parsing_raises_error( - parse_json, handle_error, web_request -): - val_err = ValidationError("error occurred") - parse_json.side_effect = val_err +def test_handle_error_called_when_parsing_raises_error(handle_error, web_request): + def always_fail(*args, **kwargs): + raise ValidationError("error occurred") + p = Parser() - p.parse({"foo": fields.Field()}, web_request, locations=("json",)) - handle_error.assert_called() - parse_json.side_effect = ValidationError("another exception") - p.parse({"foo": fields.Field()}, web_request, locations=("json",)) + assert handle_error.call_count == 0 + p.parse({"foo": fields.Field()}, web_request, validate=always_fail) + assert handle_error.call_count == 1 + p.parse({"foo": fields.Field()}, web_request, validate=always_fail) assert handle_error.call_count == 2 @@ -254,22 +247,15 @@ def test_handle_error_reraises_errors(web_request): p.handle_error(ValidationError("error raised"), web_request, Schema()) -@mock.patch("webargs.core.Parser.parse_headers") -def test_locations_as_init_arguments(parse_headers, web_request): - p = Parser(locations=("headers",)) +@mock.patch("webargs.core.Parser.load_headers") +def test_location_as_init_argument(load_headers, web_request): + p = Parser(location="headers") + load_headers.return_value = {} p.parse({"foo": fields.Field()}, web_request) - assert parse_headers.called + assert load_headers.called -@mock.patch("webargs.core.Parser.parse_files") -def test_parse_files(parse_files, web_request): - p = Parser() - p.parse({"foo": fields.Field()}, web_request, locations=("files",)) - assert parse_files.called - - -@mock.patch("webargs.core.Parser.parse_json") -def test_custom_error_handler(parse_json, web_request): +def test_custom_error_handler(web_request): class CustomError(Exception): pass @@ -277,19 +263,27 @@ def error_handler(error, req, schema, status_code, headers): assert isinstance(schema, Schema) raise CustomError(error) - parse_json.side_effect = ValidationError("parse_json failed") + def failing_validate_func(args): + raise ValidationError("parsing failed") + + class MySchema(Schema): + foo = fields.Int() + + myschema = MySchema(**strict_kwargs) + web_request.json = {"foo": "hello world"} + p = Parser(error_handler=error_handler) with pytest.raises(CustomError): - p.parse({"foo": fields.Field()}, web_request) + p.parse(myschema, web_request, validate=failing_validate_func) -@mock.patch("webargs.core.Parser.parse_json") -def test_custom_error_handler_decorator(parse_json, web_request): +def test_custom_error_handler_decorator(web_request): class CustomError(Exception): pass - parse_json.side_effect = ValidationError("parse_json failed") - + mock_schema = mock.Mock(spec=Schema) + mock_schema.strict = True + mock_schema.load.side_effect = ValidationError("parsing json failed") parser = Parser() @parser.error_handler @@ -298,53 +292,47 @@ def handle_error(error, req, schema, status_code, headers): raise CustomError(error) with pytest.raises(CustomError): - parser.parse({"foo": fields.Field()}, web_request) + parser.parse(mock_schema, web_request) -def test_custom_location_handler(web_request): +def test_custom_location_loader(web_request): web_request.data = {"foo": 42} parser = Parser() - @parser.location_handler("data") - def parse_data(req, name, arg): - return req.data.get(name, missing) + @parser.location_loader("data") + def load_data(req, schema): + return req.data - result = parser.parse({"foo": fields.Int()}, web_request, locations=("data",)) + result = parser.parse({"foo": fields.Int()}, web_request, location="data") assert result["foo"] == 42 -def test_custom_location_handler_with_data_key(web_request): +def test_custom_location_loader_with_data_key(web_request): web_request.data = {"X-Foo": 42} parser = Parser() - @parser.location_handler("data") - def parse_data(req, name, arg): - return req.data.get(name, missing) + @parser.location_loader("data") + def load_data(req, schema): + return req.data data_key_kwarg = { "load_from" if (MARSHMALLOW_VERSION_INFO[0] < 3) else "data_key": "X-Foo" } result = parser.parse( - {"x_foo": fields.Int(**data_key_kwarg)}, web_request, locations=("data",) + {"x_foo": fields.Int(**data_key_kwarg)}, web_request, location="data" ) assert result["x_foo"] == 42 -def test_full_input_validation(web_request): +def test_full_input_validation(parser, web_request): web_request.json = {"foo": 41, "bar": 42} - parser = MockRequestParser() args = {"foo": fields.Int(), "bar": fields.Int()} with pytest.raises(ValidationError): # Test that `validate` receives dictionary of args - parser.parse( - args, - web_request, - locations=("json",), - validate=lambda args: args["foo"] > args["bar"], - ) + parser.parse(args, web_request, validate=lambda args: args["foo"] > args["bar"]) def test_full_input_validation_with_multiple_validators(web_request, parser): @@ -360,31 +348,29 @@ def validate2(args): web_request.json = {"a": 2, "b": 1} validators = [validate1, validate2] with pytest.raises(ValidationError, match="b must be > a"): - parser.parse(args, web_request, locations=("json",), validate=validators) + parser.parse(args, web_request, validate=validators) web_request.json = {"a": 1, "b": 2} with pytest.raises(ValidationError, match="a must be > b"): - parser.parse(args, web_request, locations=("json",), validate=validators) + parser.parse(args, web_request, validate=validators) -def test_required_with_custom_error(web_request): +def test_required_with_custom_error(parser, web_request): web_request.json = {} - parser = MockRequestParser() args = { "foo": fields.Str(required=True, error_messages={"required": "We need foo"}) } with pytest.raises(ValidationError) as excinfo: # Test that `validate` receives dictionary of args - parser.parse(args, web_request, locations=("json",)) + parser.parse(args, web_request) assert "We need foo" in excinfo.value.messages["foo"] if MARSHMALLOW_VERSION_INFO[0] < 3: assert "foo" in excinfo.value.field_names -def test_required_with_custom_error_and_validation_error(web_request): +def test_required_with_custom_error_and_validation_error(parser, web_request): web_request.json = {"foo": ""} - parser = MockRequestParser() args = { "foo": fields.Str( required="We need foo", @@ -394,7 +380,7 @@ def test_required_with_custom_error_and_validation_error(web_request): } with pytest.raises(ValidationError) as excinfo: # Test that `validate` receives dictionary of args - parser.parse(args, web_request, locations=("json",)) + parser.parse(args, web_request) assert "foo required length is 3" in excinfo.value.args[0]["foo"] if MARSHMALLOW_VERSION_INFO[0] < 3: @@ -410,7 +396,7 @@ def validate(val): parser = MockRequestParser() args = {"text": fields.Str()} with pytest.raises(ValidationError) as excinfo: - parser.parse(args, web_request, locations=("json",), validate=validate) + parser.parse(args, web_request, validate=validate) assert excinfo.value.messages == ["Invalid value."] @@ -420,14 +406,6 @@ def test_invalid_argument_for_validate(web_request, parser): assert "not a callable or list of callables." in excinfo.value.args[0] -def test_get_value_basic(): - assert get_value({"foo": 42}, "foo", False) == 42 - assert get_value({"foo": 42}, "bar", False) is missing - assert get_value({"foos": ["a", "b"]}, "foos", True) == ["a", "b"] - # https://github.com/marshmallow-code/webargs/pull/30 - assert get_value({"foos": ["a", "b"]}, "bar", True) is missing - - def create_bottle_multi_dict(): d = BotMultiDict() d["foos"] = "a" @@ -443,9 +421,24 @@ def create_bottle_multi_dict(): @pytest.mark.parametrize("input_dict", multidicts) -def test_get_value_multidict(input_dict): - field = fields.List(fields.Str()) - assert get_value(input_dict, "foos", field) == ["a", "b"] +def test_multidict_proxy(input_dict): + class ListSchema(Schema): + foos = fields.List(fields.Str()) + + class StrSchema(Schema): + foos = fields.Str() + + # this MultiDictProxy is aware that "foos" is a list field and will + # therefore produce a list with __getitem__ + list_wrapped_multidict = MultiDictProxy(input_dict, ListSchema()) + + # this MultiDictProxy is under the impression that "foos" is just a string + # and it should return "a" or "b" + # the decision between "a" and "b" in this case belongs to the framework + str_wrapped_multidict = MultiDictProxy(input_dict, StrSchema()) + + assert list_wrapped_multidict["foos"] == ["a", "b"] + assert str_wrapped_multidict["foos"] in ("a", "b") def test_parse_with_data_key(web_request): @@ -456,7 +449,7 @@ def test_parse_with_data_key(web_request): "load_from" if (MARSHMALLOW_VERSION_INFO[0] < 3) else "data_key": "Content-Type" } args = {"content_type": fields.Field(**data_key_kwargs)} - parsed = parser.parse(args, web_request, locations=("json",)) + parsed = parser.parse(args, web_request) assert parsed == {"content_type": "application/json"} @@ -470,7 +463,7 @@ def test_load_from_is_checked_after_given_key(web_request): parser = MockRequestParser() args = {"content_type": fields.Field(load_from="Content-Type")} - parsed = parser.parse(args, web_request, locations=("json",)) + parsed = parser.parse(args, web_request) assert parsed == {"content_type": "application/json"} @@ -483,7 +476,7 @@ def test_parse_with_data_key_retains_field_name_in_error(web_request): } args = {"content_type": fields.Str(**data_key_kwargs)} with pytest.raises(ValidationError) as excinfo: - parser.parse(args, web_request, locations=("json",)) + parser.parse(args, web_request) assert "Content-Type" in excinfo.value.messages assert excinfo.value.messages["Content-Type"] == ["Not a valid string."] @@ -496,7 +489,7 @@ def test_parse_nested_with_data_key(web_request): } args = {"nested_arg": fields.Nested({"right": fields.Field(**data_key_kwarg)})} - parsed = parser.parse(args, web_request, locations=("json",)) + parsed = parser.parse(args, web_request) assert parsed == {"nested_arg": {"right": "OK"}} @@ -513,7 +506,7 @@ def test_parse_nested_with_missing_key_and_data_key(web_request): ) } - parsed = parser.parse(args, web_request, locations=("json",)) + parsed = parser.parse(args, web_request) assert parsed == {"nested_arg": {"found": None}} @@ -523,7 +516,7 @@ def test_parse_nested_with_default(web_request): web_request.json = {"nested_arg": {}} args = {"nested_arg": fields.Nested({"miss": fields.Field(missing="")})} - parsed = parser.parse(args, web_request, locations=("json",)) + parsed = parser.parse(args, web_request) assert parsed == {"nested_arg": {"miss": ""}} @@ -554,8 +547,8 @@ def test_use_args_stacked(web_request, parser): web_request.json = {"username": "foo"} web_request.query = {"page": 42} - @parser.use_args(query_args, web_request, locations=("query",)) - @parser.use_args(json_args, web_request, locations=("json",)) + @parser.use_args(query_args, web_request, location="query") + @parser.use_args(json_args, web_request) def viewfunc(query_parsed, json_parsed): return {"json": json_parsed, "query": query_parsed} @@ -570,8 +563,8 @@ def test_use_kwargs_stacked(web_request, parser): web_request.json = {"username": "foo"} web_request.query = {"page": 42} - @parser.use_kwargs(query_args, web_request, locations=("query",)) - @parser.use_kwargs(json_args, web_request, locations=("json",)) + @parser.use_kwargs(query_args, web_request, location="query") + @parser.use_kwargs(json_args, web_request) def viewfunc(page, username): return {"json": {"username": username}, "query": {"page": page}} @@ -592,21 +585,21 @@ def viewfunc(*args, **kwargs): def test_list_allowed_missing(web_request, parser): args = {"name": fields.List(fields.Str())} - web_request.json = {"fakedata": True} + web_request.json = {} result = parser.parse(args, web_request) assert result == {} def test_int_list_allowed_missing(web_request, parser): args = {"name": fields.List(fields.Int())} - web_request.json = {"fakedata": True} + web_request.json = {} result = parser.parse(args, web_request) assert result == {} def test_multiple_arg_required_with_int_conversion(web_request, parser): args = {"ids": fields.List(fields.Int(), required=True)} - web_request.json = {"fakedata": True} + web_request.json = {} with pytest.raises(ValidationError) as excinfo: parser.parse(args, web_request) assert excinfo.value.messages == {"ids": ["Missing data for required field."]} @@ -747,10 +740,22 @@ def test_warning_raised_if_schema_is_not_in_strict_mode(self, web_request, parse assert "strict=True" in str(warning.message) def test_use_kwargs_stacked(self, web_request, parser): + if MARSHMALLOW_VERSION_INFO[0] >= 3: + from marshmallow import EXCLUDE + + class PageSchema(Schema): + page = fields.Int() + + pageschema = PageSchema(unknown=EXCLUDE) + userschema = self.UserSchema(unknown=EXCLUDE) + else: + pageschema = {"page": fields.Int()} + userschema = self.UserSchema(**strict_kwargs) + web_request.json = {"email": "foo@bar.com", "password": "bar", "page": 42} - @parser.use_kwargs({"page": fields.Int()}, web_request) - @parser.use_kwargs(self.UserSchema(**strict_kwargs), web_request) + @parser.use_kwargs(pageschema, web_request) + @parser.use_kwargs(userschema, web_request) def viewfunc(email, password, page): return {"email": email, "password": password, "page": page} @@ -774,18 +779,18 @@ def validate_schema(self, data, original_data, **kwargs): return True web_request.json = {"name": "Eric Cartman"} - res = parser.parse(UserSchema, web_request, locations=("json",)) + res = parser.parse(UserSchema, web_request) assert res == {"name": "Eric Cartman"} -def test_use_args_with_custom_locations_in_parser(web_request, parser): +def test_use_args_with_custom_location_in_parser(web_request, parser): custom_args = {"foo": fields.Str()} web_request.json = {} - parser.locations = ("custom",) + parser.location = "custom" - @parser.location_handler("custom") - def parse_custom(req, name, arg): - return "bar" + @parser.location_loader("custom") + def load_custom(schema, req): + return {"foo": "bar"} @parser.use_args(custom_args, web_request) def viewfunc(args): @@ -913,16 +918,6 @@ def test_type_conversion_with_multiple_required(web_request, parser): parser.parse(args, web_request) -def test_arg_location_param(web_request, parser): - web_request.json = {"foo": 24} - web_request.cookies = {"foo": 42} - args = {"foo": fields.Field(location="cookies")} - - parsed = parser.parse(args, web_request) - - assert parsed["foo"] == 42 - - def test_validation_errors_in_validator_are_passed_to_handle_error(parser, web_request): def validate(value): raise ValidationError("Something went wrong.") @@ -1041,23 +1036,23 @@ def test_parse_with_error_status_code_and_headers(web_request): assert error.headers == {"X-Foo": "bar"} -@mock.patch("webargs.core.Parser.parse_json") -def test_custom_schema_class(parse_json, web_request): +@mock.patch("webargs.core.Parser.load_json") +def test_custom_schema_class(load_json, web_request): class CustomSchema(Schema): @pre_load def pre_load(self, data, **kwargs): data["value"] += " world" return data - parse_json.return_value = "hello" + load_json.return_value = {"value": "hello"} argmap = {"value": fields.Str()} p = Parser(schema_class=CustomSchema) ret = p.parse(argmap, web_request) assert ret == {"value": "hello world"} -@mock.patch("webargs.core.Parser.parse_json") -def test_custom_default_schema_class(parse_json, web_request): +@mock.patch("webargs.core.Parser.load_json") +def test_custom_default_schema_class(load_json, web_request): class CustomSchema(Schema): @pre_load def pre_load(self, data, **kwargs): @@ -1067,7 +1062,7 @@ def pre_load(self, data, **kwargs): class CustomParser(Parser): DEFAULT_SCHEMA_CLASS = CustomSchema - parse_json.return_value = "hello" + load_json.return_value = {"value": "hello"} argmap = {"value": fields.Str()} p = CustomParser() ret = p.parse(argmap, web_request) diff --git a/tests/test_djangoparser.py b/tests/test_djangoparser.py index 5b8497a4..c585653d 100644 --- a/tests/test_djangoparser.py +++ b/tests/test_djangoparser.py @@ -23,12 +23,12 @@ def test_parsing_headers(self, testapp): def test_parsing_in_class_based_view(self, testapp): assert testapp.get("/echo_cbv?name=Fred").json == {"name": "Fred"} - assert testapp.post("/echo_cbv", {"name": "Fred"}).json == {"name": "Fred"} + assert testapp.post_json("/echo_cbv", {"name": "Fred"}).json == {"name": "Fred"} def test_use_args_in_class_based_view(self, testapp): res = testapp.get("/echo_use_args_cbv?name=Fred") assert res.json == {"name": "Fred"} - res = testapp.post("/echo_use_args_cbv", {"name": "Fred"}) + res = testapp.post_json("/echo_use_args_cbv", {"name": "Fred"}) assert res.json == {"name": "Fred"} def test_use_args_in_class_based_view_with_path_param(self, testapp): diff --git a/tests/test_falconparser.py b/tests/test_falconparser.py index d6092c72..26138424 100644 --- a/tests/test_falconparser.py +++ b/tests/test_falconparser.py @@ -19,7 +19,7 @@ def test_use_args_hook(self, testapp): # https://github.com/marshmallow-code/webargs/issues/427 def test_parse_json_with_nonutf8_chars(self, testapp): res = testapp.post( - "/echo", + "/echo_json", b"\xfe", headers={"Accept": "application/json", "Content-Type": "application/json"}, expect_errors=True, @@ -31,10 +31,15 @@ def test_parse_json_with_nonutf8_chars(self, testapp): # https://github.com/sloria/webargs/issues/329 def test_invalid_json(self, testapp): res = testapp.post( - "/echo", + "/echo_json", '{"foo": "bar", }', headers={"Accept": "application/json", "Content-Type": "application/json"}, expect_errors=True, ) assert res.status_code == 400 assert res.json["errors"] == {"json": ["Invalid JSON body."]} + + # Falcon converts headers to all-caps + def test_parsing_headers(self, testapp): + res = testapp.get("/echo_headers", headers={"name": "Fred"}) + assert res.json == {"NAME": "Fred"} diff --git a/tests/test_flaskparser.py b/tests/test_flaskparser.py index 5122196c..97b447ef 100644 --- a/tests/test_flaskparser.py +++ b/tests/test_flaskparser.py @@ -7,7 +7,7 @@ import pytest from flask import Flask -from webargs import fields, ValidationError, missing +from webargs import fields, ValidationError, missing, dict2schema from webargs.flaskparser import parser, abort from webargs.core import MARSHMALLOW_VERSION_INFO, json @@ -33,23 +33,31 @@ def test_use_args_with_view_args_parsing(self, testapp): assert res.json == {"view_arg": 42} def test_use_args_on_a_method_view(self, testapp): - res = testapp.post("/echo_method_view_use_args", {"val": 42}) + res = testapp.post_json("/echo_method_view_use_args", {"val": 42}) assert res.json == {"val": 42} def test_use_kwargs_on_a_method_view(self, testapp): - res = testapp.post("/echo_method_view_use_kwargs", {"val": 42}) + res = testapp.post_json("/echo_method_view_use_kwargs", {"val": 42}) assert res.json == {"val": 42} def test_use_kwargs_with_missing_data(self, testapp): - res = testapp.post("/echo_use_kwargs_missing", {"username": "foo"}) + res = testapp.post_json("/echo_use_kwargs_missing", {"username": "foo"}) assert res.json == {"username": "foo"} # regression test for https://github.com/marshmallow-code/webargs/issues/145 def test_nested_many_with_data_key(self, testapp): - res = testapp.post_json("/echo_nested_many_data_key", {"x_field": [{"id": 42}]}) - # https://github.com/marshmallow-code/marshmallow/pull/714 + post_with_raw_fieldname_args = ( + "/echo_nested_many_data_key", + {"x_field": [{"id": 42}]}, + ) + # under marshmallow 2 this is allowed and works if MARSHMALLOW_VERSION_INFO[0] < 3: + res = testapp.post_json(*post_with_raw_fieldname_args) assert res.json == {"x_field": [{"id": 42}]} + # but under marshmallow3 , only data_key is checked, field name is ignored + else: + res = testapp.post_json(*post_with_raw_fieldname_args, expect_errors=True) + assert res.status_code == 422 res = testapp.post_json("/echo_nested_many_data_key", {"X-Field": [{"id": 24}]}) assert res.json == {"x_field": [{"id": 24}]} @@ -81,10 +89,13 @@ def validate(x): assert type(abort_kwargs["exc"]) == ValidationError -def test_parse_form_returns_missing_if_no_form(): +@pytest.mark.parametrize("mimetype", [None, "application/json"]) +def test_load_json_returns_missing_if_no_data(mimetype): req = mock.Mock() - req.form.get.side_effect = AttributeError("no form") - assert parser.parse_form(req, "foo", fields.Field()) is missing + req.mimetype = mimetype + req.get_data.return_value = "" + schema = dict2schema({"foo": fields.Field()})() + assert parser.load_json(req, schema) is missing def test_abort_with_message(): diff --git a/tests/test_py3/test_aiohttpparser.py b/tests/test_py3/test_aiohttpparser.py index d3de2fbc..b03b8793 100644 --- a/tests/test_py3/test_aiohttpparser.py +++ b/tests/test_py3/test_aiohttpparser.py @@ -38,16 +38,17 @@ def test_use_args_on_method_handler(self, testapp): # regression test for https://github.com/marshmallow-code/webargs/issues/165 def test_multiple_args(self, testapp): - res = testapp.post_json( - "/echo_multiple_args", {"first": "1", "last": "2", "_ignore": 0} - ) + res = testapp.post_json("/echo_multiple_args", {"first": "1", "last": "2"}) assert res.json == {"first": "1", "last": "2"} # regression test for https://github.com/marshmallow-code/webargs/issues/145 def test_nested_many_with_data_key(self, testapp): - res = testapp.post_json("/echo_nested_many_data_key", {"x_field": [{"id": 42}]}) # https://github.com/marshmallow-code/marshmallow/pull/714 + # on marshmallow 2, the field name can also be used if MARSHMALLOW_VERSION_INFO[0] < 3: + res = testapp.post_json( + "/echo_nested_many_data_key", {"x_field": [{"id": 42}]} + ) assert res.json == {"x_field": [{"id": 42}]} res = testapp.post_json("/echo_nested_many_data_key", {"X-Field": [{"id": 24}]}) diff --git a/tests/test_py3/test_aiohttpparser_async_functions.py b/tests/test_py3/test_aiohttpparser_async_functions.py index a0437c4a..4b732e15 100644 --- a/tests/test_py3/test_aiohttpparser_async_functions.py +++ b/tests/test_py3/test_aiohttpparser_async_functions.py @@ -11,16 +11,16 @@ async def echo_parse(request): - parsed = await parser.parse(hello_args, request) + parsed = await parser.parse(hello_args, request, location="query") return json_response(parsed) -@use_args(hello_args) +@use_args(hello_args, location="query") async def echo_use_args(request, args): return json_response(args) -@use_kwargs(hello_args) +@use_kwargs(hello_args, location="query") async def echo_use_kwargs(request, name): return json_response({"name": name}) diff --git a/tests/test_tornadoparser.py b/tests/test_tornadoparser.py index 8eb29907..99138da5 100644 --- a/tests/test_tornadoparser.py +++ b/tests/test_tornadoparser.py @@ -1,40 +1,52 @@ # -*- coding: utf-8 -*- -from webargs.core import json +import marshmallow as ma +import mock +import pytest +import tornado.concurrent +import tornado.http1connection +import tornado.httpserver +import tornado.httputil +import tornado.ioloop +import tornado.web +from tornado.testing import AsyncHTTPTestCase +from webargs import fields, missing +from webargs.core import MARSHMALLOW_VERSION_INFO, json, parse_json +from webargs.tornadoparser import ( + WebArgsTornadoMultiDictProxy, + parser, + use_args, + use_kwargs, +) try: from urllib.parse import urlencode except ImportError: # PY2 from urllib import urlencode # type: ignore -import mock -import pytest -import marshmallow as ma +name = "name" +value = "value" -import tornado.web -import tornado.httputil -import tornado.httpserver -import tornado.http1connection -import tornado.concurrent -import tornado.ioloop -from tornado.testing import AsyncHTTPTestCase -from webargs import fields, missing -from webargs.tornadoparser import parser, use_args, use_kwargs, get_value -from webargs.core import parse_json +class AuthorSchema(ma.Schema): + name = fields.Str(missing="World", validate=lambda n: len(n) >= 3) + works = fields.List(fields.Str()) -name = "name" -value = "value" +strict_kwargs = {"strict": True} if MARSHMALLOW_VERSION_INFO[0] < 3 else {} +author_schema = AuthorSchema(**strict_kwargs) -def test_get_value_basic(): - field, multifield = fields.Field(), fields.List(fields.Str()) - assert get_value({"foo": 42}, "foo", field) == 42 - assert get_value({"foo": 42}, "bar", field) is missing - assert get_value({"foos": ["a", "b"]}, "foos", multifield) == ["a", "b"] - # https://github.com/marshmallow-code/webargs/pull/30 - assert get_value({"foos": ["a", "b"]}, "bar", multifield) is missing + +def test_tornado_multidictproxy(): + for dictval, fieldname, expected in ( + ({"name": "Sophocles"}, "name", "Sophocles"), + ({"name": "Sophocles"}, "works", missing), + ({"works": ["Antigone", "Oedipus Rex"]}, "works", ["Antigone", "Oedipus Rex"]), + ({"works": ["Antigone", "Oedipus at Colonus"]}, "name", missing), + ): + proxy = WebArgsTornadoMultiDictProxy(dictval, author_schema) + assert proxy.get(fieldname) == expected class TestQueryArgs(object): @@ -42,43 +54,23 @@ def setup_method(self, method): parser.clear_cache() def test_it_should_get_single_values(self): - query = [(name, value)] - field = fields.Field() + query = [("name", "Aeschylus")] request = make_get_request(query) - - result = parser.parse_querystring(request, name, field) - - assert result == value + result = parser.load_querystring(request, author_schema) + assert result["name"] == "Aeschylus" def test_it_should_get_multiple_values(self): - query = [(name, value), (name, value)] - field = fields.List(fields.Field()) + query = [("works", "Agamemnon"), ("works", "Nereids")] request = make_get_request(query) - - result = parser.parse_querystring(request, name, field) - - assert result == [value, value] + result = parser.load_querystring(request, author_schema) + assert result["works"] == ["Agamemnon", "Nereids"] def test_it_should_return_missing_if_not_present(self): query = [] - field = fields.Field() - field2 = fields.List(fields.Int()) - request = make_get_request(query) - - result = parser.parse_querystring(request, name, field) - result2 = parser.parse_querystring(request, name, field2) - - assert result is missing - assert result2 is missing - - def test_it_should_return_empty_list_if_multiple_and_not_present(self): - query = [] - field = fields.List(fields.Field()) request = make_get_request(query) - - result = parser.parse_querystring(request, name, field) - - assert result is missing + result = parser.load_querystring(request, author_schema) + assert result["name"] is missing + assert result["works"] is missing class TestFormArgs: @@ -86,40 +78,23 @@ def setup_method(self, method): parser.clear_cache() def test_it_should_get_single_values(self): - query = [(name, value)] - field = fields.Field() + query = [("name", "Aristophanes")] request = make_form_request(query) - - result = parser.parse_form(request, name, field) - - assert result == value + result = parser.load_form(request, author_schema) + assert result["name"] == "Aristophanes" def test_it_should_get_multiple_values(self): - query = [(name, value), (name, value)] - field = fields.List(fields.Field()) + query = [("works", "The Wasps"), ("works", "The Frogs")] request = make_form_request(query) - - result = parser.parse_form(request, name, field) - - assert result == [value, value] + result = parser.load_form(request, author_schema) + assert result["works"] == ["The Wasps", "The Frogs"] def test_it_should_return_missing_if_not_present(self): query = [] - field = fields.Field() request = make_form_request(query) - - result = parser.parse_form(request, name, field) - - assert result is missing - - def test_it_should_return_empty_list_if_multiple_and_not_present(self): - query = [] - field = fields.List(fields.Field()) - request = make_form_request(query) - - result = parser.parse_form(request, name, field) - - assert result is missing + result = parser.load_form(request, author_schema) + assert result["name"] is missing + assert result["works"] is missing class TestJSONArgs(object): @@ -127,70 +102,66 @@ def setup_method(self, method): parser.clear_cache() def test_it_should_get_single_values(self): - query = {name: value} - field = fields.Field() + query = {"name": "Euripides"} request = make_json_request(query) - result = parser.parse_json(request, name, field) - - assert result == value + result = parser.load_json(request, author_schema) + assert result["name"] == "Euripides" def test_parsing_request_with_vendor_content_type(self): - query = {name: value} - field = fields.Field() + query = {"name": "Euripides"} request = make_json_request( query, content_type="application/vnd.api+json; charset=UTF-8" ) - result = parser.parse_json(request, name, field) - - assert result == value + result = parser.load_json(request, author_schema) + assert result["name"] == "Euripides" def test_it_should_get_multiple_values(self): - query = {name: [value, value]} - field = fields.List(fields.Field()) + query = {"works": ["Medea", "Electra"]} request = make_json_request(query) - result = parser.parse_json(request, name, field) - - assert result == [value, value] + result = parser.load_json(request, author_schema) + assert result["works"] == ["Medea", "Electra"] def test_it_should_get_multiple_nested_values(self): - query = {name: [{"id": 1, "name": "foo"}, {"id": 2, "name": "bar"}]} - field = fields.List( - fields.Nested({"id": fields.Field(), "name": fields.Field()}) - ) - request = make_json_request(query) - result = parser.parse_json(request, name, field) - assert result == [{"id": 1, "name": "foo"}, {"id": 2, "name": "bar"}] - - def test_it_should_return_missing_if_not_present(self): - query = {} - field = fields.Field() + class CustomSchema(ma.Schema): + works = fields.List( + fields.Nested({"author": fields.Str(), "workname": fields.Str()}) + ) + + custom_schema = CustomSchema(**strict_kwargs) + + query = { + "works": [ + {"author": "Euripides", "workname": "Hecuba"}, + {"author": "Aristophanes", "workname": "The Birds"}, + ] + } request = make_json_request(query) - result = parser.parse_json(request, name, field) - - assert result is missing + result = parser.load_json(request, custom_schema) + assert result["works"] == [ + {"author": "Euripides", "workname": "Hecuba"}, + {"author": "Aristophanes", "workname": "The Birds"}, + ] - def test_it_should_return_empty_list_if_multiple_and_not_present(self): + def test_it_should_not_include_fieldnames_if_not_present(self): query = {} - field = fields.List(fields.Field()) request = make_json_request(query) - result = parser.parse_json(request, name, field) + result = parser.load_json(request, author_schema) + assert result == {} - assert result is missing - - def test_it_should_handle_type_error_on_parse_json(self): - field = fields.Field() + def test_it_should_handle_type_error_on_load_json(self): + # but this is different from the test above where the payload was valid + # and empty -- missing vs {} request = make_request( - body=tornado.concurrent.Future, headers={"Content-Type": "application/json"} + body=tornado.concurrent.Future(), + headers={"Content-Type": "application/json"}, ) - result = parser.parse_json(request, name, field) - assert parser._cache["json"] == {} + result = parser.load_json(request, author_schema) assert result is missing def test_it_should_handle_value_error_on_parse_json(self): - field = fields.Field() request = make_request("this is json not") - result = parser.parse_json(request, name, field) - assert parser._cache["json"] == {} + result = parser.load_json(request, author_schema) + assert parser._cache.get("json") == missing assert result is missing @@ -199,39 +170,22 @@ def setup_method(self, method): parser.clear_cache() def test_it_should_get_single_values(self): - query = {name: value} - field = fields.Field() + query = {"name": "Euphorion"} request = make_request(headers=query) - - result = parser.parse_headers(request, name, field) - - assert result == value + result = parser.load_headers(request, author_schema) + assert result["name"] == "Euphorion" def test_it_should_get_multiple_values(self): - query = {name: [value, value]} - field = fields.List(fields.Field()) + query = {"works": ["Prometheus Bound", "Prometheus Unbound"]} request = make_request(headers=query) - - result = parser.parse_headers(request, name, field) - - assert result == [value, value] + result = parser.load_headers(request, author_schema) + assert result["works"] == ["Prometheus Bound", "Prometheus Unbound"] def test_it_should_return_missing_if_not_present(self): - field = fields.Field(multiple=False) request = make_request() - - result = parser.parse_headers(request, name, field) - - assert result is missing - - def test_it_should_return_empty_list_if_multiple_and_not_present(self): - query = {} - field = fields.List(fields.Field()) - request = make_request(headers=query) - - result = parser.parse_headers(request, name, field) - - assert result is missing + result = parser.load_headers(request, author_schema) + assert result["name"] is missing + assert result["works"] is missing class TestFilesArgs(object): @@ -239,40 +193,23 @@ def setup_method(self, method): parser.clear_cache() def test_it_should_get_single_values(self): - query = [(name, value)] - field = fields.Field() + query = [("name", "Sappho")] request = make_files_request(query) - - result = parser.parse_files(request, name, field) - - assert result == value + result = parser.load_files(request, author_schema) + assert result["name"] == "Sappho" def test_it_should_get_multiple_values(self): - query = [(name, value), (name, value)] - field = fields.List(fields.Field()) + query = [("works", "Sappho 31"), ("works", "Ode to Aphrodite")] request = make_files_request(query) - - result = parser.parse_files(request, name, field) - - assert result == [value, value] + result = parser.load_files(request, author_schema) + assert result["works"] == ["Sappho 31", "Ode to Aphrodite"] def test_it_should_return_missing_if_not_present(self): query = [] - field = fields.Field() - request = make_files_request(query) - - result = parser.parse_files(request, name, field) - - assert result is missing - - def test_it_should_return_empty_list_if_multiple_and_not_present(self): - query = [] - field = fields.List(fields.Field()) request = make_files_request(query) - - result = parser.parse_files(request, name, field) - - assert result is missing + result = parser.load_files(request, author_schema) + assert result["name"] is missing + assert result["works"] is missing class TestErrorHandler(object): @@ -293,7 +230,7 @@ def test_it_should_parse_query_arguments(self): [("string", "value"), ("integer", "1"), ("integer", "2")] ) - parsed = parser.parse(attrs, request) + parsed = parser.parse(attrs, request, location="query") assert parsed["integer"] == [1, 2] assert parsed["string"] == value @@ -305,7 +242,7 @@ def test_it_should_parse_form_arguments(self): [("string", "value"), ("integer", "1"), ("integer", "2")] ) - parsed = parser.parse(attrs, request) + parsed = parser.parse(attrs, request, location="form") assert parsed["integer"] == [1, 2] assert parsed["string"] == value @@ -337,7 +274,7 @@ def test_it_should_parse_header_arguments(self): request = make_request(headers={"string": "value", "integer": ["1", "2"]}) - parsed = parser.parse(attrs, request, locations=["headers"]) + parsed = parser.parse(attrs, request, location="headers") assert parsed["string"] == value assert parsed["integer"] == [1, 2] @@ -349,7 +286,7 @@ def test_it_should_parse_cookies_arguments(self): [("string", "value"), ("integer", "1"), ("integer", "2")] ) - parsed = parser.parse(attrs, request, locations=["cookies"]) + parsed = parser.parse(attrs, request, location="cookies") assert parsed["string"] == value assert parsed["integer"] == [2] @@ -361,7 +298,7 @@ def test_it_should_parse_files_arguments(self): [("string", "value"), ("integer", "1"), ("integer", "2")] ) - parsed = parser.parse(attrs, request, locations=["files"]) + parsed = parser.parse(attrs, request, location="files") assert parsed["string"] == value assert parsed["integer"] == [1, 2] @@ -509,10 +446,22 @@ def make_request(uri=None, body=None, headers=None, files=None): class EchoHandler(tornado.web.RequestHandler): ARGS = {"name": fields.Str()} - @use_args(ARGS) + @use_args(ARGS, location="query") def get(self, args): self.write(args) + +class EchoFormHandler(tornado.web.RequestHandler): + ARGS = {"name": fields.Str()} + + @use_args(ARGS, location="form") + def post(self, args): + self.write(args) + + +class EchoJSONHandler(tornado.web.RequestHandler): + ARGS = {"name": fields.Str()} + @use_args(ARGS) def post(self, args): self.write(args) @@ -521,13 +470,18 @@ def post(self, args): class EchoWithParamHandler(tornado.web.RequestHandler): ARGS = {"name": fields.Str()} - @use_args(ARGS) + @use_args(ARGS, location="query") def get(self, id, args): self.write(args) echo_app = tornado.web.Application( - [(r"/echo", EchoHandler), (r"/echo_with_param/(\d+)", EchoWithParamHandler)] + [ + (r"/echo", EchoHandler), + (r"/echo_form", EchoFormHandler), + (r"/echo_json", EchoJSONHandler), + (r"/echo_with_param/(\d+)", EchoWithParamHandler), + ] ) @@ -537,7 +491,7 @@ def get_app(self): def test_post(self): res = self.fetch( - "/echo", + "/echo_json", method="POST", headers={"Content-Type": "application/json"}, body=json.dumps({"name": "Steve"}), @@ -545,7 +499,7 @@ def test_post(self): json_body = parse_json(res.body) assert json_body["name"] == "Steve" res = self.fetch( - "/echo", + "/echo_json", method="POST", headers={"Content-Type": "application/json"}, body=json.dumps({}), @@ -577,7 +531,7 @@ class ValidateHandler(tornado.web.RequestHandler): def post(self, args): self.write(args) - @use_kwargs(ARGS) + @use_kwargs(ARGS, location="query") def get(self, name): self.write({"status": "success"}) diff --git a/tests/test_webapp2parser.py b/tests/test_webapp2parser.py index 3fcb20bd..54243b88 100644 --- a/tests/test_webapp2parser.py +++ b/tests/test_webapp2parser.py @@ -7,11 +7,13 @@ from webargs.core import json import pytest +import marshmallow as ma from marshmallow import fields, ValidationError import webtest import webapp2 from webargs.webapp2parser import parser +from webargs.core import MARSHMALLOW_VERSION_INFO hello_args = {"name": fields.Str(missing="World")} @@ -25,32 +27,43 @@ } +class HelloSchema(ma.Schema): + name = fields.Str(missing="World", validate=lambda n: len(n) >= 3) + + +# variant which ignores unknown fields +exclude_kwargs = ( + {"strict": True} if MARSHMALLOW_VERSION_INFO[0] < 3 else {"unknown": ma.EXCLUDE} +) +hello_exclude_schema = HelloSchema(**exclude_kwargs) + + def test_parse_querystring_args(): request = webapp2.Request.blank("/echo?name=Fred") - assert parser.parse(hello_args, req=request) == {"name": "Fred"} + assert parser.parse(hello_args, req=request, location="query") == {"name": "Fred"} def test_parse_querystring_multiple(): expected = {"name": ["steve", "Loria"]} request = webapp2.Request.blank("/echomulti?name=steve&name=Loria") - assert parser.parse(hello_multiple, req=request) == expected + assert parser.parse(hello_multiple, req=request, location="query") == expected def test_parse_form(): expected = {"name": "Joe"} request = webapp2.Request.blank("/echo", POST=expected) - assert parser.parse(hello_args, req=request) == expected + assert parser.parse(hello_args, req=request, location="form") == expected def test_parse_form_multiple(): expected = {"name": ["steve", "Loria"]} request = webapp2.Request.blank("/echo", POST=urlencode(expected, doseq=True)) - assert parser.parse(hello_multiple, req=request) == expected + assert parser.parse(hello_multiple, req=request, location="form") == expected def test_parsing_form_default(): request = webapp2.Request.blank("/echo", POST="") - assert parser.parse(hello_args, req=request) == {"name": "World"} + assert parser.parse(hello_args, req=request, location="form") == {"name": "World"} def test_parse_json(): @@ -95,13 +108,15 @@ def test_parsing_cookies(): request = webapp2.Request.blank( "/", headers={"Cookie": response.headers["Set-Cookie"]} ) - assert parser.parse(hello_args, req=request, locations=("cookies",)) == expected + assert parser.parse(hello_args, req=request, location="cookies") == expected def test_parsing_headers(): expected = {"name": "Fred"} request = webapp2.Request.blank("/", headers=expected) - assert parser.parse(hello_args, req=request, locations=("headers",)) == expected + assert ( + parser.parse(hello_exclude_schema, req=request, location="headers") == expected + ) def test_parse_files(): @@ -110,7 +125,7 @@ def test_parse_files(): """ class Handler(webapp2.RequestHandler): - @parser.use_args({"myfile": fields.List(fields.Field())}, locations=("files",)) + @parser.use_args({"myfile": fields.List(fields.Field())}, location="files") def post(self, args): self.response.content_type = "application/json" @@ -130,13 +145,13 @@ def _value(f): def test_exception_on_validation_error(): request = webapp2.Request.blank("/", POST={"num": "3"}) with pytest.raises(ValidationError): - parser.parse(hello_validate, req=request) + parser.parse(hello_validate, req=request, location="form") def test_validation_error_with_message(): request = webapp2.Request.blank("/", POST={"num": "3"}) with pytest.raises(ValidationError) as exc: - parser.parse(hello_validate, req=request) + parser.parse(hello_validate, req=request, location="form") assert "Houston, we've had a problem." in exc.value @@ -148,4 +163,4 @@ def test_default_app_request(): request = webapp2.Request.blank("/echo", POST=expected) app = webapp2.WSGIApplication([]) app.set_globals(app, request) - assert parser.parse(hello_args) == expected + assert parser.parse(hello_args, location="form") == expected