diff --git a/tests/test_items.py b/tests/test_items.py index 057ed46..e36a130 100644 --- a/tests/test_items.py +++ b/tests/test_items.py @@ -10,20 +10,32 @@ from tomlkit import api from tomlkit import parse +from tomlkit.check import is_tomlkit from tomlkit.exceptions import NonExistentKey +from tomlkit.items import AoT +from tomlkit.items import Array from tomlkit.items import Bool from tomlkit.items import Comment +from tomlkit.items import Date +from tomlkit.items import DateTime +from tomlkit.items import Float from tomlkit.items import InlineTable from tomlkit.items import Integer +from tomlkit.items import Item from tomlkit.items import KeyType +from tomlkit.items import Null from tomlkit.items import SingleKey as Key from tomlkit.items import String from tomlkit.items import StringType from tomlkit.items import Table +from tomlkit.items import Time from tomlkit.items import Trivia from tomlkit.items import item from tomlkit.parser import Parser +from .util import assert_is_ppo +from .util import elementary_test + @pytest.fixture() def tz_pst(): @@ -69,6 +81,98 @@ def dst(self, dt): return UTC() +def test_item_base_has_no_unwrap(): + trivia = Trivia(indent="\t", comment_ws=" ", comment="For unit test") + item = Item(trivia) + try: + item.unwrap() + except NotImplementedError: + pass + else: + raise AssertionError("`items.Item` should not implement `unwrap`") + + +def test_integer_unwrap(): + elementary_test(item(666), int) + + +def test_float_unwrap(): + elementary_test(item(2.78), float) + + +def test_false_unwrap(): + elementary_test(item(False), bool) + + +def test_true_unwrap(): + elementary_test(item(True), bool) + + +def test_datetime_unwrap(): + dt = datetime.utcnow() + elementary_test(item(dt), datetime) + + +def test_string_unwrap(): + elementary_test(item("hello"), str) + + +def test_null_unwrap(): + n = Null() + elementary_test(n, type(None)) + + +def test_aot_unwrap(): + d = item([{"a": "A"}, {"b": "B"}]) + assert is_tomlkit(d) + unwrapped = d.unwrap() + assert_is_ppo(unwrapped, list) + for du, dw in zip(unwrapped, d): + assert_is_ppo(du, dict) + for ku in du: + vu = du[ku] + assert_is_ppo(ku, str) + assert_is_ppo(vu, str) + + +def test_time_unwrap(): + t = time(3, 8, 14) + elementary_test(item(t), time) + + +def test_date_unwrap(): + d = date.today() + elementary_test(item(d), date) + + +def test_array_unwrap(): + trivia = Trivia(indent="\t", comment_ws=" ", comment="For unit test") + i = item(666) + f = item(2.78) + b = item(False) + a = Array([i, f, b], trivia) + a_unwrapped = a.unwrap() + assert_is_ppo(a_unwrapped, list) + assert_is_ppo(a_unwrapped[0], int) + assert_is_ppo(a_unwrapped[1], float) + assert_is_ppo(a_unwrapped[2], bool) + + +def test_abstract_table_unwrap(): + table = item({"foo": "bar"}) + super_table = item({"table": table, "baz": "borg"}) + assert is_tomlkit(super_table["table"]) + + table_unwrapped = super_table.unwrap() + sub_table = table_unwrapped["table"] + assert_is_ppo(table_unwrapped, dict) + assert_is_ppo(sub_table, dict) + for ku in sub_table: + vu = sub_table[ku] + assert_is_ppo(ku, str) + assert_is_ppo(vu, str) + + def test_key_comparison(): k = Key("foo") diff --git a/tests/test_toml_document.py b/tests/test_toml_document.py index c350984..f2e57c9 100644 --- a/tests/test_toml_document.py +++ b/tests/test_toml_document.py @@ -14,6 +14,10 @@ from tomlkit._utils import _utc from tomlkit.api import document from tomlkit.exceptions import NonExistentKey +from tomlkit.toml_document import TOMLDocument + +from .util import assert_is_ppo +from .util import elementary_test def test_document_is_a_dict(example): @@ -154,6 +158,20 @@ def test_toml_document_without_super_tables(): assert "tool" in d +def test_toml_document_unwrap(): + content = """[tool.poetry] +name = "foo" +""" + + doc = parse(content) + unwrapped = doc.unwrap() + assert_is_ppo(unwrapped, dict) + assert_is_ppo(list(unwrapped.keys())[0], str) + assert_is_ppo(unwrapped["tool"], dict) + assert_is_ppo(list(unwrapped["tool"].keys())[0], str) + assert_is_ppo(unwrapped["tool"]["poetry"]["name"], str) + + def test_toml_document_with_dotted_keys(example): content = example("0.5.0") diff --git a/tests/util.py b/tests/util.py new file mode 100644 index 0000000..3a4c758 --- /dev/null +++ b/tests/util.py @@ -0,0 +1,57 @@ +from tomlkit.items import AoT +from tomlkit.items import Array +from tomlkit.items import Bool +from tomlkit.items import Comment +from tomlkit.items import Date +from tomlkit.items import DateTime +from tomlkit.items import Float +from tomlkit.items import InlineTable +from tomlkit.items import Integer +from tomlkit.items import Item +from tomlkit.items import KeyType +from tomlkit.items import Null +from tomlkit.items import SingleKey as Key +from tomlkit.items import String +from tomlkit.items import StringType +from tomlkit.items import Table +from tomlkit.items import Time +from tomlkit.items import Trivia +from tomlkit.toml_document import TOMLDocument + + +TOMLKIT_TYPES = [ + Bool, + Comment, + InlineTable, + Integer, + Float, + DateTime, + Date, + Time, + Array, + KeyType, + Key, + String, + StringType, + Table, + Trivia, + Item, + AoT, + Null, + TOMLDocument, +] + + +def assert_not_tomlkit_type(v): + for i, T in enumerate(TOMLKIT_TYPES): + assert not isinstance(v, T) + + +def assert_is_ppo(v_unwrapped, unwrappedType): + assert_not_tomlkit_type(v_unwrapped) + assert isinstance(v_unwrapped, unwrappedType) + + +def elementary_test(v, unwrappedType): + v_unwrapped = v.unwrap() + assert_is_ppo(v_unwrapped, unwrappedType) diff --git a/tomlkit/check.py b/tomlkit/check.py new file mode 100644 index 0000000..6d9327a --- /dev/null +++ b/tomlkit/check.py @@ -0,0 +1,12 @@ +def is_tomlkit(v): + from .container import Container + from .container import OutOfOrderTableProxy + from .items import Item as _Item + + if isinstance(v, _Item): + return True + if isinstance(v, Container): + return True + if isinstance(v, OutOfOrderTableProxy): + return True + return False diff --git a/tomlkit/container.py b/tomlkit/container.py index 2db19e3..5b03d5e 100644 --- a/tomlkit/container.py +++ b/tomlkit/container.py @@ -10,6 +10,7 @@ from ._compat import decode from ._utils import merge_dicts +from .check import is_tomlkit from .exceptions import KeyAlreadyPresent from .exceptions import NonExistentKey from .exceptions import TOMLKitError @@ -46,6 +47,25 @@ def __init__(self, parsed: bool = False) -> None: def body(self) -> List[Tuple[Optional[Key], Item]]: return self._body + def unwrap(self) -> str: + unwrapped = {} + for k, v in self.items(): + if k is None: + continue + + if not isinstance(k, str): + k = k.key + + if isinstance(v, Item): + v = v.unwrap() + + if k in unwrapped: + merge_dicts(unwrapped[k], v) + else: + unwrapped[k] = v + + return unwrapped + @property def value(self) -> Dict[Any, Any]: d = {} @@ -796,6 +816,9 @@ def __init__(self, container: Container, indices: Tuple[int]) -> None: if k is not None: dict.__setitem__(self, k.key, v) + def unwrap(self) -> str: + return self._internal_container.unwrap() + @property def value(self): return self._internal_container.value diff --git a/tomlkit/items.py b/tomlkit/items.py index bdc97f2..7bf88f6 100644 --- a/tomlkit/items.py +++ b/tomlkit/items.py @@ -27,6 +27,7 @@ from ._compat import decode from ._utils import CONTROL_CHARS from ._utils import escape_string +from .check import is_tomlkit from .exceptions import InvalidStringError from .toml_char import TOMLChar @@ -492,6 +493,10 @@ def as_string(self) -> str: """The TOML representation""" raise NotImplementedError() + def unwrap(self): + """Returns as pure python object (ppo)""" + raise NotImplementedError() + # Helpers def comment(self, comment: str) -> "Item": @@ -610,6 +615,9 @@ def __init__(self, _: int, trivia: Trivia, raw: str) -> None: if re.match(r"^[+\-]\d+$", raw): self._sign = True + def unwrap(self) -> int: + return int(self) + @property def discriminant(self) -> int: return 2 @@ -678,6 +686,9 @@ def __init__(self, _: float, trivia: Trivia, raw: str) -> None: if re.match(r"^[+\-].+$", raw): self._sign = True + def unwrap(self) -> float: + return float(self) + @property def discriminant(self) -> int: return 3 @@ -739,6 +750,9 @@ def __init__(self, t: int, trivia: Trivia) -> None: self._value = bool(t) + def unwrap(self) -> bool: + return bool(self) + @property def discriminant(self) -> int: return 4 @@ -821,6 +835,21 @@ def __init__( self._raw = raw or self.isoformat() + def unwrap(self) -> datetime: + ( + year, + month, + day, + hour, + minute, + second, + microsecond, + tzinfo, + _, + _, + ) = self._getstate() + return datetime(year, month, day, hour, minute, second, microsecond, tzinfo) + @property def discriminant(self) -> int: return 5 @@ -924,6 +953,10 @@ def __init__( self._raw = raw + def unwrap(self) -> date: + (year, month, day, _, _) = self._getstate() + return date(year, month, day) + @property def discriminant(self) -> int: return 6 @@ -996,6 +1029,10 @@ def __init__( self._raw = raw + def unwrap(self) -> datetime: + (hour, minute, second, microsecond, tzinfo, _, _) = self._getstate() + return time(hour, minute, second, microsecond, tzinfo) + @property def discriminant(self) -> int: return 7 @@ -1051,6 +1088,15 @@ def __init__(self, value: list, trivia: Trivia, multiline: bool = False) -> None self._multiline = multiline self._reindex() + def unwrap(self) -> str: + unwrapped = [] + for v in self: + if is_tomlkit(v): + unwrapped.append(v.unwrap()) + else: + unwrapped.append(v) + return unwrapped + @property def discriminant(self) -> int: return 8 @@ -1295,6 +1341,21 @@ def __init__(self, value: "container.Container", trivia: Trivia): if k is not None: dict.__setitem__(self, k.key, v) + def unwrap(self): + unwrapped = {} + for k in self: + if is_tomlkit(k): + nk = k.unwrap() + else: + nk = k + if is_tomlkit(self[k]): + nv = self[k].unwrap() + else: + nv = self[k] + unwrapped[nk] = nv + + return unwrapped + @property def value(self) -> "container.Container": return self._value @@ -1617,6 +1678,9 @@ def __init__(self, t: StringType, _: str, original: str, trivia: Trivia) -> None self._t = t self._original = original + def unwrap(self) -> str: + return self.as_string() + @property def discriminant(self) -> int: return 11 @@ -1675,6 +1739,15 @@ def __init__( for table in body: self.append(table) + def unwrap(self) -> str: + unwrapped = [] + for t in self._body: + if isinstance(t, Item): + unwrapped.append(t.unwrap()) + else: + unwrapped.append(t) + return unwrapped + @property def body(self) -> List[Table]: return self._body @@ -1766,6 +1839,9 @@ class Null(Item): def __init__(self) -> None: pass + def unwrap(self) -> str: + return None + @property def discriminant(self) -> int: return -1