forked from marshmallow-code/webargs
/
asyncparser.py
239 lines (210 loc) · 9.37 KB
/
asyncparser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
"""Asynchronous request parser. Compatible with Python>=3.5."""
import asyncio
import functools
import inspect
import typing
from collections.abc import Mapping
from marshmallow import Schema, ValidationError
from marshmallow.fields import Field
import marshmallow as ma
from marshmallow.utils import missing
from webargs import core
Request = typing.TypeVar("Request")
ArgMap = typing.Union[Schema, typing.Mapping[str, Field]]
Validate = typing.Union[typing.Callable, typing.Iterable[typing.Callable]]
class AsyncParser(core.Parser):
"""Asynchronous variant of `webargs.core.Parser`, where parsing methods may be
either coroutines or regular methods.
"""
async def _parse_request(
self, schema: Schema, req: Request, locations: typing.Iterable
) -> typing.Union[dict, list]:
if schema.many:
assert (
"json" in locations
), "schema.many=True is only supported for JSON location"
# The ad hoc Nested field is more like a workaround or a helper,
# and it servers its purpose fine. However, if somebody has a desire
# to re-design the support of bulk-type arguments, go ahead.
parsed = await self.parse_arg(
name="json",
field=ma.fields.Nested(schema, many=True),
req=req,
locations=locations,
)
if parsed is missing:
parsed = []
else:
argdict = schema.fields
parsed = {}
for argname, field_obj in argdict.items():
if core.MARSHMALLOW_VERSION_INFO[0] < 3:
parsed_value = await self.parse_arg(
argname, field_obj, req, locations
)
# If load_from is specified on the field, try to parse from that key
if parsed_value is missing and field_obj.load_from:
parsed_value = await self.parse_arg(
field_obj.load_from, field_obj, req, locations
)
argname = field_obj.load_from
else:
argname = field_obj.data_key or argname
parsed_value = await self.parse_arg(
argname, field_obj, req, locations
)
if parsed_value is not missing:
parsed[argname] = parsed_value
return parsed
# TODO: Lots of duplication from core.Parser here. Rethink.
async def parse(
self,
argmap: ArgMap,
req: Request = None,
locations: typing.Iterable = None,
validate: Validate = None,
error_status_code: typing.Union[int, None] = None,
error_headers: typing.Union[typing.Mapping[str, str], None] = None,
) -> typing.Union[typing.Mapping, None]:
"""Coroutine variant of `webargs.core.Parser`.
Receives the same arguments as `webargs.core.Parser.parse`.
"""
self.clear_cache() # in case someone used `parse_*()`
req = req if req is not None else self.get_default_request()
assert req is not None, "Must pass req object"
data = None
validators = core._ensure_list_of_callables(validate)
schema = self._get_schema(argmap, req)
try:
parsed = await self._parse_request(
schema=schema, req=req, locations=locations or self.locations
)
result = schema.load(parsed)
data = result.data if core.MARSHMALLOW_VERSION_INFO[0] < 3 else result
self._validate_arguments(data, validators)
except ma.exceptions.ValidationError as error:
await self._on_validation_error(
error, req, schema, error_status_code, error_headers
)
return data
async def _load_location_data(self, schema, req, location):
"""Return a dictionary-like object for the location on the given request.
Needs to have the schema in hand in order to correctly handle loading
lists from multidict objects and `many=True` schemas.
"""
loader_func = self._get_loader(location)
if asyncio.iscoroutinefunction(loader_func):
data = await loader_func(req, schema)
else:
data = loader_func(req, schema)
# when the desired location is empty (no data), provide an empty
# dict as the default so that optional arguments in a location
# (e.g. optional JSON body) work smoothly
if data is core.missing:
data = {}
return data
async def _on_validation_error(
self,
error: ValidationError,
req: Request,
schema: Schema,
error_status_code: typing.Union[int, None],
error_headers: typing.Union[typing.Mapping[str, str], None] = None,
) -> None:
error_handler = self.error_callback or self.handle_error
await error_handler(error, req, schema, error_status_code, error_headers)
def use_args(
self,
argmap: ArgMap,
req: typing.Optional[Request] = None,
locations: typing.Iterable = None,
as_kwargs: bool = False,
validate: Validate = None,
error_status_code: typing.Optional[int] = None,
error_headers: typing.Union[typing.Mapping[str, str], None] = None,
) -> typing.Callable[..., typing.Callable]:
"""Decorator that injects parsed arguments into a view function or method.
Receives the same arguments as `webargs.core.Parser.use_args`.
"""
locations = locations or self.locations
request_obj = req
# Optimization: If argmap is passed as a dictionary, we only need
# to generate a Schema once
if isinstance(argmap, Mapping):
argmap = core.dict2schema(argmap, self.schema_class)()
def decorator(func: typing.Callable) -> typing.Callable:
req_ = request_obj
if inspect.iscoroutinefunction(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
req_obj = req_
if not req_obj:
req_obj = self.get_request_from_view_args(func, args, kwargs)
# NOTE: At this point, argmap may be a Schema, callable, or dict
parsed_args = await self.parse(
argmap,
req=req_obj,
locations=locations,
validate=validate,
error_status_code=error_status_code,
error_headers=error_headers,
)
if as_kwargs:
kwargs.update(parsed_args or {})
return await func(*args, **kwargs)
else:
# Add parsed_args after other positional arguments
new_args = args + (parsed_args,)
return await func(*new_args, **kwargs)
else:
@functools.wraps(func) # type: ignore
def wrapper(*args, **kwargs):
req_obj = req_
if not req_obj:
req_obj = self.get_request_from_view_args(func, args, kwargs)
# NOTE: At this point, argmap may be a Schema, callable, or dict
parsed_args = yield from self.parse( # type: ignore
argmap,
req=req_obj,
locations=locations,
validate=validate,
error_status_code=error_status_code,
error_headers=error_headers,
)
if as_kwargs:
kwargs.update(parsed_args)
return func(*args, **kwargs) # noqa: B901
else:
# Add parsed_args after other positional arguments
new_args = args + (parsed_args,)
return func(*new_args, **kwargs)
return wrapper
return decorator
def use_kwargs(self, *args, **kwargs) -> typing.Callable:
"""Decorator that injects parsed arguments into a view function or method.
Receives the same arguments as `webargs.core.Parser.use_kwargs`.
"""
return super().use_kwargs(*args, **kwargs)
async def parse_arg(
self, name: str, field: Field, req: Request, locations: typing.Iterable = None
) -> typing.Any:
location = field.metadata.get("location")
if location:
locations_to_check = self._validated_locations([location])
else:
locations_to_check = self._validated_locations(locations or self.locations)
for location in locations_to_check:
value = await self._get_value(name, field, req=req, location=location)
# Found the value; validate and return it
if value is not core.missing:
return value
return core.missing
async def _get_value(
self, name: str, argobj: Field, req: Request, location: str
) -> typing.Any:
function = self._get_handler(location)
if asyncio.iscoroutinefunction(function):
value = await function(req, name, argobj)
else:
value = function(req, name, argobj)
return value