diff --git a/changes/4538-sisp.md b/changes/4538-sisp.md new file mode 100644 index 0000000000..afb3312ba5 --- /dev/null +++ b/changes/4538-sisp.md @@ -0,0 +1 @@ +Fix field regex with `StrictStr` type annotation. diff --git a/pydantic/types.py b/pydantic/types.py index f98dba3de4..eaf679d48e 100644 --- a/pydantic/types.py +++ b/pydantic/types.py @@ -403,7 +403,7 @@ class ConstrainedStr(str): min_length: OptionalInt = None max_length: OptionalInt = None curtail_length: OptionalInt = None - regex: Optional[Pattern[str]] = None + regex: Optional[Union[str, Pattern[str]]] = None strict = False @classmethod @@ -412,7 +412,7 @@ def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: field_schema, minLength=cls.min_length, maxLength=cls.max_length, - pattern=cls.regex and cls.regex.pattern, + pattern=cls.regex and cls._get_pattern(cls.regex), ) @classmethod @@ -430,11 +430,15 @@ def validate(cls, value: Union[str]) -> Union[str]: value = value[: cls.curtail_length] if cls.regex: - if not cls.regex.match(value): - raise errors.StrRegexError(pattern=cls.regex.pattern) + if not re.match(cls.regex, value): + raise errors.StrRegexError(pattern=cls._get_pattern(cls.regex)) return value + @staticmethod + def _get_pattern(regex: Union[str, Pattern[str]]) -> str: + return regex if isinstance(regex, str) else regex.pattern + def constr( *, diff --git a/tests/test_types.py b/tests/test_types.py index af4a91ef1d..785e335b91 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -1759,6 +1759,45 @@ class Model(BaseModel): Model(u='1234567') +def test_strict_str_regex(): + class Model(BaseModel): + u: StrictStr = Field(..., regex=r'^[0-9]+$') + + assert Model(u='123').u == '123' + + with pytest.raises(ValidationError, match='str type expected'): + Model(u=123) + + with pytest.raises(ValidationError) as exc_info: + Model(u='abc') + assert exc_info.value.errors() == [ + { + 'loc': ('u',), + 'msg': 'string does not match regex "^[0-9]+$"', + 'type': 'value_error.str.regex', + 'ctx': {'pattern': '^[0-9]+$'}, + } + ] + + +def test_string_regex(): + class Model(BaseModel): + u: str = Field(..., regex=r'^[0-9]+$') + + assert Model(u='123').u == '123' + + with pytest.raises(ValidationError) as exc_info: + Model(u='abc') + assert exc_info.value.errors() == [ + { + 'loc': ('u',), + 'msg': 'string does not match regex "^[0-9]+$"', + 'type': 'value_error.str.regex', + 'ctx': {'pattern': '^[0-9]+$'}, + } + ] + + def test_strict_bool(): class Model(BaseModel): v: StrictBool