Skip to content

Commit

Permalink
✨ Add support for decimal-specific configs in Field() (#3507)
Browse files Browse the repository at this point in the history
* ✨ Add support for Decimal-specific configs in Field()

* ✅ Add/update tests for condecimal and variant with Field()

* 📝 Update schema - Field() docs including Decimal-specific configs

* 📝 Add PR changes file
  • Loading branch information
tiangolo committed Dec 11, 2021
1 parent 6ad80cd commit 61d30ae
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 35 deletions.
1 change: 1 addition & 0 deletions changes/3507-tiangolo.md
@@ -0,0 +1 @@
Add support for `Decimal`-specific validation configurations in `Field()`, additionally to using `condecimal()`, to allow better suppport from editors and tooling
4 changes: 4 additions & 0 deletions docs/usage/schema.md
Expand Up @@ -65,6 +65,10 @@ It has the following arguments:
JSON Schema
* `multiple_of`: for numeric values, this adds a validation of "a multiple of" and an annotation of `multipleOf` to the
JSON Schema
* `max_digits`: for `Decimal` values, this adds a validation to have a maximum number of digits within the decimal. It
does not include a zero before the decimal point or trailing decimal zeroes.
* `decimal_places`: for `Decimal` values, this adds a validation to have at most a number of decimal places allowed. It
does not include trailing decimal zeroes.
* `min_items`: for list values, this adds a corresponding validation and an annotation of `minItems` to the
JSON Schema
* `max_items`: for list values, this adds a corresponding validation and an annotation of `maxItems` to the
Expand Down
14 changes: 14 additions & 0 deletions pydantic/fields.py
Expand Up @@ -100,6 +100,8 @@ class FieldInfo(Representation):
'lt',
'le',
'multiple_of',
'max_digits',
'decimal_places',
'min_items',
'max_items',
'unique_items',
Expand All @@ -121,6 +123,8 @@ class FieldInfo(Representation):
'ge': None,
'le': None,
'multiple_of': None,
'max_digits': None,
'decimal_places': None,
'min_items': None,
'max_items': None,
'unique_items': None,
Expand All @@ -142,6 +146,8 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
self.lt = kwargs.pop('lt', None)
self.le = kwargs.pop('le', None)
self.multiple_of = kwargs.pop('multiple_of', None)
self.max_digits = kwargs.pop('max_digits', None)
self.decimal_places = kwargs.pop('decimal_places', None)
self.min_items = kwargs.pop('min_items', None)
self.max_items = kwargs.pop('max_items', None)
self.unique_items = kwargs.pop('unique_items', None)
Expand Down Expand Up @@ -208,6 +214,8 @@ def Field(
lt: float = None,
le: float = None,
multiple_of: float = None,
max_digits: int = None,
decimal_places: int = None,
min_items: int = None,
max_items: int = None,
unique_items: bool = None,
Expand Down Expand Up @@ -244,6 +252,10 @@ def Field(
schema will have a ``maximum`` validation keyword
:param multiple_of: only applies to numbers, requires the field to be "a multiple of". The
schema will have a ``multipleOf`` validation keyword
:param max_digits: only applies to Decimals, requires the field to have a maximum number
of digits within the decimal. It does not include a zero before the decimal point or trailing decimal zeroes.
:param decimal_places: only applies to Decimals, requires the field to have at most a number of decimal places
allowed. It does not include trailing decimal zeroes.
:param min_items: only applies to lists, requires the field to have a minimum number of
elements. The schema will have a ``minItems`` validation keyword
:param max_items: only applies to lists, requires the field to have a maximum number of
Expand Down Expand Up @@ -275,6 +287,8 @@ def Field(
lt=lt,
le=le,
multiple_of=multiple_of,
max_digits=max_digits,
decimal_places=decimal_places,
min_items=min_items,
max_items=max_items,
unique_items=unique_items,
Expand Down
2 changes: 2 additions & 0 deletions pydantic/schema.py
Expand Up @@ -1056,6 +1056,8 @@ def constraint_func(**kw: Any) -> Type[Any]:
):
# Is numeric type
attrs = ('gt', 'lt', 'ge', 'le', 'multiple_of')
if issubclass(type_, Decimal):
attrs += ('max_digits', 'decimal_places')
numeric_type = next(t for t in numeric_types if issubclass(type_, t)) # pragma: no branch
constraint_func = _map_types_constraint[numeric_type]

Expand Down
72 changes: 37 additions & 35 deletions tests/test_types.py
Expand Up @@ -1840,11 +1840,11 @@ class Config:


@pytest.mark.parametrize(
'type_,value,result',
'type_args,value,result',
[
(condecimal(gt=Decimal('42.24')), Decimal('43'), Decimal('43')),
(dict(gt=Decimal('42.24')), Decimal('43'), Decimal('43')),
(
condecimal(gt=Decimal('42.24')),
dict(gt=Decimal('42.24')),
Decimal('42'),
[
{
Expand All @@ -1855,9 +1855,9 @@ class Config:
}
],
),
(condecimal(lt=Decimal('42.24')), Decimal('42'), Decimal('42')),
(dict(lt=Decimal('42.24')), Decimal('42'), Decimal('42')),
(
condecimal(lt=Decimal('42.24')),
dict(lt=Decimal('42.24')),
Decimal('43'),
[
{
Expand All @@ -1868,10 +1868,10 @@ class Config:
}
],
),
(condecimal(ge=Decimal('42.24')), Decimal('43'), Decimal('43')),
(condecimal(ge=Decimal('42.24')), Decimal('42.24'), Decimal('42.24')),
(dict(ge=Decimal('42.24')), Decimal('43'), Decimal('43')),
(dict(ge=Decimal('42.24')), Decimal('42.24'), Decimal('42.24')),
(
condecimal(ge=Decimal('42.24')),
dict(ge=Decimal('42.24')),
Decimal('42'),
[
{
Expand All @@ -1882,10 +1882,10 @@ class Config:
}
],
),
(condecimal(le=Decimal('42.24')), Decimal('42'), Decimal('42')),
(condecimal(le=Decimal('42.24')), Decimal('42.24'), Decimal('42.24')),
(dict(le=Decimal('42.24')), Decimal('42'), Decimal('42')),
(dict(le=Decimal('42.24')), Decimal('42.24'), Decimal('42.24')),
(
condecimal(le=Decimal('42.24')),
dict(le=Decimal('42.24')),
Decimal('43'),
[
{
Expand All @@ -1896,9 +1896,9 @@ class Config:
}
],
),
(condecimal(max_digits=2, decimal_places=2), Decimal('0.99'), Decimal('0.99')),
(dict(max_digits=2, decimal_places=2), Decimal('0.99'), Decimal('0.99')),
(
condecimal(max_digits=2, decimal_places=1),
dict(max_digits=2, decimal_places=1),
Decimal('0.99'),
[
{
Expand All @@ -1910,7 +1910,7 @@ class Config:
],
),
(
condecimal(max_digits=3, decimal_places=1),
dict(max_digits=3, decimal_places=1),
Decimal('999'),
[
{
Expand All @@ -1921,11 +1921,11 @@ class Config:
}
],
),
(condecimal(max_digits=4, decimal_places=1), Decimal('999'), Decimal('999')),
(condecimal(max_digits=20, decimal_places=2), Decimal('742403889818000000'), Decimal('742403889818000000')),
(condecimal(max_digits=20, decimal_places=2), Decimal('7.42403889818E+17'), Decimal('7.42403889818E+17')),
(dict(max_digits=4, decimal_places=1), Decimal('999'), Decimal('999')),
(dict(max_digits=20, decimal_places=2), Decimal('742403889818000000'), Decimal('742403889818000000')),
(dict(max_digits=20, decimal_places=2), Decimal('7.42403889818E+17'), Decimal('7.42403889818E+17')),
(
condecimal(max_digits=20, decimal_places=2),
dict(max_digits=20, decimal_places=2),
Decimal('7424742403889818000000'),
[
{
Expand All @@ -1936,9 +1936,9 @@ class Config:
}
],
),
(condecimal(max_digits=5, decimal_places=2), Decimal('7304E-1'), Decimal('7304E-1')),
(dict(max_digits=5, decimal_places=2), Decimal('7304E-1'), Decimal('7304E-1')),
(
condecimal(max_digits=5, decimal_places=2),
dict(max_digits=5, decimal_places=2),
Decimal('7304E-3'),
[
{
Expand All @@ -1949,9 +1949,9 @@ class Config:
}
],
),
(condecimal(max_digits=5, decimal_places=5), Decimal('70E-5'), Decimal('70E-5')),
(dict(max_digits=5, decimal_places=5), Decimal('70E-5'), Decimal('70E-5')),
(
condecimal(max_digits=5, decimal_places=5),
dict(max_digits=5, decimal_places=5),
Decimal('70E-6'),
[
{
Expand All @@ -1964,7 +1964,7 @@ class Config:
),
*[
(
condecimal(decimal_places=2, max_digits=10),
dict(decimal_places=2, max_digits=10),
value,
[{'loc': ('foo',), 'msg': 'value is not a valid decimal', 'type': 'value_error.decimal.not_finite'}],
)
Expand All @@ -1985,7 +1985,7 @@ class Config:
],
*[
(
condecimal(decimal_places=2, max_digits=10),
dict(decimal_places=2, max_digits=10),
Decimal(value),
[{'loc': ('foo',), 'msg': 'value is not a valid decimal', 'type': 'value_error.decimal.not_finite'}],
)
Expand All @@ -2005,7 +2005,7 @@ class Config:
)
],
(
condecimal(multiple_of=Decimal('5')),
dict(multiple_of=Decimal('5')),
Decimal('42'),
[
{
Expand All @@ -2018,16 +2018,18 @@ class Config:
),
],
)
def test_decimal_validation(type_, value, result):
model = create_model('DecimalModel', foo=(type_, ...))

if not isinstance(result, Decimal):
with pytest.raises(ValidationError) as exc_info:
model(foo=value)
assert exc_info.value.errors() == result
assert exc_info.value.json().startswith('[')
else:
assert model(foo=value).foo == result
def test_decimal_validation(type_args, value, result):
modela = create_model('DecimalModel', foo=(condecimal(**type_args), ...))
modelb = create_model('DecimalModel', foo=(Decimal, Field(..., **type_args)))

for model in (modela, modelb):
if not isinstance(result, Decimal):
with pytest.raises(ValidationError) as exc_info:
model(foo=value)
assert exc_info.value.errors() == result
assert exc_info.value.json().startswith('[')
else:
assert model(foo=value).foo == result


@pytest.mark.parametrize('value,result', (('/test/path', Path('/test/path')), (Path('/test/path'), Path('/test/path'))))
Expand Down

0 comments on commit 61d30ae

Please sign in to comment.