diff --git a/src/cryptography/utils.py b/src/cryptography/utils.py index 9bf31faadbec..ef0fc44332d0 100644 --- a/src/cryptography/utils.py +++ b/src/cryptography/utils.py @@ -41,8 +41,8 @@ def read_only_property(name: str): def register_interface(iface): - def register_decorator(klass): - verify_interface(iface, klass) + def register_decorator(klass, *, check_annotations=False): + verify_interface(iface, klass, check_annotations=check_annotations) iface.register(klass) return klass @@ -50,9 +50,9 @@ def register_decorator(klass): def register_interface_if(predicate, iface): - def register_decorator(klass): + def register_decorator(klass, *, check_annotations=False): if predicate: - verify_interface(iface, klass) + verify_interface(iface, klass, check_annotations=check_annotations) iface.register(klass) return klass @@ -69,7 +69,16 @@ class InterfaceNotImplemented(Exception): pass -def verify_interface(iface, klass): +def strip_annotation(signature): + return inspect.Signature( + [ + param.replace(annotation=inspect.Parameter.empty) + for param in signature.parameters.values() + ] + ) + + +def verify_interface(iface, klass, *, check_annotations=False): for method in iface.__abstractmethods__: if not hasattr(klass, method): raise InterfaceNotImplemented( @@ -80,7 +89,11 @@ def verify_interface(iface, klass): continue sig = inspect.signature(getattr(iface, method)) actual = inspect.signature(getattr(klass, method)) - if sig != actual: + if check_annotations: + ok = sig == actual + else: + ok = strip_annotation(sig) == strip_annotation(actual) + if not ok: raise InterfaceNotImplemented( "{}.{}'s signature differs from the expected. Expected: " "{!r}. Received: {!r}".format(klass, method, sig, actual) diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index c5c579da0ca7..89d802aed017 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -77,3 +77,27 @@ def property(self): # Invoke this to ensure the line is covered NonImplementer().property verify_interface(SimpleInterface, NonImplementer) + + def test_signature_mismatch(self): + class SimpleInterface(metaclass=abc.ABCMeta): + @abc.abstractmethod + def method(self, other: object) -> int: + """Method with signature""" + + class ClassWithoutSignature: + def method(self, other): + """Method without signature""" + + class ClassWithSignature: + def method(self, other: object) -> int: + """Method with signature""" + + verify_interface(SimpleInterface, ClassWithoutSignature) + verify_interface(SimpleInterface, ClassWithSignature) + with pytest.raises(InterfaceNotImplemented): + verify_interface( + SimpleInterface, ClassWithoutSignature, check_annotations=True + ) + verify_interface( + SimpleInterface, ClassWithSignature, check_annotations=True + )