From 827388b4fef0e7f466f96882d3aaed1cd85cfc2f Mon Sep 17 00:00:00 2001 From: John Carter Date: Mon, 1 Jun 2020 01:51:50 +1200 Subject: [PATCH] Add a test assertion that `default_factory` can return a singleton (#1523) --- changes/1523-therefromhere.md | 1 + tests/test_dataclasses.py | 17 +++++++++++++++++ tests/test_main.py | 14 ++++++++++++++ 3 files changed, 32 insertions(+) create mode 100644 changes/1523-therefromhere.md diff --git a/changes/1523-therefromhere.md b/changes/1523-therefromhere.md new file mode 100644 index 0000000000..3a8256776a --- /dev/null +++ b/changes/1523-therefromhere.md @@ -0,0 +1 @@ +Add a test assertion that `default_factory` can return a singleton diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index 22b126a136..d799e35a0c 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -431,6 +431,23 @@ class User: assert fields['aliases'].default == {'John': 'Joey'} +def test_default_factory_singleton_field(): + class MySingleton: + pass + + class MyConfig: + arbitrary_types_allowed = True + + MY_SINGLETON = MySingleton() + + @pydantic.dataclasses.dataclass(config=MyConfig) + class Foo: + singleton: MySingleton = dataclasses.field(default_factory=lambda: MY_SINGLETON) + + # Returning a singleton from a default_factory is supported + assert Foo().singleton is Foo().singleton + + def test_schema(): @pydantic.dataclasses.dataclass class User: diff --git a/tests/test_main.py b/tests/test_main.py index 76e05b70ad..c70ada964e 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1087,6 +1087,20 @@ class FunctionModel(BaseModel): m = FunctionModel() assert m.uid is uuid4 + # Returning a singleton from a default_factory is supported + class MySingleton: + pass + + MY_SINGLETON = MySingleton() + + class SingletonFieldModel(BaseModel): + singleton: MySingleton = Field(default_factory=lambda: MY_SINGLETON) + + class Config: + arbitrary_types_allowed = True + + assert SingletonFieldModel().singleton is SingletonFieldModel().singleton + @pytest.mark.skipif(sys.version_info < (3, 7), reason='field constraints are set but not enforced with python 3.6') def test_none_min_max_items():