diff --git a/src/attr/_next_gen.py b/src/attr/_next_gen.py index 57e8635b1..0e3c365f3 100644 --- a/src/attr/_next_gen.py +++ b/src/attr/_next_gen.py @@ -50,9 +50,9 @@ def define( .. versionadded:: 20.1.0 """ - def do_it(auto_attribs): + def do_it(cls, auto_attribs): return attrs( - maybe_cls=maybe_cls, + maybe_cls=cls, these=these, repr=repr, hash=hash, @@ -74,12 +74,21 @@ def do_it(auto_attribs): ) if auto_attribs is not None: - return do_it(auto_attribs) - - try: - return do_it(True) - except UnannotatedAttributeError: - return do_it(False) + return do_it(maybe_cls, auto_attribs) + + def wrap(cls): + # Making this a wrapper ensures this code runs during class creation. + try: + return do_it(cls, True) + except UnannotatedAttributeError: + return do_it(cls, False) + + # maybe_cls's type depends on the usage of the decorator. It's a class + # if it's used as `@attrs` but ``None`` if used as `@attrs()`. + if maybe_cls is None: + return wrap + else: + return wrap(maybe_cls) mutable = define diff --git a/tests/test_next_gen.py b/tests/test_next_gen.py index a4460b192..1ed0108ad 100644 --- a/tests/test_next_gen.py +++ b/tests/test_next_gen.py @@ -101,6 +101,31 @@ class OldSchool: assert OldSchool(1) == OldSchool(1) + # Test with maybe_cls = None + @attr.define() + class OldSchool2: + x = attr.field() + + assert OldSchool2(1) == OldSchool2(1) + + def test_auto_attribs_detect_annotations(self): + """ + define correctly detects if a class has type annotations. + """ + + @attr.define + class NewSchool: + x: int + + assert NewSchool(1) == NewSchool(1) + + # Test with maybe_cls = None + @attr.define() + class NewSchool2: + x: int + + assert NewSchool2(1) == NewSchool2(1) + def test_exception(self): """ Exceptions are detected and correctly handled.