Skip to content

Commit

Permalink
Convert Parser.pass_all_args to Parser.unknown
Browse files Browse the repository at this point in the history
This commit is an intermediate step meant to show the distinction
between a standalone pass_all_args flag and *always* passing all args
but allowing the user to specify a value for `unknown` to pass to
`Schema.load()`.

In this version, the webargs.Parser object is able to use
`Schema.load(unknown=EXCLUDE)` on marshmallow 3 as the default for
backwards compatibility. Users can pass `unknown=None` to specify that
the Parser should not pass a value and the schema default should be
used. In this way, the backwards-incompatible change to be made in a
future webargs release will be to change the default for
Parser.unknown from EXCLUDE to None (or to eliminate it, and make the
`None` behavior the only option).
  • Loading branch information
sirosen committed Aug 4, 2019
1 parent 8ccde33 commit 8a85d5e
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 29 deletions.
60 changes: 36 additions & 24 deletions src/webargs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@

DEFAULT_VALIDATION_STATUS = 422 # type: int

if MARSHMALLOW_VERSION_INFO[0] < 3:
PARSER_UNKNOWN_DEFAULT = None
else:
PARSER_UNKNOWN_DEFAULT = ma.EXCLUDE


def _callable_or_raise(obj):
"""Makes sure an object is callable if it is not ``None``. If not
Expand Down Expand Up @@ -160,8 +165,10 @@ class Parser(object):
:param tuple locations: Default locations to parse.
:param callable error_handler: Custom error handler function.
:param bool pass_all_args: Pass all arguments to the schema, so that the
schema's `unknown` behavior is used (default is False)
:param bool unknown: Value for the "unknown" argument to use in the
schema's ``load`` method on marshmallow v3 and later. Pass ``None``
to use the schema default. Has no effect on marhsmallow v2.
(defaults to EXCLUDE)
"""

#: Default locations to check for data
Expand All @@ -185,16 +192,16 @@ class Parser(object):
}

def __init__(
self, pass_all_args=False, locations=None, error_handler=None, schema_class=None
self,
unknown=PARSER_UNKNOWN_DEFAULT,
locations=None,
error_handler=None,
schema_class=None,
):
self.locations = locations or self.DEFAULT_LOCATIONS
self.error_callback = _callable_or_raise(error_handler)
self.schema_class = schema_class or self.DEFAULT_SCHEMA_CLASS
# TODO: pass_all_args is for compatibility, remove in a future webargs
# version and make it the only behavior (?)
self.pass_all_args = pass_all_args
if pass_all_args and MARSHMALLOW_VERSION_INFO[0] < 3:
raise ValueError("Parser.pass_all_args requires marshmallow v3")
self.unknown = unknown
#: A short-lived cache to store results from processing request bodies.
self._cache = {}

Expand Down Expand Up @@ -302,22 +309,21 @@ def _parse_request(self, schema, req, locations):
if parsed is missing:
parsed = []
else:
# start with known args
parsed = self._parse_specified_args(schema, req, locations)
if self.pass_all_args:
# in this case, start with known args above, then collect all of the
# unknown ones here
for location, argnames in iteritems(
self.get_args_by_location(req, locations)
):
to_add = {}
for argname in argnames:
if argname not in parsed: # else, it's already known
to_add[argname] = ma.fields.Raw()
parsed.update(
self._parse_specified_args(
schema, req, (location,), argdict=to_add
)
)
# then collect all of the unknown ones
for location, argnames in iteritems(
self.get_args_by_location(req, locations)
):
if argnames is missing:
continue
to_add = {}
for argname in argnames:
if argname not in parsed: # else, it's already known
to_add[argname] = ma.fields.Raw()
parsed.update(
self._parse_specified_args(schema, req, (location,), argdict=to_add)
)

return parsed

Expand Down Expand Up @@ -367,6 +373,9 @@ def get_args_by_location(self, req, locations):
"""
Get a complete mapping of locations to iterables of args
Pulled from the given request and limited to the given set of locations
May return a location name mapped to "missing", so results must be
checked
"""
return {}

Expand Down Expand Up @@ -409,7 +418,10 @@ def parse(
parsed = parser._parse_request(
schema=schema, req=req, locations=locations or self.locations
)
result = schema.load(parsed)
if self.unknown is not None:
result = schema.load(parsed, unknown=self.unknown)
else:
result = schema.load(parsed)
data = result.data if MARSHMALLOW_VERSION_INFO[0] < 3 else result
parser._validate_arguments(data, validators)
except ma.exceptions.ValidationError as error:
Expand Down
7 changes: 6 additions & 1 deletion src/webargs/flaskparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,12 @@ def get_args_by_location(self, req, locations):
if "json" in locations:
data = self._load_json_data(req)
if data is not core.missing:
data = data.keys()
if isinstance(data, dict):
data = data.keys()
# this is slightly unintuitive, but if we parse JSON which is
# not a dict, we don't know any arg names
else:
data = core.missing
result["json"] = data
if "querystring" in locations:
result["querystring"] = req.args.keys()
Expand Down
26 changes: 22 additions & 4 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,20 +959,38 @@ def test_parse_basic(web_request, parser):
MARSHMALLOW_VERSION_INFO[0] < 3,
reason="Support for unknown=... was added in marshmallow 3",
)
def test_parse_with_unknown_raises_schema(web_request, parser, monkeypatch):
def test_parse_with_unknown_raise_set_by_schema(web_request, parser, monkeypatch):
from marshmallow import RAISE

class RaisingSchema(Schema):
class MySchema(Schema):
foo = fields.Int()

raising_schema = RaisingSchema(unknown=RAISE)
raising_schema = MySchema(unknown=RAISE)

web_request.json = {"foo": "42", "bar": "baz"}
monkeypatch.setattr(parser, "pass_all_args", True)
monkeypatch.setattr(parser, "unknown", None)
with pytest.raises(ValidationError):
parser.parse(raising_schema, web_request)


@pytest.mark.skipif(
MARSHMALLOW_VERSION_INFO[0] < 3,
reason="Support for unknown=... was added in marshmallow 3",
)
def test_parse_with_unknown_raise_set_by_parser(web_request, parser, monkeypatch):
from marshmallow import RAISE

class MySchema(Schema):
foo = fields.Int()

schema = MySchema()

web_request.json = {"foo": "42", "bar": "baz"}
monkeypatch.setattr(parser, "unknown", RAISE)
with pytest.raises(ValidationError):
parser.parse(schema, web_request)


def test_parse_raises_validation_error_if_data_invalid(web_request, parser):
args = {"email": fields.Email()}
web_request.json = {"email": "invalid"}
Expand Down

0 comments on commit 8a85d5e

Please sign in to comment.