Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add support for decimal-specific configs in Field() #3507

Merged
merged 4 commits into from Dec 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3507-tiangolo.md
@@ -0,0 +1 @@
Add support for `Decimal`-specific validation configurations in `Field()`, additionally to using `condecimal()`, to allow better suppport from editors and tooling
4 changes: 4 additions & 0 deletions docs/usage/schema.md
Expand Up @@ -65,6 +65,10 @@ It has the following arguments:
JSON Schema
* `multiple_of`: for numeric values, this adds a validation of "a multiple of" and an annotation of `multipleOf` to the
JSON Schema
* `max_digits`: for `Decimal` values, this adds a validation to have a maximum number of digits within the decimal. It
does not include a zero before the decimal point or trailing decimal zeroes.
* `decimal_places`: for `Decimal` values, this adds a validation to have at most a number of decimal places allowed. It
does not include trailing decimal zeroes.
* `min_items`: for list values, this adds a corresponding validation and an annotation of `minItems` to the
JSON Schema
* `max_items`: for list values, this adds a corresponding validation and an annotation of `maxItems` to the
Expand Down
14 changes: 14 additions & 0 deletions pydantic/fields.py
Expand Up @@ -101,6 +101,8 @@ class FieldInfo(Representation):
'lt',
'le',
'multiple_of',
'max_digits',
'decimal_places',
'min_items',
'max_items',
'unique_items',
Expand All @@ -122,6 +124,8 @@ class FieldInfo(Representation):
'ge': None,
'le': None,
'multiple_of': None,
'max_digits': None,
'decimal_places': None,
'min_items': None,
'max_items': None,
'unique_items': None,
Expand All @@ -143,6 +147,8 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
self.lt = kwargs.pop('lt', None)
self.le = kwargs.pop('le', None)
self.multiple_of = kwargs.pop('multiple_of', None)
self.max_digits = kwargs.pop('max_digits', None)
self.decimal_places = kwargs.pop('decimal_places', None)
self.min_items = kwargs.pop('min_items', None)
self.max_items = kwargs.pop('max_items', None)
self.unique_items = kwargs.pop('unique_items', None)
Expand Down Expand Up @@ -209,6 +215,8 @@ def Field(
lt: float = None,
le: float = None,
multiple_of: float = None,
max_digits: int = None,
decimal_places: int = None,
min_items: int = None,
max_items: int = None,
unique_items: bool = None,
Expand Down Expand Up @@ -245,6 +253,10 @@ def Field(
schema will have a ``maximum`` validation keyword
:param multiple_of: only applies to numbers, requires the field to be "a multiple of". The
schema will have a ``multipleOf`` validation keyword
:param max_digits: only applies to Decimals, requires the field to have a maximum number
of digits within the decimal. It does not include a zero before the decimal point or trailing decimal zeroes.
:param decimal_places: only applies to Decimals, requires the field to have at most a number of decimal places
allowed. It does not include trailing decimal zeroes.
:param min_items: only applies to lists, requires the field to have a minimum number of
elements. The schema will have a ``minItems`` validation keyword
:param max_items: only applies to lists, requires the field to have a maximum number of
Expand Down Expand Up @@ -276,6 +288,8 @@ def Field(
lt=lt,
le=le,
multiple_of=multiple_of,
max_digits=max_digits,
decimal_places=decimal_places,
min_items=min_items,
max_items=max_items,
unique_items=unique_items,
Expand Down
2 changes: 2 additions & 0 deletions pydantic/schema.py
Expand Up @@ -1056,6 +1056,8 @@ def constraint_func(**kw: Any) -> Type[Any]:
):
# Is numeric type
attrs = ('gt', 'lt', 'ge', 'le', 'multiple_of')
if issubclass(type_, Decimal):
attrs += ('max_digits', 'decimal_places')
numeric_type = next(t for t in numeric_types if issubclass(type_, t)) # pragma: no branch
constraint_func = _map_types_constraint[numeric_type]

Expand Down
72 changes: 37 additions & 35 deletions tests/test_types.py
Expand Up @@ -1840,11 +1840,11 @@ class Config:


@pytest.mark.parametrize(
'type_,value,result',
'type_args,value,result',
[
(condecimal(gt=Decimal('42.24')), Decimal('43'), Decimal('43')),
(dict(gt=Decimal('42.24')), Decimal('43'), Decimal('43')),
(
condecimal(gt=Decimal('42.24')),
dict(gt=Decimal('42.24')),
Decimal('42'),
[
{
Expand All @@ -1855,9 +1855,9 @@ class Config:
}
],
),
(condecimal(lt=Decimal('42.24')), Decimal('42'), Decimal('42')),
(dict(lt=Decimal('42.24')), Decimal('42'), Decimal('42')),
(
condecimal(lt=Decimal('42.24')),
dict(lt=Decimal('42.24')),
Decimal('43'),
[
{
Expand All @@ -1868,10 +1868,10 @@ class Config:
}
],
),
(condecimal(ge=Decimal('42.24')), Decimal('43'), Decimal('43')),
(condecimal(ge=Decimal('42.24')), Decimal('42.24'), Decimal('42.24')),
(dict(ge=Decimal('42.24')), Decimal('43'), Decimal('43')),
(dict(ge=Decimal('42.24')), Decimal('42.24'), Decimal('42.24')),
(
condecimal(ge=Decimal('42.24')),
dict(ge=Decimal('42.24')),
Decimal('42'),
[
{
Expand All @@ -1882,10 +1882,10 @@ class Config:
}
],
),
(condecimal(le=Decimal('42.24')), Decimal('42'), Decimal('42')),
(condecimal(le=Decimal('42.24')), Decimal('42.24'), Decimal('42.24')),
(dict(le=Decimal('42.24')), Decimal('42'), Decimal('42')),
(dict(le=Decimal('42.24')), Decimal('42.24'), Decimal('42.24')),
(
condecimal(le=Decimal('42.24')),
dict(le=Decimal('42.24')),
Decimal('43'),
[
{
Expand All @@ -1896,9 +1896,9 @@ class Config:
}
],
),
(condecimal(max_digits=2, decimal_places=2), Decimal('0.99'), Decimal('0.99')),
(dict(max_digits=2, decimal_places=2), Decimal('0.99'), Decimal('0.99')),
(
condecimal(max_digits=2, decimal_places=1),
dict(max_digits=2, decimal_places=1),
Decimal('0.99'),
[
{
Expand All @@ -1910,7 +1910,7 @@ class Config:
],
),
(
condecimal(max_digits=3, decimal_places=1),
dict(max_digits=3, decimal_places=1),
Decimal('999'),
[
{
Expand All @@ -1921,11 +1921,11 @@ class Config:
}
],
),
(condecimal(max_digits=4, decimal_places=1), Decimal('999'), Decimal('999')),
(condecimal(max_digits=20, decimal_places=2), Decimal('742403889818000000'), Decimal('742403889818000000')),
(condecimal(max_digits=20, decimal_places=2), Decimal('7.42403889818E+17'), Decimal('7.42403889818E+17')),
(dict(max_digits=4, decimal_places=1), Decimal('999'), Decimal('999')),
(dict(max_digits=20, decimal_places=2), Decimal('742403889818000000'), Decimal('742403889818000000')),
(dict(max_digits=20, decimal_places=2), Decimal('7.42403889818E+17'), Decimal('7.42403889818E+17')),
(
condecimal(max_digits=20, decimal_places=2),
dict(max_digits=20, decimal_places=2),
Decimal('7424742403889818000000'),
[
{
Expand All @@ -1936,9 +1936,9 @@ class Config:
}
],
),
(condecimal(max_digits=5, decimal_places=2), Decimal('7304E-1'), Decimal('7304E-1')),
(dict(max_digits=5, decimal_places=2), Decimal('7304E-1'), Decimal('7304E-1')),
(
condecimal(max_digits=5, decimal_places=2),
dict(max_digits=5, decimal_places=2),
Decimal('7304E-3'),
[
{
Expand All @@ -1949,9 +1949,9 @@ class Config:
}
],
),
(condecimal(max_digits=5, decimal_places=5), Decimal('70E-5'), Decimal('70E-5')),
(dict(max_digits=5, decimal_places=5), Decimal('70E-5'), Decimal('70E-5')),
(
condecimal(max_digits=5, decimal_places=5),
dict(max_digits=5, decimal_places=5),
Decimal('70E-6'),
[
{
Expand All @@ -1964,7 +1964,7 @@ class Config:
),
*[
(
condecimal(decimal_places=2, max_digits=10),
dict(decimal_places=2, max_digits=10),
value,
[{'loc': ('foo',), 'msg': 'value is not a valid decimal', 'type': 'value_error.decimal.not_finite'}],
)
Expand All @@ -1985,7 +1985,7 @@ class Config:
],
*[
(
condecimal(decimal_places=2, max_digits=10),
dict(decimal_places=2, max_digits=10),
Decimal(value),
[{'loc': ('foo',), 'msg': 'value is not a valid decimal', 'type': 'value_error.decimal.not_finite'}],
)
Expand All @@ -2005,7 +2005,7 @@ class Config:
)
],
(
condecimal(multiple_of=Decimal('5')),
dict(multiple_of=Decimal('5')),
Decimal('42'),
[
{
Expand All @@ -2018,16 +2018,18 @@ class Config:
),
],
)
def test_decimal_validation(type_, value, result):
model = create_model('DecimalModel', foo=(type_, ...))

if not isinstance(result, Decimal):
with pytest.raises(ValidationError) as exc_info:
model(foo=value)
assert exc_info.value.errors() == result
assert exc_info.value.json().startswith('[')
else:
assert model(foo=value).foo == result
def test_decimal_validation(type_args, value, result):
modela = create_model('DecimalModel', foo=(condecimal(**type_args), ...))
modelb = create_model('DecimalModel', foo=(Decimal, Field(..., **type_args)))

for model in (modela, modelb):
if not isinstance(result, Decimal):
with pytest.raises(ValidationError) as exc_info:
model(foo=value)
assert exc_info.value.errors() == result
assert exc_info.value.json().startswith('[')
else:
assert model(foo=value).foo == result


@pytest.mark.parametrize('value,result', (('/test/path', Path('/test/path')), (Path('/test/path'), Path('/test/path'))))
Expand Down