Skip to content

Commit

Permalink
Improve type hints for fastapi example (#1601)
Browse files Browse the repository at this point in the history
* Improve type hints for fastapi example

* fix static code analysis

* Add more type hint for test file
  • Loading branch information
waketzheng committed May 14, 2024
1 parent 9053fa7 commit b6680f1
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 10 deletions.
8 changes: 5 additions & 3 deletions examples/fastapi/_tests.py
@@ -1,5 +1,7 @@
# mypy: no-disallow-untyped-decorators
# pylint: disable=E0611,E0401
from typing import AsyncGenerator

import pytest
from asgi_lifespan import LifespanManager
from httpx import ASGITransport, AsyncClient
Expand All @@ -8,20 +10,20 @@


@pytest.fixture(scope="module")
def anyio_backend():
def anyio_backend() -> str:
return "asyncio"


@pytest.fixture(scope="module")
async def client():
async def client() -> AsyncGenerator[AsyncClient, None]:
async with LifespanManager(app):
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as c:
yield c


@pytest.mark.anyio
async def test_create_user(client: AsyncClient): # nosec
async def test_create_user(client: AsyncClient) -> None: # nosec
response = await client.post("/users", json={"username": "admin"})
assert response.status_code == 200, response.text
data = response.json()
Expand Down
18 changes: 15 additions & 3 deletions examples/fastapi/main.py
@@ -1,16 +1,28 @@
# pylint: disable=E0611,E0401
from contextlib import asynccontextmanager
from typing import List
from typing import TYPE_CHECKING, AsyncGenerator, List

from fastapi import FastAPI, HTTPException
from models import User_Pydantic, UserIn_Pydantic, Users
from models import Users
from pydantic import BaseModel

from tortoise.contrib.fastapi import RegisterTortoise
from tortoise.contrib.pydantic import PydanticModel

if TYPE_CHECKING: # pragma: nocoverage

class UserIn_Pydantic(Users, PydanticModel): # type:ignore[misc]
pass

class User_Pydantic(Users, PydanticModel): # type:ignore[misc]
pass

else:
from models import User_Pydantic, UserIn_Pydantic


@asynccontextmanager
async def lifespan(app: FastAPI):
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# app startup
async with RegisterTortoise(
app,
Expand Down
7 changes: 7 additions & 0 deletions pyproject.toml
Expand Up @@ -147,6 +147,13 @@ disallow_untyped_calls = false
disallow_untyped_defs = false
disallow_incomplete_defs = false

[[tool.mypy.overrides]]
module = ["examples.fastapi.*"]
check_untyped_defs = true
disallow_untyped_calls = true
disallow_untyped_defs = false
disallow_incomplete_defs = false

[[tool.mypy.overrides]]
module = ["tortoise.contrib.test.*"]
disallow_untyped_defs = false
Expand Down
14 changes: 10 additions & 4 deletions tortoise/contrib/pydantic/base.py
@@ -1,10 +1,16 @@
import sys
from typing import TYPE_CHECKING, List, Type, Union

import pydantic
from pydantic import BaseModel, ConfigDict, RootModel

from tortoise import fields

if sys.version_info >= (3, 11): # pragma: nocoverage
from typing import Self
else:
from typing_extensions import Self

if TYPE_CHECKING: # pragma: nocoverage
from tortoise.models import Model
from tortoise.queryset import QuerySet, QuerySetSingle
Expand Down Expand Up @@ -59,7 +65,7 @@ def _tortoise_convert(cls, value): # pylint: disable=E0213
return value

@classmethod
async def from_tortoise_orm(cls, obj: "Model") -> "PydanticModel":
async def from_tortoise_orm(cls, obj: "Model") -> Self:
"""
Returns a serializable pydantic model instance built from the provided model instance.
Expand All @@ -86,7 +92,7 @@ async def from_tortoise_orm(cls, obj: "Model") -> "PydanticModel":
return cls.model_validate(obj)

@classmethod
async def from_queryset_single(cls, queryset: "QuerySetSingle") -> "PydanticModel":
async def from_queryset_single(cls, queryset: "QuerySetSingle") -> Self:
"""
Returns a serializable pydantic model instance for a single model
from the provided queryset.
Expand All @@ -99,7 +105,7 @@ async def from_queryset_single(cls, queryset: "QuerySetSingle") -> "PydanticMode
return cls.model_validate(await queryset.prefetch_related(*fetch_fields))

@classmethod
async def from_queryset(cls, queryset: "QuerySet") -> "List[PydanticModel]":
async def from_queryset(cls, queryset: "QuerySet") -> List[Self]:
"""
Returns a serializable pydantic model instance that contains a list of models,
from the provided queryset.
Expand All @@ -121,7 +127,7 @@ class PydanticListModel(RootModel):
"""

@classmethod
async def from_queryset(cls, queryset: "QuerySet") -> "PydanticListModel":
async def from_queryset(cls, queryset: "QuerySet") -> Self:
"""
Returns a serializable pydantic model instance that contains a list of models,
from the provided queryset.
Expand Down

0 comments on commit b6680f1

Please sign in to comment.