diff --git a/src/dotenv/main.py b/src/dotenv/main.py index 58a23f3d..ea523d48 100644 --- a/src/dotenv/main.py +++ b/src/dotenv/main.py @@ -4,7 +4,6 @@ import io import logging import os -import re import shutil import sys import tempfile @@ -13,13 +12,13 @@ from .compat import IS_TYPE_CHECKING, PY2, StringIO, to_env from .parser import Binding, parse_stream +from .variables import parse_variables logger = logging.getLogger(__name__) if IS_TYPE_CHECKING: - from typing import ( - Dict, Iterable, Iterator, Match, Optional, Pattern, Union, Text, IO, Tuple - ) + from typing import (IO, Dict, Iterable, Iterator, Mapping, Optional, Text, + Tuple, Union) if sys.version_info >= (3, 6): _PathLike = os.PathLike else: @@ -30,18 +29,6 @@ else: _StringIO = StringIO[Text] -__posix_variable = re.compile( - r""" - \$\{ - (?P[^\}:]*) - (?::- - (?P[^\}]*) - )? - \} - """, - re.VERBOSE, -) # type: Pattern[Text] - def with_warn_for_invalid_lines(mappings): # type: (Iterator[Binding]) -> Iterator[Binding] @@ -83,13 +70,14 @@ def dict(self): if self._dict: return self._dict + raw_values = self.parse() + if self.interpolate: - values = resolve_nested_variables(self.parse()) + self._dict = OrderedDict(resolve_variables(raw_values)) else: - values = OrderedDict(self.parse()) + self._dict = OrderedDict(raw_values) - self._dict = values - return values + return self._dict def parse(self): # type: () -> Iterator[Tuple[Text, Optional[Text]]] @@ -217,27 +205,22 @@ def unset_key(dotenv_path, key_to_unset, quote_mode="always"): return removed, key_to_unset -def resolve_nested_variables(values): - # type: (Iterable[Tuple[Text, Optional[Text]]]) -> Dict[Text, Optional[Text]] - def _replacement(name, default): - # type: (Text, Optional[Text]) -> Text - default = default if default is not None else "" - ret = new_values.get(name, os.getenv(name, default)) - return ret # type: ignore +def resolve_variables(values): + # type: (Iterable[Tuple[Text, Optional[Text]]]) -> Mapping[Text, Optional[Text]] - def _re_sub_callback(match): - # type: (Match[Text]) -> Text - """ - From a match object gets the variable name and returns - the correct replacement - """ - matches = match.groupdict() - return _replacement(name=matches["name"], default=matches["default"]) # type: ignore + new_values = {} # type: Dict[Text, Optional[Text]] - new_values = {} + for (name, value) in values: + if value is None: + result = None + else: + atoms = parse_variables(value) + env = {} # type: Dict[Text, Optional[Text]] + env.update(os.environ) # type: ignore + env.update(new_values) + result = "".join(atom.resolve(env) for atom in atoms) - for (k, v) in values: - new_values[k] = __posix_variable.sub(_re_sub_callback, v) if v is not None else None + new_values[name] = result return new_values diff --git a/src/dotenv/variables.py b/src/dotenv/variables.py new file mode 100644 index 00000000..4828dfc2 --- /dev/null +++ b/src/dotenv/variables.py @@ -0,0 +1,106 @@ +import re +from abc import ABCMeta + +from .compat import IS_TYPE_CHECKING + +if IS_TYPE_CHECKING: + from typing import Iterator, Mapping, Optional, Pattern, Text + + +_posix_variable = re.compile( + r""" + \$\{ + (?P[^\}:]*) + (?::- + (?P[^\}]*) + )? + \} + """, + re.VERBOSE, +) # type: Pattern[Text] + + +class Atom(): + __metaclass__ = ABCMeta + + def __ne__(self, other): + # type: (object) -> bool + result = self.__eq__(other) + if result is NotImplemented: + return NotImplemented + return not result + + def resolve(self, env): + # type: (Mapping[Text, Optional[Text]]) -> Text + raise NotImplementedError + + +class Literal(Atom): + def __init__(self, value): + # type: (Text) -> None + self.value = value + + def __repr__(self): + # type: () -> str + return "Literal(value={})".format(self.value) + + def __eq__(self, other): + # type: (object) -> bool + if not isinstance(other, self.__class__): + return NotImplemented + return self.value == other.value + + def __hash__(self): + # type: () -> int + return hash((self.__class__, self.value)) + + def resolve(self, env): + # type: (Mapping[Text, Optional[Text]]) -> Text + return self.value + + +class Variable(Atom): + def __init__(self, name, default): + # type: (Text, Optional[Text]) -> None + self.name = name + self.default = default + + def __repr__(self): + # type: () -> str + return "Variable(name={}, default={})".format(self.name, self.default) + + def __eq__(self, other): + # type: (object) -> bool + if not isinstance(other, self.__class__): + return NotImplemented + return (self.name, self.default) == (other.name, other.default) + + def __hash__(self): + # type: () -> int + return hash((self.__class__, self.name, self.default)) + + def resolve(self, env): + # type: (Mapping[Text, Optional[Text]]) -> Text + default = self.default if self.default is not None else "" + result = env.get(self.name, default) + return result if result is not None else "" + + +def parse_variables(value): + # type: (Text) -> Iterator[Atom] + cursor = 0 + + for match in _posix_variable.finditer(value): + (start, end) = match.span() + name = match.groupdict()["name"] + default = match.groupdict()["default"] + + if start > cursor: + yield Literal(value=value[cursor:start]) + + yield Variable(name=name, default=default) + cursor = end + + length = len(value) + if cursor < length: + yield Literal(value=value[cursor:length]) diff --git a/tests/test_variables.py b/tests/test_variables.py new file mode 100644 index 00000000..86b06466 --- /dev/null +++ b/tests/test_variables.py @@ -0,0 +1,35 @@ +import pytest + +from dotenv.variables import Literal, Variable, parse_variables + + +@pytest.mark.parametrize( + "value,expected", + [ + ("", []), + ("a", [Literal(value="a")]), + ("${a}", [Variable(name="a", default=None)]), + ("${a:-b}", [Variable(name="a", default="b")]), + ( + "${a}${b}", + [ + Variable(name="a", default=None), + Variable(name="b", default=None), + ], + ), + ( + "a${b}c${d}e", + [ + Literal(value="a"), + Variable(name="b", default=None), + Literal(value="c"), + Variable(name="d", default=None), + Literal(value="e"), + ], + ), + ] +) +def test_parse_variables(value, expected): + result = parse_variables(value) + + assert list(result) == expected