Skip to content

Commit

Permalink
Remove converter from path when generating OpenAPI schema (#1648)
Browse files Browse the repository at this point in the history
* Remove converter from path when generating `OpenAPI` schema

* Update starlette/schemas.py

Co-authored-by: Tom Christie <tom@tomchristie.com>

Co-authored-by: Tom Christie <tom@tomchristie.com>
  • Loading branch information
Kludex and tomchristie committed Jun 28, 2022
1 parent 92c1f1e commit 795cf60
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 5 deletions.
21 changes: 16 additions & 5 deletions starlette/schemas.py
@@ -1,4 +1,5 @@
import inspect
import re
import typing

from starlette.requests import Request
Expand Down Expand Up @@ -49,10 +50,11 @@ def get_endpoints(

for route in routes:
if isinstance(route, Mount):
path = self._remove_converter(route.path)
routes = route.routes or []
sub_endpoints = [
EndpointInfo(
path="".join((route.path, sub_endpoint.path)),
path="".join((path, sub_endpoint.path)),
http_method=sub_endpoint.http_method,
func=sub_endpoint.func,
)
Expand All @@ -64,23 +66,32 @@ def get_endpoints(
continue

elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint):
path = self._remove_converter(route.path)
for method in route.methods or ["GET"]:
if method == "HEAD":
continue
endpoints_info.append(
EndpointInfo(route.path, method.lower(), route.endpoint)
EndpointInfo(path, method.lower(), route.endpoint)
)
else:
path = self._remove_converter(route.path)
for method in ["get", "post", "put", "patch", "delete", "options"]:
if not hasattr(route.endpoint, method):
continue
func = getattr(route.endpoint, method)
endpoints_info.append(
EndpointInfo(route.path, method.lower(), func)
)
endpoints_info.append(EndpointInfo(path, method.lower(), func))

return endpoints_info

def _remove_converter(self, path: str) -> str:
"""
Remove the converter from the path.
For example, a route like this:
Route("/users/{id:int}", endpoint=get_user, methods=["GET"])
Should be represented as `/users/{id}` in the OpenAPI schema.
"""
return re.sub(r":\w+}", "}", path)

def parse_docstring(self, func_or_method: typing.Callable) -> dict:
"""
Given a function, parse the docstring as YAML and return a dictionary of info.
Expand Down
29 changes: 29 additions & 0 deletions tests/test_schemas.py
Expand Up @@ -13,6 +13,17 @@ def ws(session):
pass # pragma: no cover


def get_user(request):
"""
responses:
200:
description: A user.
examples:
{"username": "tom"}
"""
pass # pragma: no cover


def list_users(request):
"""
responses:
Expand Down Expand Up @@ -103,6 +114,7 @@ def schema(request):
app = Starlette(
routes=[
WebSocketRoute("/ws", endpoint=ws),
Route("/users/{id:int}", endpoint=get_user, methods=["GET"]),
Route("/users", endpoint=list_users, methods=["GET", "HEAD"]),
Route("/users", endpoint=create_user, methods=["POST"]),
Route("/orgs", endpoint=OrganisationsEndpoint),
Expand Down Expand Up @@ -168,6 +180,16 @@ def test_schema_generation():
}
},
},
"/users/{id}": {
"get": {
"responses": {
200: {
"description": "A user.",
"examples": {"username": "tom"},
}
}
},
},
},
}

Expand Down Expand Up @@ -216,6 +238,13 @@ def test_schema_generation():
description: A user.
examples:
username: tom
/users/{id}:
get:
responses:
200:
description: A user.
examples:
username: tom
"""


Expand Down

0 comments on commit 795cf60

Please sign in to comment.