Skip to content

Commit

Permalink
Merge pull request #526 from marshmallow-code/ignore_field_location_m…
Browse files Browse the repository at this point in the history
…etadata

Ignore "location" field metadata
  • Loading branch information
lafrech committed Aug 31, 2020
2 parents 16a3a2c + a22bb4e commit e204fa0
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 93 deletions.
36 changes: 12 additions & 24 deletions src/apispec/ext/marshmallow/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,27 +110,21 @@ def resolve_nested_schema(self, schema):
return self.get_ref_dict(schema_instance)

def schema2parameters(
self,
schema,
*,
default_in="body",
name="body",
required=False,
description=None
self, schema, *, location, name="body", required=False, description=None
):
"""Return an array of OpenAPI parameters given a given marshmallow
:class:`Schema <marshmallow.Schema>`. If `default_in` is "body", then return an array
:class:`Schema <marshmallow.Schema>`. If `location` is "body", then return an array
of a single parameter; else return an array of a parameter for each included field in
the :class:`Schema <marshmallow.Schema>`.
https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#parameterObject
"""
openapi_default_in = __location_map__.get(default_in, default_in)
if self.openapi_version.major < 3 and openapi_default_in == "body":
openapi_location = __location_map__.get(location, location)
if self.openapi_version.major < 3 and openapi_location == "body":
prop = self.resolve_nested_schema(schema)

param = {
"in": openapi_default_in,
"in": openapi_location,
"required": required,
"name": name,
"schema": prop,
Expand All @@ -147,11 +141,11 @@ def schema2parameters(

fields = get_fields(schema, exclude_dump_only=True)

return self.fields2parameters(fields, default_in=default_in)
return self.fields2parameters(fields, location=location)

def fields2parameters(self, fields, *, default_in):
def fields2parameters(self, fields, *, location):
"""Return an array of OpenAPI parameters given a mapping between field names and
:class:`Field <marshmallow.Field>` objects. If `default_in` is "body", then return an array
:class:`Field <marshmallow.Field>` objects. If `location` is "body", then return an array
of a single parameter; else return an array of a parameter for each included field in
the :class:`Schema <marshmallow.Schema>`.
Expand All @@ -171,7 +165,7 @@ def fields2parameters(self, fields, *, default_in):
param = self.field2parameter(
field_obj,
name=self._observed_name(field_obj, field_name),
default_in=default_in,
location=location,
)
if (
self.openapi_version.major < 3
Expand All @@ -190,26 +184,22 @@ def fields2parameters(self, fields, *, default_in):
parameters.append(param)
return parameters

def field2parameter(self, field, *, name, default_in):
def field2parameter(self, field, *, name, location):
"""Return an OpenAPI parameter as a `dict`, given a marshmallow
:class:`Field <marshmallow.Field>`.
https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#parameterObject
"""
location = field.metadata.get("location", None)
prop = self.field2property(field)
return self.property2parameter(
prop,
name=name,
required=field.required,
multiple=isinstance(field, marshmallow.fields.List),
location=location,
default_in=default_in,
)

def property2parameter(
self, prop, *, name, required, multiple, location, default_in
):
def property2parameter(self, prop, *, name, required, multiple, location):
"""Return the Parameter Object definition for a JSON Schema property.
https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#parameterObject
Expand All @@ -219,12 +209,10 @@ def property2parameter(
:param bool required: Parameter is required
:param bool multiple: Parameter is repeated
:param str location: Location to look for ``name``
:param str default_in: Default location to look for ``name``
:raise: TranslationError if arg object cannot be translated to a Parameter Object schema.
:rtype: dict, a Parameter Object
"""
openapi_default_in = __location_map__.get(default_in, default_in)
openapi_location = __location_map__.get(location, openapi_default_in)
openapi_location = __location_map__.get(location, location)
ret = {"in": openapi_location, "name": name}

if openapi_location == "body":
Expand Down
2 changes: 1 addition & 1 deletion src/apispec/ext/marshmallow/schema_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class UserSchema(Schema):
):
schema_instance = resolve_schema_instance(parameter.pop("schema"))
resolved += self.converter.schema2parameters(
schema_instance, default_in=parameter.pop("in"), **parameter
schema_instance, location=parameter.pop("in"), **parameter
)
else:
self.resolve_schema(parameter)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_ext_marshmallow.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,12 +440,12 @@ def test_schema_expand_parameters_v2(self, spec_fixture):
p = get_paths(spec_fixture.spec)["/pet"]
get = p["get"]
assert get["parameters"] == spec_fixture.openapi.schema2parameters(
PetSchema(), default_in="query"
PetSchema(), location="query"
)
post = p["post"]
assert post["parameters"] == spec_fixture.openapi.schema2parameters(
PetSchema,
default_in="body",
location="body",
required=True,
name="pet",
description="a pet schema",
Expand All @@ -469,7 +469,7 @@ def test_schema_expand_parameters_v3(self, spec_fixture):
p = get_paths(spec_fixture.spec)["/pet"]
get = p["get"]
assert get["parameters"] == spec_fixture.openapi.schema2parameters(
PetSchema(), default_in="query"
PetSchema(), location="query"
)
for parameter in get["parameters"]:
description = parameter.get("description", False)
Expand Down Expand Up @@ -819,7 +819,7 @@ def test_schema_global_state_untouched_2json(self, spec_fixture):

def test_schema_global_state_untouched_2parameters(self, spec_fixture):
assert get_nested_schema(RunSchema, "sample") is None
data = spec_fixture.openapi.schema2parameters(RunSchema)
data = spec_fixture.openapi.schema2parameters(RunSchema, location="json")
json.dumps(data)
assert get_nested_schema(RunSchema, "sample") is None

Expand Down
91 changes: 27 additions & 64 deletions tests/test_ext_marshmallow_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,53 +13,20 @@
class TestMarshmallowFieldToOpenAPI:
def test_fields_with_missing_load(self, openapi):
field_dict = {"field": fields.Str(default="foo", missing="bar")}
res = openapi.fields2parameters(field_dict, default_in="query")
res = openapi.fields2parameters(field_dict, location="query")
if openapi.openapi_version.major < 3:
assert res[0]["default"] == "bar"
else:
assert res[0]["schema"]["default"] == "bar"

def test_fields_with_location(self, openapi):
field_dict = {"field": fields.Str(location="querystring")}
res = openapi.fields2parameters(field_dict, default_in="headers")
assert res[0]["in"] == "query"

# json/body is invalid for OpenAPI 3
@pytest.mark.parametrize("openapi", ("2.0",), indirect=True)
def test_fields_with_multiple_json_locations(self, openapi):
field_dict = {
"field1": fields.Str(location="json", required=True),
"field2": fields.Str(location="json", required=True),
"field3": fields.Str(location="json"),
}
res = openapi.fields2parameters(field_dict, default_in=None)
assert len(res) == 1
assert res[0]["in"] == "body"
assert res[0]["required"] is False
assert "field1" in res[0]["schema"]["properties"]
assert "field2" in res[0]["schema"]["properties"]
assert "field3" in res[0]["schema"]["properties"]
assert "required" in res[0]["schema"]
assert len(res[0]["schema"]["required"]) == 2
assert "field1" in res[0]["schema"]["required"]
assert "field2" in res[0]["schema"]["required"]

def test_fields2parameters_does_not_modify_metadata(self, openapi):
field_dict = {"field": fields.Str(location="querystring")}
res = openapi.fields2parameters(field_dict, default_in="headers")
assert res[0]["in"] == "query"

res = openapi.fields2parameters(field_dict, default_in="headers")
assert res[0]["in"] == "query"

def test_fields_location_mapping(self, openapi):
field_dict = {"field": fields.Str(location="cookies")}
res = openapi.fields2parameters(field_dict, default_in="headers")
field_dict = {"field": fields.Str()}
res = openapi.fields2parameters(field_dict, location="cookies")
assert res[0]["in"] == "cookie"

def test_fields_default_location_mapping(self, openapi):
field_dict = {"field": fields.Str()}
res = openapi.fields2parameters(field_dict, default_in="headers")
res = openapi.fields2parameters(field_dict, location="headers")
assert res[0]["in"] == "header"

# json/body is invalid for OpenAPI 3
Expand All @@ -69,16 +36,16 @@ class ExampleSchema(Schema):
id = fields.Int()

schema = ExampleSchema(many=True)
res = openapi.schema2parameters(schema=schema, default_in="json")
res = openapi.schema2parameters(schema=schema, location="json")
assert res[0]["in"] == "body"

def test_fields_with_dump_only(self, openapi):
class UserSchema(Schema):
name = fields.Str(dump_only=True)

res = openapi.fields2parameters(UserSchema._declared_fields, default_in="query")
res = openapi.fields2parameters(UserSchema._declared_fields, location="query")
assert len(res) == 0
res = openapi.fields2parameters(UserSchema().fields, default_in="query")
res = openapi.fields2parameters(UserSchema().fields, location="query")
assert len(res) == 0

class UserSchema(Schema):
Expand All @@ -87,7 +54,7 @@ class UserSchema(Schema):
class Meta:
dump_only = ("name",)

res = openapi.schema2parameters(schema=UserSchema, default_in="query")
res = openapi.schema2parameters(schema=UserSchema, location="query")
assert len(res) == 0


Expand Down Expand Up @@ -261,8 +228,8 @@ class NotASchema:
class TestMarshmallowSchemaToParameters:
@pytest.mark.parametrize("ListClass", [fields.List, CustomList])
def test_field_multiple(self, ListClass, openapi):
field = ListClass(fields.Str, location="querystring")
res = openapi.field2parameter(field, name="field", default_in=None)
field = ListClass(fields.Str)
res = openapi.field2parameter(field, name="field", location="querystring")
assert res["in"] == "query"
if openapi.openapi_version.major < 3:
assert res["type"] == "array"
Expand All @@ -275,13 +242,13 @@ def test_field_multiple(self, ListClass, openapi):
assert res["explode"] is True

def test_field_required(self, openapi):
field = fields.Str(required=True, location="query")
res = openapi.field2parameter(field, name="field", default_in=None)
field = fields.Str(required=True)
res = openapi.field2parameter(field, name="field", location="query")
assert res["required"] is True

def test_invalid_schema(self, openapi):
with pytest.raises(ValueError):
openapi.schema2parameters(None)
openapi.schema2parameters(None, location="json")

# json/body is invalid for OpenAPI 3
@pytest.mark.parametrize("openapi", ("2.0",), indirect=True)
Expand All @@ -290,7 +257,7 @@ class UserSchema(Schema):
name = fields.Str()
email = fields.Email()

res = openapi.schema2parameters(UserSchema, default_in="body")
res = openapi.schema2parameters(UserSchema, location="body")
assert len(res) == 1
param = res[0]
assert param["in"] == "body"
Expand All @@ -303,7 +270,7 @@ class UserSchema(Schema):
name = fields.Str()
email = fields.Email(dump_only=True)

res_nodump = openapi.schema2parameters(UserSchema, default_in="body")
res_nodump = openapi.schema2parameters(UserSchema, location="body")
assert len(res_nodump) == 1
param = res_nodump[0]
assert param["in"] == "body"
Expand All @@ -316,7 +283,7 @@ class UserSchema(Schema):
name = fields.Str()
email = fields.Email()

res = openapi.schema2parameters(UserSchema(many=True), default_in="body")
res = openapi.schema2parameters(UserSchema(many=True), location="body")
assert len(res) == 1
param = res[0]
assert param["in"] == "body"
Expand All @@ -328,7 +295,7 @@ class UserSchema(Schema):
name = fields.Str()
email = fields.Email()

res = openapi.schema2parameters(UserSchema, default_in="query")
res = openapi.schema2parameters(UserSchema, location="query")
assert len(res) == 2
res.sort(key=lambda param: param["name"])
assert res[0]["name"] == "email"
Expand All @@ -341,7 +308,7 @@ class UserSchema(Schema):
name = fields.Str()
email = fields.Email()

res = openapi.schema2parameters(UserSchema(), default_in="query")
res = openapi.schema2parameters(UserSchema(), location="query")
assert len(res) == 2
res.sort(key=lambda param: param["name"])
assert res[0]["name"] == "email"
Expand All @@ -355,11 +322,11 @@ class UserSchema(Schema):
email = fields.Email()

with pytest.raises(AssertionError):
openapi.schema2parameters(UserSchema(many=True), default_in="query")
openapi.schema2parameters(UserSchema(many=True), location="query")

def test_fields_query(self, openapi):
field_dict = {"name": fields.Str(), "email": fields.Email()}
res = openapi.fields2parameters(field_dict, default_in="query")
res = openapi.fields2parameters(field_dict, location="query")
assert len(res) == 2
res.sort(key=lambda param: param["name"])
assert res[0]["name"] == "email"
Expand Down Expand Up @@ -479,15 +446,13 @@ def test_openapi_tools_validate_v2():
},
openapi.field2parameter(
field=fields.List(
fields.Str(),
validate=validate.OneOf(["freddie", "roger"]),
location="querystring",
fields.Str(), validate=validate.OneOf(["freddie", "roger"]),
),
default_in=None,
location="querystring",
name="body",
),
]
+ openapi.schema2parameters(PageSchema, default_in="query"),
+ openapi.schema2parameters(PageSchema, location="query"),
"responses": {200: {"schema": PetSchema, "description": "A pet"}},
},
"post": {
Expand All @@ -500,7 +465,7 @@ def test_openapi_tools_validate_v2():
"type": "string",
}
]
+ openapi.schema2parameters(CategorySchema, default_in="body")
+ openapi.schema2parameters(CategorySchema, location="body")
),
"responses": {201: {"schema": PetSchema, "description": "A pet"}},
},
Expand Down Expand Up @@ -537,15 +502,13 @@ def test_openapi_tools_validate_v3():
},
openapi.field2parameter(
field=fields.List(
fields.Str(),
validate=validate.OneOf(["freddie", "roger"]),
location="querystring",
fields.Str(), validate=validate.OneOf(["freddie", "roger"]),
),
default_in=None,
location="querystring",
name="body",
),
]
+ openapi.schema2parameters(PageSchema, default_in="query"),
+ openapi.schema2parameters(PageSchema, location="query"),
"responses": {
200: {
"description": "success",
Expand Down

0 comments on commit e204fa0

Please sign in to comment.