From e2be8d605222eb0e9be1cb6ab1236d7f722f9722 Mon Sep 17 00:00:00 2001 From: Bertrand Bonnefoy-Claudet Date: Sat, 19 Feb 2022 14:38:01 +0100 Subject: [PATCH] Add encoding parameter to {get,set,unset}_key The parameter already exists for `dotenv_values` and `load_dotenv` and has the same meaning. --- CHANGELOG.md | 7 +++++++ src/dotenv/main.py | 29 +++++++++++++++++++---------- tests/test_main.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b18856e..b0c06dd3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## Unreleased + +### Added + +- Add `encoding` (`Optional[str]`) parameter to `get_key`, `set_key` and `unset_key`. + (#? by [@bbc2]) + ## [0.19.2] - 2021-11-11 ### Fixed diff --git a/src/dotenv/main.py b/src/dotenv/main.py index d867f023..20ac61ba 100644 --- a/src/dotenv/main.py +++ b/src/dotenv/main.py @@ -109,23 +109,30 @@ def get(self, key: str) -> Optional[str]: return None -def get_key(dotenv_path: Union[str, _PathLike], key_to_get: str) -> Optional[str]: +def get_key( + dotenv_path: Union[str, _PathLike], + key_to_get: str, + encoding: Optional[str] = "utf-8", +) -> Optional[str]: """ - Gets the value of a given key from the given .env + Get the value of a given key from the given .env. - If the .env path given doesn't exist, fails + Returns `None` if the key isn't found or doesn't have a value. """ - return DotEnv(dotenv_path, verbose=True).get(key_to_get) + return DotEnv(dotenv_path, verbose=True, encoding=encoding).get(key_to_get) @contextmanager -def rewrite(path: Union[str, _PathLike]) -> Iterator[Tuple[IO[str], IO[str]]]: +def rewrite( + path: Union[str, _PathLike], + encoding: Optional[str], +) -> Iterator[Tuple[IO[str], IO[str]]]: try: if not os.path.isfile(path): - with io.open(path, "w+") as source: + with io.open(path, "w+", encoding=encoding) as source: source.write("") - with tempfile.NamedTemporaryFile(mode="w+", delete=False) as dest: - with io.open(path) as source: + with tempfile.NamedTemporaryFile(mode="w+", delete=False, encoding=encoding) as dest: + with io.open(path, encoding=encoding) as source: yield (source, dest) # type: ignore except BaseException: if os.path.isfile(dest.name): @@ -141,6 +148,7 @@ def set_key( value_to_set: str, quote_mode: str = "always", export: bool = False, + encoding: Optional[str] = "utf-8", ) -> Tuple[Optional[bool], str, str]: """ Adds or Updates a key/value to the given .env @@ -165,7 +173,7 @@ def set_key( else: line_out = "{}={}\n".format(key_to_set, value_out) - with rewrite(dotenv_path) as (source, dest): + with rewrite(dotenv_path, encoding=encoding) as (source, dest): replaced = False missing_newline = False for mapping in with_warn_for_invalid_lines(parse_stream(source)): @@ -187,6 +195,7 @@ def unset_key( dotenv_path: Union[str, _PathLike], key_to_unset: str, quote_mode: str = "always", + encoding: Optional[str] = "utf-8", ) -> Tuple[Optional[bool], str]: """ Removes a given key from the given .env @@ -199,7 +208,7 @@ def unset_key( return None, key_to_unset removed = False - with rewrite(dotenv_path) as (source, dest): + with rewrite(dotenv_path, encoding=encoding) as (source, dest): for mapping in with_warn_for_invalid_lines(parse_stream(source)): if mapping.key == key_to_unset: removed = True diff --git a/tests/test_main.py b/tests/test_main.py index 541ac5ee..364fc24d 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -53,6 +53,15 @@ def test_set_key(dotenv_file, before, key, value, expected, after): mock_warning.assert_not_called() +def test_set_key_encoding(dotenv_file): + encoding = "latin-1" + + result = dotenv.set_key(dotenv_file, "a", "é", encoding=encoding) + + assert result == (True, "a", "é") + assert open(dotenv_file, "r", encoding=encoding).read() == "a='é'\n" + + def test_set_key_permission_error(dotenv_file): os.chmod(dotenv_file, 0o000) @@ -107,6 +116,16 @@ def test_get_key_ok(dotenv_file): mock_warning.assert_not_called() +def test_get_key_encoding(dotenv_file): + encoding = "latin-1" + with open(dotenv_file, "w", encoding=encoding) as f: + f.write("é=è") + + result = dotenv.get_key(dotenv_file, "é", encoding=encoding) + + assert result == "è" + + def test_get_key_none(dotenv_file): logger = logging.getLogger("dotenv.main") with open(dotenv_file, "w") as f: @@ -147,6 +166,18 @@ def test_unset_no_value(dotenv_file): mock_warning.assert_not_called() +def test_unset_encoding(dotenv_file): + encoding = "latin-1" + with open(dotenv_file, "w", encoding=encoding) as f: + f.write("é=x") + + result = dotenv.unset_key(dotenv_file, "é", encoding=encoding) + + assert result == (True, "é") + with open(dotenv_file, "r", encoding=encoding) as f: + assert f.read() == "" + + def test_unset_non_existent_file(tmp_path): nx_file = str(tmp_path / "nx") logger = logging.getLogger("dotenv.main")