Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ignore "location" field metadata #526

Merged
merged 4 commits into from
Aug 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
36 changes: 12 additions & 24 deletions src/apispec/ext/marshmallow/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,27 +110,21 @@ def resolve_nested_schema(self, schema):
return self.get_ref_dict(schema_instance)

def schema2parameters(
self,
schema,
*,
default_in="body",
name="body",
required=False,
description=None
self, schema, *, location, name="body", required=False, description=None
):
"""Return an array of OpenAPI parameters given a given marshmallow
:class:`Schema <marshmallow.Schema>`. If `default_in` is "body", then return an array
:class:`Schema <marshmallow.Schema>`. If `location` is "body", then return an array
of a single parameter; else return an array of a parameter for each included field in
the :class:`Schema <marshmallow.Schema>`.
https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#parameterObject
"""
openapi_default_in = __location_map__.get(default_in, default_in)
if self.openapi_version.major < 3 and openapi_default_in == "body":
openapi_location = __location_map__.get(location, location)
if self.openapi_version.major < 3 and openapi_location == "body":
prop = self.resolve_nested_schema(schema)

param = {
"in": openapi_default_in,
"in": openapi_location,
"required": required,
"name": name,
"schema": prop,
Expand All @@ -147,11 +141,11 @@ def schema2parameters(

fields = get_fields(schema, exclude_dump_only=True)

return self.fields2parameters(fields, default_in=default_in)
return self.fields2parameters(fields, location=location)

def fields2parameters(self, fields, *, default_in):
def fields2parameters(self, fields, *, location):
"""Return an array of OpenAPI parameters given a mapping between field names and
:class:`Field <marshmallow.Field>` objects. If `default_in` is "body", then return an array
:class:`Field <marshmallow.Field>` objects. If `location` is "body", then return an array
of a single parameter; else return an array of a parameter for each included field in
the :class:`Schema <marshmallow.Schema>`.
Expand All @@ -171,7 +165,7 @@ def fields2parameters(self, fields, *, default_in):
param = self.field2parameter(
field_obj,
name=self._observed_name(field_obj, field_name),
default_in=default_in,
location=location,
)
if (
self.openapi_version.major < 3
Expand All @@ -190,26 +184,22 @@ def fields2parameters(self, fields, *, default_in):
parameters.append(param)
return parameters

def field2parameter(self, field, *, name, default_in):
def field2parameter(self, field, *, name, location):
"""Return an OpenAPI parameter as a `dict`, given a marshmallow
:class:`Field <marshmallow.Field>`.
https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#parameterObject
"""
location = field.metadata.get("location", None)
prop = self.field2property(field)
return self.property2parameter(
prop,
name=name,
required=field.required,
multiple=isinstance(field, marshmallow.fields.List),
location=location,
default_in=default_in,
)

def property2parameter(
self, prop, *, name, required, multiple, location, default_in
):
def property2parameter(self, prop, *, name, required, multiple, location):
"""Return the Parameter Object definition for a JSON Schema property.
https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#parameterObject
Expand All @@ -219,12 +209,10 @@ def property2parameter(
:param bool required: Parameter is required
:param bool multiple: Parameter is repeated
:param str location: Location to look for ``name``
:param str default_in: Default location to look for ``name``
:raise: TranslationError if arg object cannot be translated to a Parameter Object schema.
:rtype: dict, a Parameter Object
"""
openapi_default_in = __location_map__.get(default_in, default_in)
openapi_location = __location_map__.get(location, openapi_default_in)
openapi_location = __location_map__.get(location, location)
ret = {"in": openapi_location, "name": name}

if openapi_location == "body":
Expand Down
2 changes: 1 addition & 1 deletion src/apispec/ext/marshmallow/schema_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class UserSchema(Schema):
):
schema_instance = resolve_schema_instance(parameter.pop("schema"))
resolved += self.converter.schema2parameters(
schema_instance, default_in=parameter.pop("in"), **parameter
schema_instance, location=parameter.pop("in"), **parameter
)
else:
self.resolve_schema(parameter)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_ext_marshmallow.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,12 +440,12 @@ def test_schema_expand_parameters_v2(self, spec_fixture):
p = get_paths(spec_fixture.spec)["/pet"]
get = p["get"]
assert get["parameters"] == spec_fixture.openapi.schema2parameters(
PetSchema(), default_in="query"
PetSchema(), location="query"
)
post = p["post"]
assert post["parameters"] == spec_fixture.openapi.schema2parameters(
PetSchema,
default_in="body",
location="body",
required=True,
name="pet",
description="a pet schema",
Expand All @@ -469,7 +469,7 @@ def test_schema_expand_parameters_v3(self, spec_fixture):
p = get_paths(spec_fixture.spec)["/pet"]
get = p["get"]
assert get["parameters"] == spec_fixture.openapi.schema2parameters(
PetSchema(), default_in="query"
PetSchema(), location="query"
)
for parameter in get["parameters"]:
description = parameter.get("description", False)
Expand Down Expand Up @@ -819,7 +819,7 @@ def test_schema_global_state_untouched_2json(self, spec_fixture):

def test_schema_global_state_untouched_2parameters(self, spec_fixture):
assert get_nested_schema(RunSchema, "sample") is None
data = spec_fixture.openapi.schema2parameters(RunSchema)
data = spec_fixture.openapi.schema2parameters(RunSchema, location="json")
json.dumps(data)
assert get_nested_schema(RunSchema, "sample") is None

Expand Down
91 changes: 27 additions & 64 deletions tests/test_ext_marshmallow_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,53 +13,20 @@
class TestMarshmallowFieldToOpenAPI:
def test_fields_with_missing_load(self, openapi):
field_dict = {"field": fields.Str(default="foo", missing="bar")}
res = openapi.fields2parameters(field_dict, default_in="query")
res = openapi.fields2parameters(field_dict, location="query")
if openapi.openapi_version.major < 3:
assert res[0]["default"] == "bar"
else:
assert res[0]["schema"]["default"] == "bar"

def test_fields_with_location(self, openapi):
field_dict = {"field": fields.Str(location="querystring")}
res = openapi.fields2parameters(field_dict, default_in="headers")
assert res[0]["in"] == "query"

# json/body is invalid for OpenAPI 3
@pytest.mark.parametrize("openapi", ("2.0",), indirect=True)
def test_fields_with_multiple_json_locations(self, openapi):
field_dict = {
"field1": fields.Str(location="json", required=True),
"field2": fields.Str(location="json", required=True),
"field3": fields.Str(location="json"),
}
res = openapi.fields2parameters(field_dict, default_in=None)
assert len(res) == 1
assert res[0]["in"] == "body"
assert res[0]["required"] is False
assert "field1" in res[0]["schema"]["properties"]
assert "field2" in res[0]["schema"]["properties"]
assert "field3" in res[0]["schema"]["properties"]
assert "required" in res[0]["schema"]
assert len(res[0]["schema"]["required"]) == 2
assert "field1" in res[0]["schema"]["required"]
assert "field2" in res[0]["schema"]["required"]

def test_fields2parameters_does_not_modify_metadata(self, openapi):
field_dict = {"field": fields.Str(location="querystring")}
res = openapi.fields2parameters(field_dict, default_in="headers")
assert res[0]["in"] == "query"

res = openapi.fields2parameters(field_dict, default_in="headers")
assert res[0]["in"] == "query"

def test_fields_location_mapping(self, openapi):
field_dict = {"field": fields.Str(location="cookies")}
res = openapi.fields2parameters(field_dict, default_in="headers")
field_dict = {"field": fields.Str()}
res = openapi.fields2parameters(field_dict, location="cookies")
assert res[0]["in"] == "cookie"

def test_fields_default_location_mapping(self, openapi):
field_dict = {"field": fields.Str()}
res = openapi.fields2parameters(field_dict, default_in="headers")
res = openapi.fields2parameters(field_dict, location="headers")
assert res[0]["in"] == "header"

# json/body is invalid for OpenAPI 3
Expand All @@ -69,16 +36,16 @@ class ExampleSchema(Schema):
id = fields.Int()

schema = ExampleSchema(many=True)
res = openapi.schema2parameters(schema=schema, default_in="json")
res = openapi.schema2parameters(schema=schema, location="json")
assert res[0]["in"] == "body"

def test_fields_with_dump_only(self, openapi):
class UserSchema(Schema):
name = fields.Str(dump_only=True)

res = openapi.fields2parameters(UserSchema._declared_fields, default_in="query")
res = openapi.fields2parameters(UserSchema._declared_fields, location="query")
assert len(res) == 0
res = openapi.fields2parameters(UserSchema().fields, default_in="query")
res = openapi.fields2parameters(UserSchema().fields, location="query")
assert len(res) == 0

class UserSchema(Schema):
Expand All @@ -87,7 +54,7 @@ class UserSchema(Schema):
class Meta:
dump_only = ("name",)

res = openapi.schema2parameters(schema=UserSchema, default_in="query")
res = openapi.schema2parameters(schema=UserSchema, location="query")
assert len(res) == 0


Expand Down Expand Up @@ -261,8 +228,8 @@ class NotASchema:
class TestMarshmallowSchemaToParameters:
@pytest.mark.parametrize("ListClass", [fields.List, CustomList])
def test_field_multiple(self, ListClass, openapi):
field = ListClass(fields.Str, location="querystring")
res = openapi.field2parameter(field, name="field", default_in=None)
field = ListClass(fields.Str)
res = openapi.field2parameter(field, name="field", location="querystring")
assert res["in"] == "query"
if openapi.openapi_version.major < 3:
assert res["type"] == "array"
Expand All @@ -275,13 +242,13 @@ def test_field_multiple(self, ListClass, openapi):
assert res["explode"] is True

def test_field_required(self, openapi):
field = fields.Str(required=True, location="query")
res = openapi.field2parameter(field, name="field", default_in=None)
field = fields.Str(required=True)
res = openapi.field2parameter(field, name="field", location="query")
assert res["required"] is True

def test_invalid_schema(self, openapi):
with pytest.raises(ValueError):
openapi.schema2parameters(None)
openapi.schema2parameters(None, location="json")

# json/body is invalid for OpenAPI 3
@pytest.mark.parametrize("openapi", ("2.0",), indirect=True)
Expand All @@ -290,7 +257,7 @@ class UserSchema(Schema):
name = fields.Str()
email = fields.Email()

res = openapi.schema2parameters(UserSchema, default_in="body")
res = openapi.schema2parameters(UserSchema, location="body")
assert len(res) == 1
param = res[0]
assert param["in"] == "body"
Expand All @@ -303,7 +270,7 @@ class UserSchema(Schema):
name = fields.Str()
email = fields.Email(dump_only=True)

res_nodump = openapi.schema2parameters(UserSchema, default_in="body")
res_nodump = openapi.schema2parameters(UserSchema, location="body")
assert len(res_nodump) == 1
param = res_nodump[0]
assert param["in"] == "body"
Expand All @@ -316,7 +283,7 @@ class UserSchema(Schema):
name = fields.Str()
email = fields.Email()

res = openapi.schema2parameters(UserSchema(many=True), default_in="body")
res = openapi.schema2parameters(UserSchema(many=True), location="body")
assert len(res) == 1
param = res[0]
assert param["in"] == "body"
Expand All @@ -328,7 +295,7 @@ class UserSchema(Schema):
name = fields.Str()
email = fields.Email()

res = openapi.schema2parameters(UserSchema, default_in="query")
res = openapi.schema2parameters(UserSchema, location="query")
assert len(res) == 2
res.sort(key=lambda param: param["name"])
assert res[0]["name"] == "email"
Expand All @@ -341,7 +308,7 @@ class UserSchema(Schema):
name = fields.Str()
email = fields.Email()

res = openapi.schema2parameters(UserSchema(), default_in="query")
res = openapi.schema2parameters(UserSchema(), location="query")
assert len(res) == 2
res.sort(key=lambda param: param["name"])
assert res[0]["name"] == "email"
Expand All @@ -355,11 +322,11 @@ class UserSchema(Schema):
email = fields.Email()

with pytest.raises(AssertionError):
openapi.schema2parameters(UserSchema(many=True), default_in="query")
openapi.schema2parameters(UserSchema(many=True), location="query")

def test_fields_query(self, openapi):
field_dict = {"name": fields.Str(), "email": fields.Email()}
res = openapi.fields2parameters(field_dict, default_in="query")
res = openapi.fields2parameters(field_dict, location="query")
assert len(res) == 2
res.sort(key=lambda param: param["name"])
assert res[0]["name"] == "email"
Expand Down Expand Up @@ -479,15 +446,13 @@ def test_openapi_tools_validate_v2():
},
openapi.field2parameter(
field=fields.List(
fields.Str(),
validate=validate.OneOf(["freddie", "roger"]),
location="querystring",
fields.Str(), validate=validate.OneOf(["freddie", "roger"]),
),
default_in=None,
location="querystring",
name="body",
),
]
+ openapi.schema2parameters(PageSchema, default_in="query"),
+ openapi.schema2parameters(PageSchema, location="query"),
"responses": {200: {"schema": PetSchema, "description": "A pet"}},
},
"post": {
Expand All @@ -500,7 +465,7 @@ def test_openapi_tools_validate_v2():
"type": "string",
}
]
+ openapi.schema2parameters(CategorySchema, default_in="body")
+ openapi.schema2parameters(CategorySchema, location="body")
),
"responses": {201: {"schema": PetSchema, "description": "A pet"}},
},
Expand Down Expand Up @@ -537,15 +502,13 @@ def test_openapi_tools_validate_v3():
},
openapi.field2parameter(
field=fields.List(
fields.Str(),
validate=validate.OneOf(["freddie", "roger"]),
location="querystring",
fields.Str(), validate=validate.OneOf(["freddie", "roger"]),
),
default_in=None,
location="querystring",
name="body",
),
]
+ openapi.schema2parameters(PageSchema, default_in="query"),
+ openapi.schema2parameters(PageSchema, location="query"),
"responses": {
200: {
"description": "success",
Expand Down