Skip to content

Commit

Permalink
fix: handle basemodel fallback for custom encoders
Browse files Browse the repository at this point in the history
  • Loading branch information
PrettyWood committed Dec 19, 2021
1 parent c532e83 commit 898a9c1
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 2 deletions.
18 changes: 16 additions & 2 deletions pydantic/main.py
Expand Up @@ -467,7 +467,6 @@ def json(
DeprecationWarning,
)
exclude_unset = skip_defaults
encoder = cast(Callable[[Any], Any], encoder or self.__json_encoder__)

# We don't directly call `self.dict()`, which does exactly the same thing but
# with `to_dict = True` because we want to keep raw `BaseModel` instances and not as `dict`.
Expand All @@ -484,7 +483,22 @@ def json(
)
if self.__custom_root_type__:
data = data[ROOT_KEY]
return self.__config__.json_dumps(data, default=encoder, **dumps_kwargs)

def encoder_with_fallback(o: Any) -> Any:
if encoder is None:
# use the default encoder
return cast(Callable[[Any], Any], self.__json_encoder__)(o)
else:
try:
return encoder(o)
except Exception as e:
# when a custom encoder is set, `BaseModel` instances may not be handled
if isinstance(o, BaseModel):
return o.dict()
else:
raise e

return self.__config__.json_dumps(data, default=encoder_with_fallback, **dumps_kwargs)

@classmethod
def _enforce_dict_if_root(cls, obj: Any) -> Any:
Expand Down
38 changes: 38 additions & 0 deletions tests/test_json.py
Expand Up @@ -319,3 +319,41 @@ class Config:
pumbaa.json() == '{"name": "Pumbaa", "SSN": 234, "birthday": 737424000.0, "phone": 18007267864, "friend": null}'
)
assert timon.json() == '{"name": "Timon", "SSN": 123, "birthday": 738892800.0, "phone": 18002752273, "friend": 234}'


def test_custom_encode_fallback_basemodel():
class MyExoticType:
pass

def custom_encoder(o):
if isinstance(o, MyExoticType):
return 'exo'
raise TypeError('not serialisable')

class Foo(BaseModel):
x: MyExoticType

class Config:
arbitrary_types_allowed = True

class Bar(BaseModel):
foo: Foo

assert Bar(foo=Foo(x=MyExoticType())).json(encoder=custom_encoder) == '{"foo": {"x": "exo"}}'


def test_custom_encode_error():
class MyExoticType:
pass

def custom_encoder(o):
raise TypeError('not serialisable')

class Foo(BaseModel):
x: MyExoticType

class Config:
arbitrary_types_allowed = True

with pytest.raises(TypeError, match='not serialisable'):
Foo(x=MyExoticType()).json(encoder=custom_encoder)

0 comments on commit 898a9c1

Please sign in to comment.