Skip to content

Commit

Permalink
Add encoding parameter to {get,set,unset}_key
Browse files Browse the repository at this point in the history
The parameter already exists for `dotenv_values` and `load_dotenv` and
has the same meaning.
  • Loading branch information
bbc2 committed Mar 12, 2022
1 parent ba9408c commit 157282c
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 10 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Expand Up @@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this
project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased

### Added

- Add `encoding` (`Optional[str]`) parameter to `get_key`, `set_key` and `unset_key`.
(#379 by [@bbc2])

## [0.19.2] - 2021-11-11

### Fixed
Expand Down
29 changes: 19 additions & 10 deletions src/dotenv/main.py
Expand Up @@ -109,23 +109,30 @@ def get(self, key: str) -> Optional[str]:
return None


def get_key(dotenv_path: Union[str, _PathLike], key_to_get: str) -> Optional[str]:
def get_key(
dotenv_path: Union[str, _PathLike],
key_to_get: str,
encoding: Optional[str] = "utf-8",
) -> Optional[str]:
"""
Gets the value of a given key from the given .env
Get the value of a given key from the given .env.
If the .env path given doesn't exist, fails
Returns `None` if the key isn't found or doesn't have a value.
"""
return DotEnv(dotenv_path, verbose=True).get(key_to_get)
return DotEnv(dotenv_path, verbose=True, encoding=encoding).get(key_to_get)


@contextmanager
def rewrite(path: Union[str, _PathLike]) -> Iterator[Tuple[IO[str], IO[str]]]:
def rewrite(
path: Union[str, _PathLike],
encoding: Optional[str],
) -> Iterator[Tuple[IO[str], IO[str]]]:
try:
if not os.path.isfile(path):
with io.open(path, "w+") as source:
with io.open(path, "w+", encoding=encoding) as source:
source.write("")
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as dest:
with io.open(path) as source:
with tempfile.NamedTemporaryFile(mode="w+", delete=False, encoding=encoding) as dest:
with io.open(path, encoding=encoding) as source:
yield (source, dest) # type: ignore
except BaseException:
if os.path.isfile(dest.name):
Expand All @@ -141,6 +148,7 @@ def set_key(
value_to_set: str,
quote_mode: str = "always",
export: bool = False,
encoding: Optional[str] = "utf-8",
) -> Tuple[Optional[bool], str, str]:
"""
Adds or Updates a key/value to the given .env
Expand All @@ -165,7 +173,7 @@ def set_key(
else:
line_out = "{}={}\n".format(key_to_set, value_out)

with rewrite(dotenv_path) as (source, dest):
with rewrite(dotenv_path, encoding=encoding) as (source, dest):
replaced = False
missing_newline = False
for mapping in with_warn_for_invalid_lines(parse_stream(source)):
Expand All @@ -187,6 +195,7 @@ def unset_key(
dotenv_path: Union[str, _PathLike],
key_to_unset: str,
quote_mode: str = "always",
encoding: Optional[str] = "utf-8",
) -> Tuple[Optional[bool], str]:
"""
Removes a given key from the given .env
Expand All @@ -199,7 +208,7 @@ def unset_key(
return None, key_to_unset

removed = False
with rewrite(dotenv_path) as (source, dest):
with rewrite(dotenv_path, encoding=encoding) as (source, dest):
for mapping in with_warn_for_invalid_lines(parse_stream(source)):
if mapping.key == key_to_unset:
removed = True
Expand Down
31 changes: 31 additions & 0 deletions tests/test_main.py
Expand Up @@ -53,6 +53,15 @@ def test_set_key(dotenv_file, before, key, value, expected, after):
mock_warning.assert_not_called()


def test_set_key_encoding(dotenv_file):
encoding = "latin-1"

result = dotenv.set_key(dotenv_file, "a", "é", encoding=encoding)

assert result == (True, "a", "é")
assert open(dotenv_file, "r", encoding=encoding).read() == "a='é'\n"


def test_set_key_permission_error(dotenv_file):
os.chmod(dotenv_file, 0o000)

Expand Down Expand Up @@ -107,6 +116,16 @@ def test_get_key_ok(dotenv_file):
mock_warning.assert_not_called()


def test_get_key_encoding(dotenv_file):
encoding = "latin-1"
with open(dotenv_file, "w", encoding=encoding) as f:
f.write("é=è")

result = dotenv.get_key(dotenv_file, "é", encoding=encoding)

assert result == "è"


def test_get_key_none(dotenv_file):
logger = logging.getLogger("dotenv.main")
with open(dotenv_file, "w") as f:
Expand Down Expand Up @@ -147,6 +166,18 @@ def test_unset_no_value(dotenv_file):
mock_warning.assert_not_called()


def test_unset_encoding(dotenv_file):
encoding = "latin-1"
with open(dotenv_file, "w", encoding=encoding) as f:
f.write("é=x")

result = dotenv.unset_key(dotenv_file, "é", encoding=encoding)

assert result == (True, "é")
with open(dotenv_file, "r", encoding=encoding) as f:
assert f.read() == ""


def test_unset_non_existent_file(tmp_path):
nx_file = str(tmp_path / "nx")
logger = logging.getLogger("dotenv.main")
Expand Down

0 comments on commit 157282c

Please sign in to comment.