-
Notifications
You must be signed in to change notification settings - Fork 26
/
database_models_factory.py
171 lines (136 loc) · 5.45 KB
/
database_models_factory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
""" Automatic creation of pydantic model classes from a sqlalchemy table
SEE: Copied and adapted from https://github.com/tiangolo/pydantic-sqlalchemy/blob/master/pydantic_sqlalchemy/main.py
"""
import json
import warnings
from datetime import datetime
from typing import Any, Callable, Container, Optional
from uuid import UUID
import sqlalchemy as sa
import sqlalchemy.sql.functions
from pydantic import BaseConfig, BaseModel, Field, create_model
from pydantic.types import NonNegativeInt
from sqlalchemy import null
from sqlalchemy.sql.schema import Column
warnings.warn(
"This is still a concept under development. "
"Currently only inteded for testing. "
"DO NOT USE in production.",
category=UserWarning,
)
class OrmConfig(BaseConfig):
orm_mode = True
_RESERVED = {
"schema",
# e.g. Field name "schema" shadows a BaseModel attribute; use a different field name with "alias='schema'".
}
def _eval_defaults(
column: Column, pydantic_type: type, *, include_server_defaults: bool = True
):
"""
Uses some heuristics to determine the default value/factory produced
parsing both the client and the server (if include_server_defaults==True) defaults
in the sa model.
"""
default: Optional[Any] = None
default_factory: Optional[Callable] = None
if (
column.default is None
and (include_server_defaults and column.server_default is None)
and not column.nullable
):
default = ...
if column.default and column.default.is_scalar:
assert not column.default.is_server_default # nosec
default = column.default.arg
if include_server_defaults and column.server_default:
assert column.server_default.is_server_default # nosec
#
# FIXME: Map server's DefaultClauses to correct values
# Heuristics based on test against all our tables
#
if pydantic_type:
if issubclass(pydantic_type, list):
assert column.server_default.arg == "{}" # nosec
default_factory = list
elif issubclass(pydantic_type, dict):
assert column.server_default.arg.text.endswith("::jsonb") # nosec
default = json.loads(
column.server_default.arg.text.replace("::jsonb", "").replace(
"'", ""
)
)
elif issubclass(pydantic_type, datetime):
assert isinstance( # nosec
column.server_default.arg,
(type(null()), sqlalchemy.sql.functions.now),
)
default_factory = datetime.now
return default, default_factory
PolicyCallable = Callable[[Column, Any, type], tuple[Any, type]]
def eval_name_policy(column: Column, default: Any, pydantic_type: type):
"""All string columns including 'uuid' in their name are set as UUIDs"""
new_default, new_pydantic_type = default, pydantic_type
if "uuid" in str(column.name).split("_") and pydantic_type == str:
new_pydantic_type = UUID
if isinstance(default, str):
new_default = UUID(default)
return new_default, new_pydantic_type
DEFAULT_EXTRA_POLICIES = [
eval_name_policy,
]
def create_pydantic_model_from_sa_table(
table: sa.Table,
*,
config: type = OrmConfig,
exclude: Optional[Container[str]] = None,
include_server_defaults: bool = False,
extra_policies: Optional[list[PolicyCallable]] = None,
) -> type[BaseModel]:
fields = {}
exclude = exclude or []
extra_policies = extra_policies or DEFAULT_EXTRA_POLICIES
for column in table.columns:
name = str(column.name)
if name in exclude:
continue
field_args: dict[str, Any] = {}
if name in _RESERVED:
field_args["alias"] = name
name = f"{table.name.lower()}_{name}"
# type ---
pydantic_type: Optional[type] = None
if hasattr(column.type, "impl"):
if hasattr(column.type.impl, "python_type"):
pydantic_type = column.type.impl.python_type
elif hasattr(column.type, "python_type"):
pydantic_type = column.type.python_type
assert pydantic_type, f"Could not infer pydantic_type for {column}" # nosec
# big integer primary keys
if column.primary_key and issubclass(pydantic_type, int):
pydantic_type = NonNegativeInt
# default ----
default, default_factory = _eval_defaults(
column, pydantic_type, include_server_defaults=include_server_defaults
)
# Policies based on naming conventions
#
# TODO: implement it as a pluggable policy class.
# Base policy class is abstract interface
# and user can add as many in a given order in the arguments
#
for apply_policy in extra_policies:
default, pydantic_type = apply_policy(column, default, pydantic_type)
if default_factory:
field_args["default_factory"] = default_factory
else:
field_args["default"] = default
if hasattr(column, "doc") and column.doc:
field_args["description"] = column.doc
fields[name] = (pydantic_type, Field(**field_args))
# create domain models from db-schemas
pydantic_model = create_model(
table.name.capitalize(), __config__=config, **fields # type: ignore
)
assert issubclass(pydantic_model, BaseModel) # nosec
return pydantic_model