Skip to content

Commit

Permalink
Added abstract SecretField class for secret fields (#3717)
Browse files Browse the repository at this point in the history
* Added abstract SecretField class for secret fields

* Added changes for changelog
  • Loading branch information
expobrain committed Aug 9, 2022
1 parent 697459a commit 6afc0c6
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 14 deletions.
1 change: 1 addition & 0 deletions changes/3409-expobrain.md
@@ -0,0 +1 @@
Adds the `SecretField` abstract class so that all the current and future secret fields like `SecretStr` and `SecretBytes` will derive from it.
1 change: 1 addition & 0 deletions pydantic/__init__.py
Expand Up @@ -110,6 +110,7 @@
'DirectoryPath',
'Json',
'JsonWrapper',
'SecretField',
'SecretStr',
'SecretBytes',
'StrictBool',
Expand Down
40 changes: 26 additions & 14 deletions pydantic/types.py
@@ -1,3 +1,4 @@
import abc
import math
import re
import warnings
Expand Down Expand Up @@ -92,6 +93,7 @@
'DirectoryPath',
'Json',
'JsonWrapper',
'SecretField',
'SecretStr',
'SecretBytes',
'StrictBool',
Expand Down Expand Up @@ -812,7 +814,29 @@ def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SECRET TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


class SecretStr:
class SecretField(abc.ABC):
"""
Note: this should be implemented as a generic like `SecretField(ABC, Generic[T])`,
the `__init__()` should be part of the abstract class and the
`get_secret_value()` method should use the generic `T` type.
However Cython doesn't support very well generics at the moment and
the generated code fails to be imported (see
https://github.com/cython/cython/issues/2753).
"""

def __eq__(self, other: Any) -> bool:
return isinstance(other, self.__class__) and self.get_secret_value() == other.get_secret_value()

def __str__(self) -> str:
return '**********' if self.get_secret_value() else ''

@abc.abstractmethod
def get_secret_value(self) -> Any: # pragma: no cover
...


class SecretStr(SecretField):
min_length: OptionalInt = None
max_length: OptionalInt = None

Expand Down Expand Up @@ -845,12 +869,6 @@ def __init__(self, value: str):
def __repr__(self) -> str:
return f"SecretStr('{self}')"

def __str__(self) -> str:
return '**********' if self._secret_value else ''

def __eq__(self, other: Any) -> bool:
return isinstance(other, SecretStr) and self.get_secret_value() == other.get_secret_value()

def __len__(self) -> int:
return len(self._secret_value)

Expand All @@ -862,7 +880,7 @@ def get_secret_value(self) -> str:
return self._secret_value


class SecretBytes:
class SecretBytes(SecretField):
min_length: OptionalInt = None
max_length: OptionalInt = None

Expand Down Expand Up @@ -895,12 +913,6 @@ def __init__(self, value: bytes):
def __repr__(self) -> str:
return f"SecretBytes(b'{self}')"

def __str__(self) -> str:
return '**********' if self._secret_value else ''

def __eq__(self, other: Any) -> bool:
return isinstance(other, SecretBytes) and self.get_secret_value() == other.get_secret_value()

def __len__(self) -> int:
return len(self._secret_value)

Expand Down
19 changes: 19 additions & 0 deletions tests/test_types.py
Expand Up @@ -75,6 +75,7 @@
errors,
validator,
)
from pydantic.types import SecretField
from pydantic.typing import NoneType

try:
Expand Down Expand Up @@ -2573,6 +2574,16 @@ class Foobar(BaseModel):
]


def test_secretfield():
class Foobar(SecretField):
...

message = "Can't instantiate abstract class Foobar with abstract methods? get_secret_value"

with pytest.raises(TypeError, match=message):
Foobar()


def test_secretstr():
class Foobar(BaseModel):
password: SecretStr
Expand Down Expand Up @@ -2605,6 +2616,10 @@ class Foobar(BaseModel):
assert f != f.copy(update=dict(password='4321'))


def test_secretstr_is_secret_field():
assert issubclass(SecretStr, SecretField)


def test_secretstr_equality():
assert SecretStr('abc') == SecretStr('abc')
assert SecretStr('123') != SecretStr('321')
Expand Down Expand Up @@ -2694,6 +2709,10 @@ class Foobar(BaseModel):
assert f != f.copy(update=dict(password=b'4321'))


def test_secretbytes_is_secret_field():
assert issubclass(SecretBytes, SecretField)


def test_secretbytes_equality():
assert SecretBytes(b'abc') == SecretBytes(b'abc')
assert SecretBytes(b'123') != SecretBytes(b'321')
Expand Down

0 comments on commit 6afc0c6

Please sign in to comment.