diff --git a/changelog.d/843.change.rst b/changelog.d/843.change.rst new file mode 100644 index 000000000..746950180 --- /dev/null +++ b/changelog.d/843.change.rst @@ -0,0 +1,2 @@ +``attr.resolve_types()`` now resolves types of subclasses after the parents are resolved. +`#842 `_ diff --git a/src/attr/_funcs.py b/src/attr/_funcs.py index fda508c5c..73271c5d5 100644 --- a/src/attr/_funcs.py +++ b/src/attr/_funcs.py @@ -377,11 +377,9 @@ class and you didn't pass any attribs. .. versionadded:: 21.1.0 *attribs* """ - try: - # Since calling get_type_hints is expensive we cache whether we've - # done it already. - cls.__attrs_types_resolved__ - except AttributeError: + # Since calling get_type_hints is expensive we cache whether we've + # done it already. + if getattr(cls, "__attrs_types_resolved__", None) != cls: import typing hints = typing.get_type_hints(cls, globalns=globalns, localns=localns) @@ -389,7 +387,9 @@ class and you didn't pass any attribs. if field.name in hints: # Since fields have been frozen we must work around it. _obj_setattr(field, "type", hints[field.name]) - cls.__attrs_types_resolved__ = True + # We store the class we resolved so that subclasses know they haven't + # been resolved. + cls.__attrs_types_resolved__ = cls # Return the class so you can use it as a decorator too. return cls diff --git a/tests/test_annotations.py b/tests/test_annotations.py index 2df6311b2..dd815228d 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -618,6 +618,40 @@ class C: with pytest.raises(NameError): typing.get_type_hints(C.__init__) + def test_inheritance(self): + """ + Subclasses can be resolved after the parent is resolved. + """ + + @attr.define() + class A: + n: "int" + + @attr.define() + class B(A): + pass + + attr.resolve_types(A) + attr.resolve_types(B) + + assert int == attr.fields(A).n.type + assert int == attr.fields(B).n.type + + def test_resolve_twice(self): + """ + You can call resolve_types as many times as you like. + This test is here mostly for coverage. + """ + + @attr.define() + class A: + n: "int" + + attr.resolve_types(A) + assert int == attr.fields(A).n.type + attr.resolve_types(A) + assert int == attr.fields(A).n.type + @pytest.mark.parametrize( "annot",