Skip to content

Commit

Permalink
✨ Add JSON-compatible float constraints for NaN and Inf (#3994)
Browse files Browse the repository at this point in the history
* ✨ Add JSON-compatible float constraints for NaN and Inf

* switching to a single "allow_inf_nan"

* fix tests

* add change and docs

* add allow_inf_nan to Config

Co-authored-by: Samuel Colvin <s@muelcolvin.com>
  • Loading branch information
tiangolo and samuelcolvin committed Aug 22, 2022
1 parent 0bbb874 commit 8dade7e
Show file tree
Hide file tree
Showing 12 changed files with 118 additions and 8 deletions.
1 change: 1 addition & 0 deletions changes/3994-tiangolo.md
@@ -0,0 +1 @@
Add JSON-compatible float constraint `allow_inf_nan`
5 changes: 5 additions & 0 deletions docs/usage/model_config.md
Expand Up @@ -121,6 +121,11 @@ not be included in the model schemas. **Note**: this means that attributes on th
: whether stdlib dataclasses `__post_init__` should be run before (default behaviour with value `'before_validation'`)
or after (value `'after_validation'`) parsing and validation when they are [converted](dataclasses.md#stdlib-dataclasses-and-_pydantic_-dataclasses).

**`allow_inf_nan`**
: whether to allows infinity (`+inf` an `-inf`) and NaN values to float fields, defaults to `True`,
set to `False` for compatibility with `JSON`,
see [#3994](https://github.com/pydantic/pydantic/pull/3994) for more details, added in **V1.10**

## Change behaviour globally

If you wish to change the behaviour of _pydantic_ globally, you can create your own custom `BaseModel`
Expand Down
3 changes: 3 additions & 0 deletions docs/usage/types.md
Expand Up @@ -833,6 +833,9 @@ The following arguments are available when using the `confloat` type function
- `lt: float = None`: enforces float to be less than the set value
- `le: float = None`: enforces float to be less than or equal to the set value
- `multiple_of: float = None`: enforces float to be a multiple of the set value
- `allow_inf_nan: bool = True`: whether to allows infinity (`+inf` an `-inf`) and NaN values, defaults to `True`,
set to `False` for compatibility with `JSON`,
see [#3994](https://github.com/pydantic/pydantic/pull/3994) for more details, added in **V1.10**

### Arguments to `condecimal`
The following arguments are available when using the `condecimal` type function
Expand Down
1 change: 1 addition & 0 deletions pydantic/__init__.py
Expand Up @@ -100,6 +100,7 @@
'NegativeFloat',
'NonNegativeFloat',
'NonPositiveFloat',
'FiniteFloat',
'ConstrainedDecimal',
'condecimal',
'ConstrainedDate',
Expand Down
2 changes: 2 additions & 0 deletions pydantic/config.py
Expand Up @@ -67,6 +67,7 @@ class ConfigDict(TypedDict, total=False):
json_dumps: AnyArgTCallable[str]
json_encoders: Dict[Type[object], AnyCallable]
underscore_attrs_are_private: bool
allow_inf_nan: bool

# whether or not inherited models as fields should be reconstructed as base model
copy_on_model_validation: bool
Expand Down Expand Up @@ -103,6 +104,7 @@ class BaseConfig:
json_dumps: Callable[..., str] = json.dumps
json_encoders: Dict[Union[Type[Any], str, ForwardRef], AnyCallable] = {}
underscore_attrs_are_private: bool = False
allow_inf_nan: bool = True

# whether inherited models as fields should be reconstructed as base model,
# and whether such a copy should be shallow or deep
Expand Down
5 changes: 5 additions & 0 deletions pydantic/errors.py
Expand Up @@ -417,6 +417,11 @@ class NumberNotLeError(_NumberBoundError):
msg_template = 'ensure this value is less than or equal to {limit_value}'


class NumberNotFiniteError(PydanticValueError):
code = 'number.not_finite_number'
msg_template = 'ensure this value is a finite number'


class NumberNotMultipleError(PydanticValueError):
code = 'number.not_multiple'
msg_template = 'ensure this value is a multiple of {multiple_of}'
Expand Down
7 changes: 7 additions & 0 deletions pydantic/fields.py
Expand Up @@ -114,6 +114,7 @@ class FieldInfo(Representation):
'lt',
'le',
'multiple_of',
'allow_inf_nan',
'max_digits',
'decimal_places',
'min_items',
Expand All @@ -138,6 +139,7 @@ class FieldInfo(Representation):
'ge': None,
'le': None,
'multiple_of': None,
'allow_inf_nan': None,
'max_digits': None,
'decimal_places': None,
'min_items': None,
Expand All @@ -161,6 +163,7 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
self.lt = kwargs.pop('lt', None)
self.le = kwargs.pop('le', None)
self.multiple_of = kwargs.pop('multiple_of', None)
self.allow_inf_nan = kwargs.pop('allow_inf_nan', None)
self.max_digits = kwargs.pop('max_digits', None)
self.decimal_places = kwargs.pop('decimal_places', None)
self.min_items = kwargs.pop('min_items', None)
Expand Down Expand Up @@ -231,6 +234,7 @@ def Field(
lt: float = None,
le: float = None,
multiple_of: float = None,
allow_inf_nan: bool = None,
max_digits: int = None,
decimal_places: int = None,
min_items: int = None,
Expand Down Expand Up @@ -270,6 +274,8 @@ def Field(
schema will have a ``maximum`` validation keyword
:param multiple_of: only applies to numbers, requires the field to be "a multiple of". The
schema will have a ``multipleOf`` validation keyword
:param allow_inf_nan: only applies to numbers, allows the field to be NaN or infinity (+inf or -inf),
which is a valid Python float. Default True, set to False for compatibility with JSON.
:param max_digits: only applies to Decimals, requires the field to have a maximum number
of digits within the decimal. It does not include a zero before the decimal point or trailing decimal zeroes.
:param decimal_places: only applies to Decimals, requires the field to have at most a number of decimal places
Expand Down Expand Up @@ -307,6 +313,7 @@ def Field(
lt=lt,
le=le,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
min_items=min_items,
Expand Down
2 changes: 2 additions & 0 deletions pydantic/schema.py
Expand Up @@ -1115,6 +1115,8 @@ def constraint_func(**kw: Any) -> Type[Any]:
):
# Is numeric type
attrs = ('gt', 'lt', 'ge', 'le', 'multiple_of')
if issubclass(type_, float):
attrs += ('allow_inf_nan',)
if issubclass(type_, Decimal):
attrs += ('max_digits', 'decimal_places')
numeric_type = next(t for t in numeric_types if issubclass(type_, t)) # pragma: no branch
Expand Down
11 changes: 10 additions & 1 deletion pydantic/types.py
Expand Up @@ -38,6 +38,7 @@
constr_strip_whitespace,
constr_upper,
decimal_validator,
float_finite_validator,
float_validator,
frozenset_validator,
int_validator,
Expand Down Expand Up @@ -83,6 +84,7 @@
'NegativeFloat',
'NonNegativeFloat',
'NonPositiveFloat',
'FiniteFloat',
'ConstrainedDecimal',
'condecimal',
'UUID1',
Expand Down Expand Up @@ -265,6 +267,7 @@ class ConstrainedFloat(float, metaclass=ConstrainedNumberMeta):
lt: OptionalIntFloat = None
le: OptionalIntFloat = None
multiple_of: OptionalIntFloat = None
allow_inf_nan: Optional[bool] = None

@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
Expand All @@ -291,6 +294,7 @@ def __get_validators__(cls) -> 'CallableGenerator':
yield strict_float_validator if cls.strict else float_validator
yield number_size_validator
yield number_multiple_validator
yield float_finite_validator


def confloat(
Expand All @@ -301,9 +305,10 @@ def confloat(
lt: float = None,
le: float = None,
multiple_of: float = None,
allow_inf_nan: Optional[bool] = None,
) -> Type[float]:
# use kwargs then define conf in a dict to aid with IDE type hinting
namespace = dict(strict=strict, gt=gt, ge=ge, lt=lt, le=le, multiple_of=multiple_of)
namespace = dict(strict=strict, gt=gt, ge=ge, lt=lt, le=le, multiple_of=multiple_of, allow_inf_nan=allow_inf_nan)
return type('ConstrainedFloatValue', (ConstrainedFloat,), namespace)


Expand All @@ -313,6 +318,7 @@ def confloat(
NonPositiveFloat = float
NonNegativeFloat = float
StrictFloat = float
FiniteFloat = float
else:

class PositiveFloat(ConstrainedFloat):
Expand All @@ -330,6 +336,9 @@ class NonNegativeFloat(ConstrainedFloat):
class StrictFloat(ConstrainedFloat):
strict = True

class FiniteFloat(ConstrainedFloat):
allow_inf_nan = False


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BYTES TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
18 changes: 15 additions & 3 deletions pydantic/validators.py
@@ -1,3 +1,4 @@
import math
import re
from collections import OrderedDict, deque
from collections.abc import Hashable as CollectionsHashable
Expand Down Expand Up @@ -151,6 +152,16 @@ def strict_float_validator(v: Any) -> float:
raise errors.FloatError()


def float_finite_validator(v: 'Number', field: 'ModelField', config: 'BaseConfig') -> 'Number':
allow_inf_nan = getattr(field.type_, 'allow_inf_nan', None)
if allow_inf_nan is None:
allow_inf_nan = config.allow_inf_nan

if allow_inf_nan is False and (math.isnan(v) or math.isinf(v)):
raise errors.NumberNotFiniteError()
return v


def number_multiple_validator(v: 'Number', field: 'ModelField') -> 'Number':
field_type: ConstrainedNumber = field.type_
if field_type.multiple_of is not None:
Expand Down Expand Up @@ -611,12 +622,13 @@ def typeddict_validator(values: 'TypedDict') -> Dict[str, Any]: # type: ignore[


class IfConfig:
def __init__(self, validator: AnyCallable, *config_attr_names: str) -> None:
def __init__(self, validator: AnyCallable, *config_attr_names: str, ignored_value: Any = False) -> None:
self.validator = validator
self.config_attr_names = config_attr_names
self.ignored_value = ignored_value

def check(self, config: Type['BaseConfig']) -> bool:
return any(getattr(config, name) not in {None, False} for name in self.config_attr_names)
return any(getattr(config, name) not in {None, self.ignored_value} for name in self.config_attr_names)


# order is important here, for example: bool is a subclass of int so has to come first, datetime before date same,
Expand Down Expand Up @@ -646,7 +658,7 @@ def check(self, config: Type['BaseConfig']) -> bool:
),
(bool, [bool_validator]),
(int, [int_validator]),
(float, [float_validator]),
(float, [float_validator, IfConfig(float_finite_validator, 'allow_inf_nan', ignored_value=True)]),
(Path, [path_validator]),
(datetime, [parse_datetime]),
(date, [parse_date]),
Expand Down
9 changes: 8 additions & 1 deletion tests/mypy/test_mypy.py
Expand Up @@ -118,7 +118,7 @@ def test_success_cases_run(module: str) -> None:
importlib.import_module(f'tests.mypy.modules.{module}')


def test_explicit_reexports() -> None:
def test_explicit_reexports():
from pydantic import __all__ as root_all
from pydantic.main import __all__ as main
from pydantic.networks import __all__ as networks
Expand All @@ -130,6 +130,13 @@ def test_explicit_reexports() -> None:
assert export in root_all, f'{export} is in {name}.__all__ but missing from re-export in __init__.py'


def test_explicit_reexports_exist():
import pydantic

for name in pydantic.__all__:
assert hasattr(pydantic, name), f'{name} is in pydantic.__all__ but missing from pydantic'


@pytest.mark.skipif(mypy_version is None, reason='mypy is not installed')
@pytest.mark.parametrize(
'v_str,v_tuple',
Expand Down
62 changes: 59 additions & 3 deletions tests/test_types.py
@@ -1,4 +1,5 @@
import itertools
import math
import os
import re
import sys
Expand Down Expand Up @@ -42,6 +43,7 @@
EmailStr,
Field,
FilePath,
FiniteFloat,
FutureDate,
Json,
NameEmail,
Expand Down Expand Up @@ -1565,12 +1567,16 @@ class Model(BaseModel):
e: confloat(gt=4, lt=12.2) = None
f: confloat(ge=0, le=9.9) = None
g: confloat(multiple_of=0.5) = None
h: confloat(allow_inf_nan=False) = None

m = Model(a=5.1, b=-5.2, c=0, d=0, e=5.3, f=9.9, g=2.5)
assert m.dict() == {'a': 5.1, 'b': -5.2, 'c': 0, 'd': 0, 'e': 5.3, 'f': 9.9, 'g': 2.5}
m = Model(a=5.1, b=-5.2, c=0, d=0, e=5.3, f=9.9, g=2.5, h=42)
assert m.dict() == {'a': 5.1, 'b': -5.2, 'c': 0, 'd': 0, 'e': 5.3, 'f': 9.9, 'g': 2.5, 'h': 42}

assert Model(a=float('inf')).a == float('inf')
assert Model(b=float('-inf')).b == float('-inf')

with pytest.raises(ValidationError) as exc_info:
Model(a=-5.1, b=5.2, c=-5.1, d=5.1, e=-5.3, f=9.91, g=4.2)
Model(a=-5.1, b=5.2, c=-5.1, d=5.1, e=-5.3, f=9.91, g=4.2, h=float('nan'))
assert exc_info.value.errors() == [
{
'loc': ('a',),
Expand Down Expand Up @@ -1614,6 +1620,56 @@ class Model(BaseModel):
'type': 'value_error.number.not_multiple',
'ctx': {'multiple_of': 0.5},
},
{
'loc': ('h',),
'msg': 'ensure this value is a finite number',
'type': 'value_error.number.not_finite_number',
},
]


def test_finite_float_validation():
class Model(BaseModel):
a: float = None

assert Model(a=float('inf')).a == float('inf')
assert Model(a=float('-inf')).a == float('-inf')
assert math.isnan(Model(a=float('nan')).a)


@pytest.mark.parametrize('value', [float('inf'), float('-inf'), float('nan')])
def test_finite_float_validation_error(value):
class Model(BaseModel):
a: FiniteFloat

assert Model(a=42).a == 42
with pytest.raises(ValidationError) as exc_info:
Model(a=value)
assert exc_info.value.errors() == [
{
'loc': ('a',),
'msg': 'ensure this value is a finite number',
'type': 'value_error.number.not_finite_number',
},
]


def test_finite_float_config():
class Model(BaseModel):
a: float

class Config:
allow_inf_nan = False

assert Model(a=42).a == 42
with pytest.raises(ValidationError) as exc_info:
Model(a=float('nan'))
assert exc_info.value.errors() == [
{
'loc': ('a',),
'msg': 'ensure this value is a finite number',
'type': 'value_error.number.not_finite_number',
},
]


Expand Down

0 comments on commit 8dade7e

Please sign in to comment.