From 6afc0c695afac80cb08d2d55324b07b26ee912a4 Mon Sep 17 00:00:00 2001 From: Daniele Esposti Date: Tue, 9 Aug 2022 18:47:15 +0300 Subject: [PATCH] Added abstract SecretField class for secret fields (#3717) * Added abstract SecretField class for secret fields * Added changes for changelog --- changes/3409-expobrain.md | 1 + pydantic/__init__.py | 1 + pydantic/types.py | 40 +++++++++++++++++++++++++-------------- tests/test_types.py | 19 +++++++++++++++++++ 4 files changed, 47 insertions(+), 14 deletions(-) create mode 100644 changes/3409-expobrain.md diff --git a/changes/3409-expobrain.md b/changes/3409-expobrain.md new file mode 100644 index 0000000000..714f9628aa --- /dev/null +++ b/changes/3409-expobrain.md @@ -0,0 +1 @@ +Adds the `SecretField` abstract class so that all the current and future secret fields like `SecretStr` and `SecretBytes` will derive from it. diff --git a/pydantic/__init__.py b/pydantic/__init__.py index 69eba67945..d265935513 100644 --- a/pydantic/__init__.py +++ b/pydantic/__init__.py @@ -110,6 +110,7 @@ 'DirectoryPath', 'Json', 'JsonWrapper', + 'SecretField', 'SecretStr', 'SecretBytes', 'StrictBool', diff --git a/pydantic/types.py b/pydantic/types.py index d8fa028176..217435b2a5 100644 --- a/pydantic/types.py +++ b/pydantic/types.py @@ -1,3 +1,4 @@ +import abc import math import re import warnings @@ -92,6 +93,7 @@ 'DirectoryPath', 'Json', 'JsonWrapper', + 'SecretField', 'SecretStr', 'SecretBytes', 'StrictBool', @@ -812,7 +814,29 @@ def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SECRET TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -class SecretStr: +class SecretField(abc.ABC): + """ + Note: this should be implemented as a generic like `SecretField(ABC, Generic[T])`, + the `__init__()` should be part of the abstract class and the + `get_secret_value()` method should use the generic `T` type. + + However Cython doesn't support very well generics at the moment and + the generated code fails to be imported (see + https://github.com/cython/cython/issues/2753). + """ + + def __eq__(self, other: Any) -> bool: + return isinstance(other, self.__class__) and self.get_secret_value() == other.get_secret_value() + + def __str__(self) -> str: + return '**********' if self.get_secret_value() else '' + + @abc.abstractmethod + def get_secret_value(self) -> Any: # pragma: no cover + ... + + +class SecretStr(SecretField): min_length: OptionalInt = None max_length: OptionalInt = None @@ -845,12 +869,6 @@ def __init__(self, value: str): def __repr__(self) -> str: return f"SecretStr('{self}')" - def __str__(self) -> str: - return '**********' if self._secret_value else '' - - def __eq__(self, other: Any) -> bool: - return isinstance(other, SecretStr) and self.get_secret_value() == other.get_secret_value() - def __len__(self) -> int: return len(self._secret_value) @@ -862,7 +880,7 @@ def get_secret_value(self) -> str: return self._secret_value -class SecretBytes: +class SecretBytes(SecretField): min_length: OptionalInt = None max_length: OptionalInt = None @@ -895,12 +913,6 @@ def __init__(self, value: bytes): def __repr__(self) -> str: return f"SecretBytes(b'{self}')" - def __str__(self) -> str: - return '**********' if self._secret_value else '' - - def __eq__(self, other: Any) -> bool: - return isinstance(other, SecretBytes) and self.get_secret_value() == other.get_secret_value() - def __len__(self) -> int: return len(self._secret_value) diff --git a/tests/test_types.py b/tests/test_types.py index e90c0bfe29..425bf9d02c 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -75,6 +75,7 @@ errors, validator, ) +from pydantic.types import SecretField from pydantic.typing import NoneType try: @@ -2573,6 +2574,16 @@ class Foobar(BaseModel): ] +def test_secretfield(): + class Foobar(SecretField): + ... + + message = "Can't instantiate abstract class Foobar with abstract methods? get_secret_value" + + with pytest.raises(TypeError, match=message): + Foobar() + + def test_secretstr(): class Foobar(BaseModel): password: SecretStr @@ -2605,6 +2616,10 @@ class Foobar(BaseModel): assert f != f.copy(update=dict(password='4321')) +def test_secretstr_is_secret_field(): + assert issubclass(SecretStr, SecretField) + + def test_secretstr_equality(): assert SecretStr('abc') == SecretStr('abc') assert SecretStr('123') != SecretStr('321') @@ -2694,6 +2709,10 @@ class Foobar(BaseModel): assert f != f.copy(update=dict(password=b'4321')) +def test_secretbytes_is_secret_field(): + assert issubclass(SecretBytes, SecretField) + + def test_secretbytes_equality(): assert SecretBytes(b'abc') == SecretBytes(b'abc') assert SecretBytes(b'123') != SecretBytes(b'321')