diff --git a/changes/4406-acmiyaguchi.md b/changes/4406-acmiyaguchi.md new file mode 100644 index 0000000000..c7028c4d15 --- /dev/null +++ b/changes/4406-acmiyaguchi.md @@ -0,0 +1 @@ +Allow for custom parsing of environment variables via `parse_env_var` in `Config`. diff --git a/docs/examples/settings_with_custom_parsing.py b/docs/examples/settings_with_custom_parsing.py new file mode 100644 index 0000000000..9de25643e1 --- /dev/null +++ b/docs/examples/settings_with_custom_parsing.py @@ -0,0 +1,19 @@ +import os +from typing import Any, List + +from pydantic import BaseSettings + + +class Settings(BaseSettings): + numbers: List[int] + + class Config: + @classmethod + def parse_env_var(cls, field_name: str, raw_val: str) -> Any: + if field_name == 'numbers': + return [int(x) for x in raw_val.split(',')] + return cls.json_loads(raw_val) + + +os.environ['numbers'] = '1,2,3' +print(Settings().dict()) diff --git a/docs/usage/settings.md b/docs/usage/settings.md index 1ba926f6eb..7268e9f0fa 100644 --- a/docs/usage/settings.md +++ b/docs/usage/settings.md @@ -87,7 +87,7 @@ export SUB_MODEL__DEEP__V4=v4 You could load a settings module thus: {!.tmp_examples/settings_nested_env.md!} -`env_nested_delimiter` can be configured via the `Config` class as shown above, or via the +`env_nested_delimiter` can be configured via the `Config` class as shown above, or via the `_env_nested_delimiter` keyword argument on instantiation. JSON is only parsed in top-level fields, if you need to parse JSON in sub-models, you will need to implement @@ -96,6 +96,11 @@ validators on those models. Nested environment variables take precedence over the top-level environment variable JSON (e.g. in the example above, `SUB_MODEL__V2` trumps `SUB_MODEL`). +You may also populate a complex type by providing your own parsing function to +the `parse_env_var` classmethod in the Config object. + +{!.tmp_examples/settings_with_custom_parsing.md!} + ## Dotenv (.env) support !!! note @@ -178,7 +183,7 @@ see [python-dotenv's documentation](https://saurabh-kumar.com/python-dotenv/#usa Placing secret values in files is a common pattern to provide sensitive configuration to an application. -A secret file follows the same principal as a dotenv file except it only contains a single value and the file name +A secret file follows the same principal as a dotenv file except it only contains a single value and the file name is used as the key. A secret file will look like the following: `/var/run/database_password`: @@ -231,7 +236,7 @@ class Settings(BaseSettings): secrets_dir = '/run/secrets' ``` !!! note - By default Docker uses `/run/secrets` as the target mount point. If you want to use a different location, change + By default Docker uses `/run/secrets` as the target mount point. If you want to use a different location, change `Config.secrets_dir` accordingly. Then, create your secret via the Docker CLI diff --git a/pydantic/env_settings.py b/pydantic/env_settings.py index 6be4b8e710..7587a1f50c 100644 --- a/pydantic/env_settings.py +++ b/pydantic/env_settings.py @@ -126,6 +126,10 @@ def customise_sources( ) -> Tuple[SettingsSourceCallable, ...]: return init_settings, env_settings, file_secret_settings + @classmethod + def parse_env_var(cls, field_name: str, raw_val: str) -> Any: + return cls.json_loads(raw_val) + # populated by the metaclass using the Config class defined above, annotated here to help IDEs only __config__: ClassVar[Type[Config]] @@ -180,7 +184,7 @@ def __call__(self, settings: BaseSettings) -> Dict[str, Any]: # noqa C901 if env_val is not None: break - is_complex, allow_json_failure = self.field_is_complex(field) + is_complex, allow_parse_failure = self.field_is_complex(field) if is_complex: if env_val is None: # field is complex but no value found so far, try explode_env_vars @@ -190,10 +194,10 @@ def __call__(self, settings: BaseSettings) -> Dict[str, Any]: # noqa C901 else: # field is complex and there's a value, decode that as JSON, then add explode_env_vars try: - env_val = settings.__config__.json_loads(env_val) + env_val = settings.__config__.parse_env_var(field.name, env_val) except ValueError as e: - if not allow_json_failure: - raise SettingsError(f'error parsing JSON for "{env_name}"') from e + if not allow_parse_failure: + raise SettingsError(f'error parsing env var "{env_name}"') from e if isinstance(env_val, dict): d[field.alias] = deep_update(env_val, self.explode_env_vars(field, env_vars)) @@ -228,13 +232,13 @@ def field_is_complex(self, field: ModelField) -> Tuple[bool, bool]: Find out if a field is complex, and if so whether JSON errors should be ignored """ if field.is_complex(): - allow_json_failure = False + allow_parse_failure = False elif is_union(get_origin(field.type_)) and field.sub_fields and any(f.is_complex() for f in field.sub_fields): - allow_json_failure = True + allow_parse_failure = True else: return False, False - return True, allow_json_failure + return True, allow_parse_failure def explode_env_vars(self, field: ModelField, env_vars: Mapping[str, Optional[str]]) -> Dict[str, Any]: """ @@ -299,9 +303,9 @@ def __call__(self, settings: BaseSettings) -> Dict[str, Any]: secret_value = path.read_text().strip() if field.is_complex(): try: - secret_value = settings.__config__.json_loads(secret_value) + secret_value = settings.__config__.parse_env_var(field.name, secret_value) except ValueError as e: - raise SettingsError(f'error parsing JSON for "{env_name}"') from e + raise SettingsError(f'error parsing env var "{env_name}"') from e secrets[field.alias] = secret_value else: diff --git a/tests/test_settings.py b/tests/test_settings.py index e08fd08b6a..d61e6209bc 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -3,7 +3,7 @@ import uuid from datetime import datetime, timezone from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import pytest @@ -221,7 +221,7 @@ def test_set_dict_model(env): def test_invalid_json(env): env.set('apples', '["russet", "granny smith",]') - with pytest.raises(SettingsError, match='error parsing JSON for "apples"'): + with pytest.raises(SettingsError, match='error parsing env var "apples"'): ComplexSettings() @@ -1054,7 +1054,7 @@ class Settings(BaseSettings): class Config: secrets_dir = tmp_path - with pytest.raises(SettingsError, match='error parsing JSON for "foo"'): + with pytest.raises(SettingsError, match='error parsing env var "foo"'): Settings() @@ -1215,3 +1215,66 @@ def test_builtins_settings_source_repr(): == "EnvSettingsSource(env_file='.env', env_file_encoding='utf-8', env_nested_delimiter=None)" ) assert repr(SecretsSettingsSource(secrets_dir='/secrets')) == "SecretsSettingsSource(secrets_dir='/secrets')" + + +def _parse_custom_dict(value: str) -> Callable[[str], Dict[int, str]]: + """A custom parsing function passed into env parsing test.""" + res = {} + for part in value.split(','): + k, v = part.split('=') + res[int(k)] = v + return res + + +def test_env_setting_source_custom_env_parse(env): + class Settings(BaseSettings): + top: Dict[int, str] + + class Config: + @classmethod + def parse_env_var(cls, field_name: str, raw_val: str): + if field_name == 'top': + return _parse_custom_dict(raw_val) + return cls.json_loads(raw_val) + + with pytest.raises(ValidationError): + Settings() + env.set('top', '1=apple,2=banana') + s = Settings() + assert s.top == {1: 'apple', 2: 'banana'} + + +def test_env_settings_source_custom_env_parse_is_bad(env): + class Settings(BaseSettings): + top: Dict[int, str] + + class Config: + @classmethod + def parse_env_var(cls, field_name: str, raw_val: str): + if field_name == 'top': + return int(raw_val) + return cls.json_loads(raw_val) + + env.set('top', '1=apple,2=banana') + with pytest.raises(SettingsError, match='error parsing env var "top"'): + Settings() + + +def test_secret_settings_source_custom_env_parse(tmp_path): + p = tmp_path / 'top' + p.write_text('1=apple,2=banana') + + class Settings(BaseSettings): + top: Dict[int, str] + + class Config: + secrets_dir = tmp_path + + @classmethod + def parse_env_var(cls, field_name: str, raw_val: str): + if field_name == 'top': + return _parse_custom_dict(raw_val) + return cls.json_loads(raw_val) + + s = Settings() + assert s.top == {1: 'apple', 2: 'banana'}