-
-
Notifications
You must be signed in to change notification settings - Fork 863
/
test_database.py
175 lines (134 loc) Β· 5.36 KB
/
test_database.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import databases
import pytest
import sqlalchemy
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Route
DATABASE_URL = "sqlite:///test.db"
metadata = sqlalchemy.MetaData()
notes = sqlalchemy.Table(
"notes",
metadata,
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
sqlalchemy.Column("text", sqlalchemy.String(length=100)),
sqlalchemy.Column("completed", sqlalchemy.Boolean),
)
pytestmark = pytest.mark.usefixtures("no_trio_support")
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
database = databases.Database(DATABASE_URL, force_rollback=True)
async def startup():
await database.connect()
async def shutdown():
await database.disconnect()
async def list_notes(request: Request):
query = notes.select()
results = await database.fetch_all(query)
content = [
{"text": result["text"], "completed": result["completed"]} for result in results
]
return JSONResponse(content)
@database.transaction()
async def add_note(request: Request):
data = await request.json()
query = notes.insert().values(text=data["text"], completed=data["completed"])
await database.execute(query)
if "raise_exc" in request.query_params:
raise RuntimeError()
return JSONResponse({"text": data["text"], "completed": data["completed"]})
async def bulk_create_notes(request: Request):
data = await request.json()
query = notes.insert()
await database.execute_many(query, data)
return JSONResponse({"notes": data})
async def read_note(request: Request):
note_id = request.path_params["note_id"]
query = notes.select().where(notes.c.id == note_id)
result = await database.fetch_one(query)
assert result is not None
content = {"text": result["text"], "completed": result["completed"]}
return JSONResponse(content)
async def read_note_text(request: Request):
note_id = request.path_params["note_id"]
query = sqlalchemy.select([notes.c.text]).where(notes.c.id == note_id)
result = await database.fetch_one(query)
assert result is not None
return JSONResponse(result[0])
app = Starlette(
routes=[
Route("/notes", endpoint=list_notes, methods=["GET"]),
Route("/notes", endpoint=add_note, methods=["POST"]),
Route("/notes/bulk_create", endpoint=bulk_create_notes, methods=["POST"]),
Route("/notes/{note_id:int}", endpoint=read_note, methods=["GET"]),
Route("/notes/{note_id:int}/text", endpoint=read_note_text, methods=["GET"]),
],
on_startup=[startup],
on_shutdown=[shutdown],
)
def test_database(test_client_factory):
with test_client_factory(app) as client:
response = client.post(
"/notes", json={"text": "buy the milk", "completed": True}
)
assert response.status_code == 200
with pytest.raises(RuntimeError):
response = client.post(
"/notes",
json={"text": "you wont see me", "completed": False},
params={"raise_exc": "true"},
)
response = client.post(
"/notes", json={"text": "walk the dog", "completed": False}
)
assert response.status_code == 200
response = client.get("/notes")
assert response.status_code == 200
assert response.json() == [
{"text": "buy the milk", "completed": True},
{"text": "walk the dog", "completed": False},
]
response = client.get("/notes/1")
assert response.status_code == 200
assert response.json() == {"text": "buy the milk", "completed": True}
response = client.get("/notes/1/text")
assert response.status_code == 200
assert response.json() == "buy the milk"
def test_database_execute_many(test_client_factory):
with test_client_factory(app) as client:
data = [
{"text": "buy the milk", "completed": True},
{"text": "walk the dog", "completed": False},
]
response = client.post("/notes/bulk_create", json=data)
assert response.status_code == 200
response = client.get("/notes")
assert response.status_code == 200
assert response.json() == [
{"text": "buy the milk", "completed": True},
{"text": "walk the dog", "completed": False},
]
def test_database_isolated_during_test_cases(test_client_factory):
"""
Using `TestClient` as a context manager
"""
with test_client_factory(app) as client:
response = client.post(
"/notes", json={"text": "just one note", "completed": True}
)
assert response.status_code == 200
response = client.get("/notes")
assert response.status_code == 200
assert response.json() == [{"text": "just one note", "completed": True}]
with test_client_factory(app) as client:
response = client.post(
"/notes", json={"text": "just one note", "completed": True}
)
assert response.status_code == 200
response = client.get("/notes")
assert response.status_code == 200
assert response.json() == [{"text": "just one note", "completed": True}]