From 0d022db3e0a4c68b8ad872761f96d520dcacd27f Mon Sep 17 00:00:00 2001 From: Sigurd Spieckermann Date: Tue, 20 Sep 2022 00:17:40 +0200 Subject: [PATCH 1/3] Fix field regex with StrictStr type annotation --- pydantic/types.py | 16 +++++++++++----- tests/test_types.py | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 5 deletions(-) diff --git a/pydantic/types.py b/pydantic/types.py index f98dba3de4..1977b8aa17 100644 --- a/pydantic/types.py +++ b/pydantic/types.py @@ -403,16 +403,17 @@ class ConstrainedStr(str): min_length: OptionalInt = None max_length: OptionalInt = None curtail_length: OptionalInt = None - regex: Optional[Pattern[str]] = None + regex: Optional[Union[str, Pattern[str]]] = None strict = False @classmethod def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + regex = cls._regex() update_not_none( field_schema, minLength=cls.min_length, maxLength=cls.max_length, - pattern=cls.regex and cls.regex.pattern, + pattern=regex and regex.pattern, ) @classmethod @@ -429,12 +430,17 @@ def validate(cls, value: Union[str]) -> Union[str]: if cls.curtail_length and len(value) > cls.curtail_length: value = value[: cls.curtail_length] - if cls.regex: - if not cls.regex.match(value): - raise errors.StrRegexError(pattern=cls.regex.pattern) + regex = cls._regex() + if regex: + if not regex.match(value): + raise errors.StrRegexError(pattern=regex.pattern) return value + @classmethod + def _regex(cls) -> Optional[Pattern[str]]: + return re.compile(cls.regex) if isinstance(cls.regex, str) else cls.regex + def constr( *, diff --git a/tests/test_types.py b/tests/test_types.py index af4a91ef1d..785e335b91 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -1759,6 +1759,45 @@ class Model(BaseModel): Model(u='1234567') +def test_strict_str_regex(): + class Model(BaseModel): + u: StrictStr = Field(..., regex=r'^[0-9]+$') + + assert Model(u='123').u == '123' + + with pytest.raises(ValidationError, match='str type expected'): + Model(u=123) + + with pytest.raises(ValidationError) as exc_info: + Model(u='abc') + assert exc_info.value.errors() == [ + { + 'loc': ('u',), + 'msg': 'string does not match regex "^[0-9]+$"', + 'type': 'value_error.str.regex', + 'ctx': {'pattern': '^[0-9]+$'}, + } + ] + + +def test_string_regex(): + class Model(BaseModel): + u: str = Field(..., regex=r'^[0-9]+$') + + assert Model(u='123').u == '123' + + with pytest.raises(ValidationError) as exc_info: + Model(u='abc') + assert exc_info.value.errors() == [ + { + 'loc': ('u',), + 'msg': 'string does not match regex "^[0-9]+$"', + 'type': 'value_error.str.regex', + 'ctx': {'pattern': '^[0-9]+$'}, + } + ] + + def test_strict_bool(): class Model(BaseModel): v: StrictBool From f4e084a80b2b7632a8aa23a12700da0e5e7e952a Mon Sep 17 00:00:00 2001 From: Sigurd Spieckermann Date: Tue, 20 Sep 2022 00:44:50 +0200 Subject: [PATCH 2/3] Add change file --- changes/4538-sisp.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changes/4538-sisp.md diff --git a/changes/4538-sisp.md b/changes/4538-sisp.md new file mode 100644 index 0000000000..afb3312ba5 --- /dev/null +++ b/changes/4538-sisp.md @@ -0,0 +1 @@ +Fix field regex with `StrictStr` type annotation. From 897d4c1fd488089ab747b1b17de6c9e810618821 Mon Sep 17 00:00:00 2001 From: Sigurd Spieckermann Date: Tue, 20 Sep 2022 12:28:37 +0200 Subject: [PATCH 3/3] Improve regex caching --- pydantic/types.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/pydantic/types.py b/pydantic/types.py index 1977b8aa17..eaf679d48e 100644 --- a/pydantic/types.py +++ b/pydantic/types.py @@ -408,12 +408,11 @@ class ConstrainedStr(str): @classmethod def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - regex = cls._regex() update_not_none( field_schema, minLength=cls.min_length, maxLength=cls.max_length, - pattern=regex and regex.pattern, + pattern=cls.regex and cls._get_pattern(cls.regex), ) @classmethod @@ -430,16 +429,15 @@ def validate(cls, value: Union[str]) -> Union[str]: if cls.curtail_length and len(value) > cls.curtail_length: value = value[: cls.curtail_length] - regex = cls._regex() - if regex: - if not regex.match(value): - raise errors.StrRegexError(pattern=regex.pattern) + if cls.regex: + if not re.match(cls.regex, value): + raise errors.StrRegexError(pattern=cls._get_pattern(cls.regex)) return value - @classmethod - def _regex(cls) -> Optional[Pattern[str]]: - return re.compile(cls.regex) if isinstance(cls.regex, str) else cls.regex + @staticmethod + def _get_pattern(regex: Union[str, Pattern[str]]) -> str: + return regex if isinstance(regex, str) else regex.pattern def constr(