Skip to content

Commit

Permalink
Optimized dataclass factory
Browse files Browse the repository at this point in the history
- covering further side effects of [#4191](tiangolo/fastapi#4191)
- FASTAPI_XML_DISABLE_PYDANTIC_PATCH disables the pydantic patch
- added pydantic license
- readme update
  • Loading branch information
cercide committed Aug 18, 2022
1 parent 3ce57ad commit 3ad7519
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 29 deletions.
18 changes: 9 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,18 @@ from fastapi_xml import NonJsonRoute, XmlBody, XmlAppResponse
from fastapi import FastAPI

@dataclass
class Hello:
message: str = field(metadata={"name": "Message", "type": "Element"})

@dataclass
class World:
class HelloWorld:
class Meta:
name = "echo"
message: str = field(metadata={"name": "Message", "type": "Element"})

app = FastAPI(default_response_class=XmlAppResponse)
app.router.route_class = NonJsonRoute

@app.post("/echo", response_model=World)
def echo(x: Hello = XmlBody()) -> World:
return World(message=x.message + " For ever!")
@app.post("/echo", response_model=HelloWorld)
def echo(x: HelloWorld = XmlBody()) -> HelloWorld:
x.message += " For ever!"
return x

if __name__ == "__main__":
import uvicorn
Expand All @@ -41,7 +40,8 @@ if __name__ == "__main__":
- This package depends on fastapi and xsdata. However, fastapi depends on
pydantic, which ships a [bug](https://github.com/pydantic/pydantic/issues/4353) that causes several side effects.
Among other, this bug is fixed within the major branch. Nevertheless, the bug still occurs in the current version
(1.9.2). Anyhow, this package supports both versions.
(1.9.2). Anyhow, this package supports both versions. Set the environment variable
`FASTAPI_XML_DISABLE_PYDANTIC_PATCH != "false"` to disable that patch.

- :warning: Do not use keyword `required` for a field's metatdata. This will crash the openapi schema generator. Remove
typehint `Optional` instead.
28 changes: 14 additions & 14 deletions fastapi_xml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""
This packaged adds xml support to :mod:`fastapi`.
"""

import os
from . import nonjson, xmlbody
from .nonjson import NonJsonRoute, NonJsonResponse
from .xmlbody import XmlResponse, XmlTextResponse, XmlAppResponse, XmlBody

Expand All @@ -16,19 +17,18 @@
"XmlBody"
]

__version__ = "1.0.0a2"


__version__ = "1.0.0a3"

nonjson.OPENAPI_SCHEMA_MODIFIER.append(xmlbody.add_openapi_xml_schema)

try:
# https://github.com/pydantic/pydantic/issues/4353
from .pydantic_dataclass_patch import pydantic_process_class_patched
except ImportError:
# the patch does not work with the pydantic.dataclasses update (commit 576e4a3a8d9c98cbf5a1fe5149450febef887cc9)
# no worries, that update works as it should and is compatible with fastapi-xml
pass
else:
import pydantic.dataclasses
pydantic.dataclasses._process_class = pydantic_process_class_patched
if os.environ.get("FASTAPI_XML_DISABLE_PYDANTIC_PATCH", "false").lower() == "false":
try:
# https://github.com/pydantic/pydantic/issues/4353
from .pydantic_dataclass_patch import pydantic_process_class_patched, _validate_dataclass
except ImportError:
# the patch does not work with the pydantic.dataclasses update (commit 576e4a3a8d9c98cbf5a1fe5149450febef887cc9)
# no worries, that update works as it should and is compatible with fastapi-xml
pass
else:
import pydantic.dataclasses
pydantic.dataclasses._process_class = pydantic_process_class_patched
74 changes: 71 additions & 3 deletions fastapi_xml/pydantic_dataclass_patch.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,44 @@
from typing import Any, Dict, Optional, Type, TYPE_CHECKING
# this module contains modified code snippets from pydantic. Hence a license copy is given below.
# Any changes are highlighted.
#
# The MIT License (MIT)
#
# Copyright (c) 2017, 2018, 2019, 2020, 2021 Samuel Colvin and other contributors
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from typing import Any, Dict, Optional, Type, TYPE_CHECKING
from dataclasses import is_dataclass, asdict
from pydantic.typing import NoArgAnyCallable
from pydantic.class_validators import gather_all_validators
from pydantic.fields import Field, FieldInfo, Required, Undefined
from pydantic.main import create_model
from pydantic.typing import resolve_annotations
from pydantic.utils import ClassAttribute
from pydantic.dataclasses import _generate_pydantic_post_init, is_builtin_dataclass, _validate_dataclass, _get_validators, setattr_validate_assignment
from pydantic.dataclasses import _generate_pydantic_post_init, is_builtin_dataclass, _get_validators, setattr_validate_assignment, DataclassTypeError

if TYPE_CHECKING:
from pydantic.dataclasses import Dataclass

_CACHE: Dict[Type, Type['Dataclass']] = {}


def pydantic_process_class_patched(
_cls: Type[Any],
init: bool,
Expand All @@ -23,9 +49,11 @@ def pydantic_process_class_patched(
frozen: bool,
config: Optional[Type[Any]],
) -> Type['Dataclass']:
# BEGIN EDIT
or_cls = _cls
if or_cls in _CACHE:
return _CACHE[or_cls]
# END EDIT

import dataclasses

Expand Down Expand Up @@ -62,9 +90,12 @@ def pydantic_process_class_patched(
# attrs for pickle to find this class
'__module__': __name__,
'__qualname__': uniq_class_name,

# BEGIN EDIT
# addresses https://github.com/pydantic/pydantic/issues/4353
# BUGFIX: forward original fields to the new dataclass
**_cls.__dataclass_fields__
**getattr(_cls, "__dataclass_fields__", {})
# BEGIN EDIT
},
)
globals()[uniq_class_name] = _cls
Expand Down Expand Up @@ -116,5 +147,42 @@ def pydantic_process_class_patched(
cls.__setattr__ = setattr_validate_assignment # type: ignore[assignment]

cls.__pydantic_model__.__try_update_forward_refs__(**{cls.__name__: cls})

# BEGIN EDIT
cls.__origin__ = or_cls
_CACHE[or_cls] = cls
# END EDIT
return cls


def _validate_dataclass(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
if isinstance(v, cls):
# BEGIN EDIT
result = v
# END EDIT
elif isinstance(v, (list, tuple)):
# BEGIN EDIT
result = cls(*v)
# END EDIT
elif isinstance(v, dict):
# BEGIN EDIT
result = cls(**v)
# END EDIT
# In nested dataclasses, v can be of type `dataclasses.dataclass`.
# But to validate fields `cls` will be in fact a `pydantic.dataclasses.dataclass`,
# which inherits directly from the class of `v`.
elif is_builtin_dataclass(v) and cls.__bases__[0] is type(v):
# BEGIN EDIT
# import dataclasses
result = cls(**asdict(v))
# END EDIT
else:
raise DataclassTypeError(class_name=cls.__name__)

# BEGIN EDIT
clazz = getattr(cls, "__origin__", None)
if is_dataclass(clazz):
return clazz(**asdict(result))
else:
return result
# END EDIT
13 changes: 10 additions & 3 deletions fastapi_xml/xmlbody.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from xsdata.utils.constants import return_input
from xsdata.exceptions import ParserError

from .nonjson import BodyDecoder, BodyDecodeError, NonJsonResponse, OPENAPI_SCHEMA_MODIFIER
from .nonjson import BodyDecoder, BodyDecodeError, NonJsonResponse


DEFAULT_XML_CONTEXT: XmlContext = XmlContext()
Expand Down Expand Up @@ -82,7 +82,8 @@ def _get_child_node(parent: ElementNode, qname: str, attrs: Dict, ns_map: Dict,
def start(self, clazz: Optional[Type], queue: List[XmlNode], objects: List[Tuple[Optional[str], Any]], qname: str, attrs: Dict, ns_map: Dict):
if len(queue) == 0:
super().start(clazz, queue, objects, qname, attrs, ns_map)
if len(queue) > 0 and qname != queue[0].meta.qname:
qq = getattr(queue[0], "meta", None)
if len(queue) > 0 and qname != getattr(qq, "qname", None):
raise ParserError("invalid root element")
else:
item = queue[-1]
Expand Down Expand Up @@ -114,8 +115,14 @@ def decode(cls, request: Request, field: ModelField, body: bytes) -> Optional[Di
Else, it MUST return a mapping for pydantic's constructor
"""
xml_parser = cls.xml_parser if cls.xml_parser is not None else cls.xml_parser_factory()
clazz = field.type_
if not is_dataclass(clazz):
return None
if hasattr(clazz, "__origin__"):
# custom attr set by patched dataclass factory
clazz = _get_dataclass(clazz.__pydantic_model__)
try:
o = xml_parser.from_bytes(body, clazz=field.type_)
o = xml_parser.from_bytes(body, clazz=clazz)
except ParserError as e:
http_content_type: str = request.headers.get("content-type", "")
if http_content_type.endswith("/xml"):
Expand Down

0 comments on commit 3ad7519

Please sign in to comment.