Skip to content

Commit

Permalink
Compatibility with apispec 4
Browse files Browse the repository at this point in the history
1. Rename 'default_in'
(marshmallow-code/apispec#526)
2. Dict schema: convert it to object and handle special case 'body',
since prior used method no longer exists
(marshmallow-code/apispec#581)
  • Loading branch information
kam193 committed Oct 11, 2020
1 parent 4c3cabe commit 1267861
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 22 deletions.
55 changes: 39 additions & 16 deletions flask_apispec/apidoc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import functools

import apispec
from apispec.core import VALID_METHODS
Expand All @@ -15,7 +16,6 @@
)

class Converter:

def __init__(self, app, spec, document_options=True):
self.app = app
self.spec = spec
Expand Down Expand Up @@ -80,23 +80,19 @@ def get_parameters(self, rule, view, docs, parent=None):
extra_params = []
for args in annotation.options:
schema = args.get('args', {})
if is_instance_or_subclass(schema, Schema):
converter = openapi.schema2parameters
elif callable(schema):
schema = schema(request=None)
if is_instance_or_subclass(schema, Schema):
converter = openapi.schema2parameters
openapi_converter = openapi.schema2parameters
if not is_instance_or_subclass(schema, Schema):
if callable(schema):
schema = schema(request=None)
else:
converter = openapi.fields2parameters
else:
converter = openapi.fields2parameters
schema = Schema.from_dict(schema)
openapi_converter = functools.partial(
self._convert_dict_schema, openapi_converter)

options = copy.copy(args.get('kwargs', {}))
location = options.pop('location', None)
if location:
options['default_in'] = location
elif 'default_in' not in options:
options['default_in'] = 'body'
extra_params += converter(schema, **options) if args else []
if not options.get('location'):
options['location'] = 'body'
extra_params += openapi_converter(schema, **options) if args else []

rule_params = rule_to_params(rule, docs.get('params')) or []

Expand All @@ -106,6 +102,33 @@ def get_responses(self, view, parent=None):
annotation = resolve_annotations(view, 'schemas', parent)
return merge_recursive(annotation.options)

def _convert_dict_schema(self, openapi_converter, schema, location, **options):
"""When location is 'body' and OpenApi is 2, return one param for body fields.
Otherwise return fields exactly as converted by apispec."""
if self.spec.openapi_version.major < 3 and location == 'body':
params = openapi_converter(schema, location=None, **options)
body_parameter = {
"in": "body",
"name": "body",
"required": False,
"schema": {
"type": "object",
"properties": {},
},
}
for param in params:
name = param["name"]
body_parameter["schema"]["properties"].update({name: param})
if param.get("required", False):
body_parameter["schema"].setdefault("required", []).append(name)
del param["name"]
del param["in"]
del param["required"]
return [body_parameter]

return openapi_converter(schema, location=location, **options)

class ViewConverter(Converter):

def get_operations(self, rule, view):
Expand Down
38 changes: 32 additions & 6 deletions tests/test_openapi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from unittest import mock
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from marshmallow import fields, Schema
Expand All @@ -7,7 +8,7 @@
from flask_apispec.paths import rule_to_params
from flask_apispec.views import MethodResource
from flask_apispec import doc, use_kwargs, marshal_with
from flask_apispec.apidoc import APISPEC_VERSION_INFO, ViewConverter, ResourceConverter
from flask_apispec.apidoc import APISPEC_VERSION_INFO, Converter, ViewConverter, ResourceConverter

@pytest.fixture()
def marshmallow_plugin():
Expand Down Expand Up @@ -113,8 +114,8 @@ def test_params(self, app, path, openapi):
params = path['get']['parameters']
rule = app.url_map._rules_by_endpoint['get_band'][0]
expected = (
openapi.fields2parameters(
{'name': fields.Str()}, default_in='query') +
openapi.schema2parameters(
Schema.from_dict({'name': fields.Str()}), location='query') +
rule_to_params(rule)
)
assert params == expected
Expand Down Expand Up @@ -184,8 +185,7 @@ def test_params(self, app, path, openapi):
params = path['get']['parameters']
rule = app.url_map._rules_by_endpoint['band'][0]
expected = (
openapi.fields2parameters(
{'name': fields.Str()}, default_in='query') +
[{'in': 'query', 'name': 'name', 'required': False, 'type': 'string'}] +
rule_to_params(rule)
)
assert params == expected
Expand Down Expand Up @@ -242,7 +242,6 @@ def test_params(self, app, path):
)
assert params == expected


class TestGetFieldsNoLocationProvided:

@pytest.fixture
Expand Down Expand Up @@ -277,6 +276,33 @@ def test_params(self, app, path):
},
} in params

class TestGetFieldsBodyLocation(TestGetFieldsNoLocationProvided):

@pytest.fixture
def function_view(self, app):
@app.route('/bands/<int:band_id>/')
@use_kwargs({'name': fields.Str(required=True), 'address': fields.Str(), 'email': fields.Str(required=True)})
def get_band(**kwargs):
return kwargs

return get_band

def test_params(self, app, path):
params = path['get']['parameters']
assert {
'in': 'body',
'name': 'body',
'required': False,
'schema': {
'properties': {
'address': {'type': 'string'},
'name': {'type': 'string'},
'email': {'type': 'string'},
},
'required': ["name", "email"],
'type': 'object',
},
} in params

class TestSchemaNoLocationProvided:

Expand Down

0 comments on commit 1267861

Please sign in to comment.