Skip to content

Commit

Permalink
Allow overriding of the the arg_name convention
Browse files Browse the repository at this point in the history
Rather than `{location}_args` being a hardcoded behavior, allow users
to subclass and override it with a custom method.

This will allow users to set alternate naming conventions in a
centralized place, on their parser class. By passing the schema object
to get_default_arg_name, we enable schemas which provide their own
argument names to the parser.

There are some ordering considerations which make it impossible to
guarantee that `get_default_arg_name` gets a schema object (rather
than, e.g. a callable which returns a schema). For this first
implementation, I have opted to move the call after any dict-to-schema
conversion happens, so that users have fewer types they need to
handle.
  • Loading branch information
sirosen committed Aug 2, 2023
1 parent 2c6ed32 commit 8d3a3b8
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 5 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ Features:
default. This will allow use of fields with ``load_default`` to specify
handling of the empty value.

* The rule for default argument names has been made configurable by overriding
the ``get_default_arg_name`` method. This is described in the argument
passing documentation.

Changes:

* Type annotations for ``FlaskParser`` have been improved
Expand Down
72 changes: 72 additions & 0 deletions docs/advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,78 @@ parameter:
Note that ``arg_name`` is available even on parsers where
``USE_ARGS_POSITIONAL`` is not set.

Using an Alternate Argument Name Convention
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

As described above, the default naming convention for ``use_args`` arguments is
``{location}_args``. You can customize this by creating a parser class and
overriding the ``get_default_arg_name`` method.

``get_default_arg_name`` takes the ``location`` and the ``schema`` as
arguments. The default implementation is:

.. code-block:: python
def get_default_arg_name(self, location, schema):
return f"{location}_args"
You can customize this to set different arg names. For example,

.. code-block:: python
from webargs.flaskparser import FlaskParser
class MyParser(FlaskParser):
USE_ARGS_POSITIONAL = False
def get_default_arg_name(self, location, schema):
if location in ("json", "form", "json_or_form"):
return "body"
elif location in ("query", "querystring"):
return "query"
return location
@app.route("/")
@parser.use_args({"foo": fields.Int(), "bar": fields.Str()}, location="query")
@parser.use_args({"baz": fields.Str()}, location="json")
def myview(*, query, body):
...
Additionally, this makes it possible to make custom schema classes which
provide an argument name. For example,

.. code-block:: python
from marshmallow import Schema
from webargs.flaskparser import FlaskParser
class RectangleSchema(Schema):
webargs_arg_name = "rectangle"
length = fields.Float()
width = fields.Float()
class MyParser(FlaskParser):
USE_ARGS_POSITIONAL = False
def get_default_arg_name(self, location, schema):
if hasattr(schema, "webargs_arg_name"):
if isinstance(schema.webargs_arg_name, str):
return schema.webargs_arg_name
return super().get_default_arg_name(location, schema)
@app.route("/")
@parser.use_args({"foo": fields.Int(), "bar": fields.Str()}, location="query")
@parser.use_args(RectangleSchema, location="json")
def myview(*, rectangle, query_args):
...
Next Steps
----------

Expand Down
22 changes: 17 additions & 5 deletions src/webargs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,18 +574,18 @@ def greet(querystring_args):
location = location or self.location
request_obj = req

if arg_name is not None and as_kwargs:
raise ValueError("arg_name and as_kwargs are mutually exclusive")
if arg_name is None and not self.USE_ARGS_POSITIONAL:
arg_name = f"{location}_args"

# Optimization: If argmap is passed as a dictionary, we only need
# to generate a Schema once
if isinstance(argmap, typing.Mapping):
if not isinstance(argmap, dict):
argmap = dict(argmap)
argmap = self.schema_class.from_dict(argmap)()

if arg_name is not None and as_kwargs:
raise ValueError("arg_name and as_kwargs are mutually exclusive")
if arg_name is None and not self.USE_ARGS_POSITIONAL:
arg_name = self.get_default_arg_name(location, argmap)

def decorator(func: typing.Callable) -> typing.Callable:
req_ = request_obj

Expand Down Expand Up @@ -690,6 +690,18 @@ def greet(name):
error_headers=error_headers,
)

def get_default_arg_name(self, location: str, schema: ArgMap) -> str:
"""This method provides the rule by which an argument name is derived for
:meth:`use_args` if no explicit ``arg_name`` is provided.
By default, the format used is ``{location}_args``. Users may override this method
to customize the default argument naming scheme.
``schema`` will be the argument map or schema passed to :meth:`use_args` unless a
dict was used, in which case it will be the schema derived from that dict.
"""
return f"{location}_args"

def location_loader(self, name: str) -> typing.Callable[[C], C]:
"""Decorator that registers a function for loading a request location.
The wrapped function receives a schema and a request.
Expand Down
91 changes: 91 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1621,3 +1621,94 @@ def mypartial(*, json_args, query_args):
assert mypartial.__webargs_argnames__ == ("query_args",)
assert route_foo.__webargs_argnames__ == ("query_args", "json_args")
assert route_bar.__webargs_argnames__ == ("query_args", "json_args")


def test_default_arg_name_pattern_is_customizable(web_request):
class MyParser(MockRequestParser):
USE_ARGS_POSITIONAL = False

def get_default_arg_name(self, location, schema):
if location == "json":
return "body"
elif location == "query":
return "query"
else:
return super().get_default_arg_name(location, schema)

parser = MyParser()

@parser.use_args({"frob": fields.Field()}, web_request, location="json")
@parser.use_args({"snork": fields.Field()}, web_request, location="query")
def myview(*, body, query):
return (body, query)

web_request.json = {"frob": "demuddler"}
web_request.query = {"snork": 2}
assert myview() == ({"frob": "demuddler"}, {"snork": 2})


def test_default_arg_name_pattern_still_allows_conflict_detection():
class MyParser(MockRequestParser):
USE_ARGS_POSITIONAL = False

def get_default_arg_name(self, location, schema):
return "data"

parser = MyParser()

with pytest.raises(ValueError, match="Attempted to pass `arg_name='data'`"):

@parser.use_args({"frob": fields.Field()}, web_request, location="json")
@parser.use_args({"snork": fields.Field()}, web_request, location="query")
def myview(*, body, query):
return (body, query)


def test_parse_with_dict_passes_schema_to_argname_derivation(web_request):
default_argname_was_called = False

class MyParser(MockRequestParser):
USE_ARGS_POSITIONAL = False

def get_default_arg_name(self, location, schema):
assert isinstance(schema, Schema)
nonlocal default_argname_was_called
default_argname_was_called = True
return super().get_default_arg_name(location, schema)

parser = MyParser()

@parser.use_args({"foo": fields.Field()}, web_request, location="json")
def myview(*, json_args):
return json_args

web_request.json = {"foo": 42}
assert myview() == {"foo": 42}
assert default_argname_was_called


def test_default_arg_name_pattern_can_pull_schema_attribute(web_request):
# this test matches a documentation example exactly
class RectangleSchema(Schema):
_webargs_arg_name = "rectangle"
length = fields.Integer()
width = fields.Integer()

class MyParser(MockRequestParser):
USE_ARGS_POSITIONAL = False

def get_default_arg_name(self, location, schema):
assert schema is not None
if hasattr(schema, "_webargs_arg_name"):
if isinstance(schema._webargs_arg_name, str):
return schema._webargs_arg_name
return super().get_default_arg_name(location, schema)

parser = MyParser()

@parser.use_args(RectangleSchema, web_request, location="json")
def area(*, rectangle):
return rectangle["length"] * rectangle["width"]

web_request.json = {"length": 6, "width": 7}
assert area() == 42

0 comments on commit 8d3a3b8

Please sign in to comment.