From 31e2fa16cf68b57cf84034a54dc3273910c98c3d Mon Sep 17 00:00:00 2001 From: Frost Ming Date: Sun, 7 Aug 2022 22:23:00 +0800 Subject: [PATCH] fix: multiline array format broken when deleting the last item --- tests/test_items.py | 21 ++- tomlkit/items.py | 318 ++++++++++++++++++++++++++++---------------- 2 files changed, 221 insertions(+), 118 deletions(-) diff --git a/tests/test_items.py b/tests/test_items.py index aad33b4..edf77a1 100644 --- a/tests/test_items.py +++ b/tests/test_items.py @@ -368,7 +368,7 @@ def test_array_multiline_modify(): def test_append_to_empty_array(): doc = parse("x = [ ]") doc["x"].append("a") - assert doc.as_string() == 'x = [ "a" ]' + assert doc.as_string() == 'x = ["a" ]' doc = parse("x = [\n]") doc["x"].append("a") assert doc.as_string() == 'x = [\n "a"\n]' @@ -410,6 +410,8 @@ def test_modify_array_with_comment(): 2 ]""" ) + doc["x"].pop(0) + assert doc.as_string() == "x = [\n 2\n]" def test_append_to_multiline_array_with_comment(): @@ -432,14 +434,25 @@ def test_append_to_multiline_array_with_comment(): 2, 3, ] +""" + ) + assert doc["x"].pop() == 3 + assert ( + doc.as_string() + == """\ +x = [ + # Here is a comment + 1, + 2, +] """ ) def test_append_dict_to_array(): - doc = parse("x = [ ]") + doc = parse("x = []") doc["x"].append({"name": "John Doe", "email": "john@doe.com"}) - expected = 'x = [ {name = "John Doe",email = "john@doe.com"} ]' + expected = 'x = [{name = "John Doe",email = "john@doe.com"}]' assert doc.as_string() == expected # Make sure the produced string is valid assert parse(doc.as_string()) == doc @@ -469,7 +482,7 @@ def test_array_add_line(): == """[ 1, 2, 3, # Line 1 4, 5, 6, # Line 2 - 7, 8 + 7, 8, ]""" ) diff --git a/tomlkit/items.py b/tomlkit/items.py index 8c6889c..5128807 100644 --- a/tomlkit/items.py +++ b/tomlkit/items.py @@ -28,7 +28,6 @@ from tomlkit._utils import CONTROL_CHARS from tomlkit._utils import escape_string from tomlkit.exceptions import InvalidStringError -from tomlkit.toml_char import TOMLChar if TYPE_CHECKING: # pragma: no cover @@ -63,47 +62,65 @@ class _CustomDict(MutableMapping, dict): @overload -def item(value: bool) -> "Bool": +def item( + value: bool, _parent: Optional["Item"] = ..., _sort_keys: bool = ... +) -> "Bool": ... @overload -def item(value: int) -> "Integer": +def item( + value: int, _parent: Optional["Item"] = ..., _sort_keys: bool = ... +) -> "Integer": ... @overload -def item(value: float) -> "Float": +def item( + value: float, _parent: Optional["Item"] = ..., _sort_keys: bool = ... +) -> "Float": ... @overload -def item(value: str) -> "String": +def item( + value: str, _parent: Optional["Item"] = ..., _sort_keys: bool = ... +) -> "String": ... @overload -def item(value: datetime) -> "DateTime": +def item( + value: datetime, _parent: Optional["Item"] = ..., _sort_keys: bool = ... +) -> "DateTime": ... @overload -def item(value: date) -> "Date": +def item( + value: date, _parent: Optional["Item"] = ..., _sort_keys: bool = ... +) -> "Date": ... @overload -def item(value: time) -> "Time": +def item( + value: time, _parent: Optional["Item"] = ..., _sort_keys: bool = ... +) -> "Time": ... @overload -def item(value: Sequence[dict]) -> "AoT": +def item( + value: Sequence[dict], _parent: Optional["Item"] = ..., _sort_keys: bool = ... +) -> "AoT": ... @overload -def item(value: Sequence) -> "Array": +def item( + value: Sequence, _parent: Optional["Item"] = ..., _sort_keys: bool = ... +) -> "Array": ... @@ -120,7 +137,9 @@ def item( @overload -def item(value: ItemT) -> ItemT: +def item( + value: ItemT, _parent: Optional["Item"] = ..., _sort_keys: bool = ... +) -> ItemT: ... @@ -1076,22 +1095,83 @@ def _getstate(self, protocol: int = 3) -> tuple: ) +class _ArrayItemGroup: + __slots__ = ("value", "indent", "comma", "comment") + + def __init__( + self, + value: Optional[Item] = None, + indent: Optional[Whitespace] = None, + comma: Optional[Whitespace] = None, + comment: Optional[Comment] = None, + ) -> None: + self.value = value + self.indent = indent + self.comma = comma + self.comment = comment + + def __iter__(self) -> Iterator[Item]: + return filter( + lambda x: x is not None, (self.indent, self.value, self.comma, self.comment) + ) + + def __repr__(self) -> str: + return repr(tuple(self)) + + def is_whitespace(self) -> bool: + return self.value is None and self.comment is None + + def __bool__(self) -> bool: + try: + next(iter(self)) + except StopIteration: + return False + return True + + class Array(Item, _CustomList): """ An array literal """ - def __init__(self, value: list, trivia: Trivia, multiline: bool = False) -> None: + def __init__( + self, value: List[Item], trivia: Trivia, multiline: bool = False + ) -> None: super().__init__(trivia) - self._index_map: Dict[int, int] = {} list.__init__( self, [v.value for v in value if not isinstance(v, (Whitespace, Comment))] ) - - self._value = value + self._index_map: Dict[int, int] = {} + self._value = self._group_values(value) self._multiline = multiline self._reindex() + def _group_values(self, value: List[Item]) -> List[_ArrayItemGroup]: + """Group the values into (indent, value, comma, comment) tuples""" + groups = [] + this_group = _ArrayItemGroup() + for item in value: + if isinstance(item, Whitespace): + if "," not in item.s: + groups.append(this_group) + this_group = _ArrayItemGroup(indent=item) + else: + if this_group.value is None: + # when comma is met and no value is provided, add a dummy Null + this_group.value = Null() + this_group.comma = item + elif isinstance(item, Comment): + if this_group.value is None: + this_group.value = Null() + this_group.comment = item + elif this_group.value is None: + this_group.value = item + else: + groups.append(this_group) + this_group = _ArrayItemGroup(value=item) + groups.append(this_group) + return [group for group in groups if group] + def unwrap(self) -> str: unwrapped = [] for v in self: @@ -1109,6 +1189,10 @@ def discriminant(self) -> int: def value(self) -> list: return self + def _iter_items(self) -> Iterator[Item]: + for v in self._value: + yield from v + def multiline(self, multiline: bool) -> "Array": """Change the array to display in multiline or not. @@ -1130,16 +1214,18 @@ def multiline(self, multiline: bool) -> "Array": def as_string(self) -> str: if not self._multiline or not self._value: - return f'[{"".join(v.as_string() for v in self._value)}]' + return f'[{"".join(v.as_string() for v in self._iter_items())}]' s = "[\n" s += "".join( self.trivia.indent + " " * 4 - + v.as_string() - + ("\n" if isinstance(v, Comment) else ",\n") + + v.value.as_string() + + ("," if not isinstance(v.value, Null) else "") + + (v.comment.as_string() if v.comment is not None else "") + + "\n" for v in self._value - if not isinstance(v, Whitespace) + if v.value is not None ) s += self.trivia.indent + "]" @@ -1149,7 +1235,7 @@ def _reindex(self) -> None: self._index_map.clear() index = 0 for i, v in enumerate(self._value): - if isinstance(v, (Whitespace, Comment)): + if v.value is None or isinstance(v.value, Null): continue self._index_map[index] = i index += 1 @@ -1178,54 +1264,56 @@ def add_line( 4, 5, 6, ] """ - values = self._value[:] - new_values = [] - - def append_item(el: Item) -> None: - if not values: - return values.append(el) - last_el = values[-1] - if ( - isinstance(el, Whitespace) - and "," not in el.s - and isinstance(last_el, Whitespace) - and "," not in last_el.s - ): - values[-1] = Whitespace(last_el.s + el.s) - else: - values.append(el) - - if newline: - append_item(Whitespace("\n")) - if indent: - append_item(Whitespace(indent)) + new_values: List[Item] = [] + first_indent = f"\n{indent}" if newline else indent + if first_indent: + new_values.append(Whitespace(first_indent)) + whitespace = "" + data_values = [] for i, el in enumerate(items): - el = item(el, _parent=self) - if isinstance(el, Comment) or add_comma and isinstance(el, Whitespace): - raise ValueError(f"item type {type(el)} is not allowed") - if not isinstance(el, Whitespace): - new_values.append(el.value) - append_item(el) - if add_comma: - append_item(Whitespace(",")) - if i != len(items) - 1: - append_item(Whitespace(" ")) + it = item(el, _parent=self) + if isinstance(it, Comment) or add_comma and isinstance(el, Whitespace): + raise ValueError(f"item type {type(it)} is not allowed in add_line") + if not isinstance(it, Whitespace): + if whitespace: + new_values.append(Whitespace(whitespace)) + whitespace = "" + new_values.append(it) + data_values.append(it.value) + if add_comma: + new_values.append(Whitespace(",")) + if i != len(items) - 1: + new_values.append(Whitespace(" ")) + elif "," not in it.s: + whitespace += it.s + else: + new_values.append(it) + if whitespace: + new_values.append(Whitespace(whitespace)) if comment: indent = " " if items else "" - append_item( + new_values.append( Comment(Trivia(indent=indent, comment=f"# {comment}", trail="")) ) - # Atomic manipulation - self._value[:] = values - list.extend(self, new_values) + list.extend(self, data_values) + if len(self._value) > 0: + last_item = self._value[-1] + last_value_item = next((v for v in self._value[::-1] if v.value), None) + if last_value_item is not None: + last_value_item.comma = Whitespace(",") + if last_item.is_whitespace(): + self._value[-1:-1] = self._group_values(new_values) + else: + self._value.extend(self._group_values(new_values)) + else: + self._value.extend(self._group_values(new_values)) self._reindex() def clear(self) -> None: """Clear the array.""" list.clear(self) - - self._value.clear() self._index_map.clear() + self._value.clear() def __len__(self) -> int: return list.__len__(self) @@ -1240,7 +1328,7 @@ def __setitem__(self, key: Union[int, slice], value: Any) -> Any: raise ValueError("slice assignment is not supported") if key < 0: key += len(self) - self._value[self._index_map[key]] = it + self._value[self._index_map[key]].value = it def insert(self, pos: int, value: Any) -> None: it = item(value, _parent=self) @@ -1252,82 +1340,84 @@ def insert(self, pos: int, value: Any) -> None: if pos < 0: pos = 0 - items = [it] - idx = 0 + idx = 0 # insert position of the self._value list + default_indent = " " if pos < length: try: idx = self._index_map[pos] - except KeyError: - raise IndexError("list index out of range") - if not isinstance(it, (Whitespace, Comment)): - items.append(Whitespace(",")) + except KeyError as e: + raise IndexError("list index out of range") from e else: idx = len(self._value) + if idx >= 1 and self._value[idx - 1].is_whitespace(): + # The last item is a pure whitespace(\n ), insert before it + idx -= 1 + if ( + self._value[idx].indent is not None + and "\n" in self._value[idx].indent.s + ): + default_indent = "\n " + indent: Optional[Item] = None + comma: Optional[Item] = Whitespace(",") if pos < length else None + if idx < len(self._value) and not self._value[idx].is_whitespace(): + # Prefer to copy the indentation from the item after + indent = self._value[idx].indent if idx > 0: last_item = self._value[idx - 1] - if isinstance(last_item, Whitespace) and "," not in last_item.s: - # the item has an indent, copy that - idx -= 1 - ws = last_item.s - if isinstance(it, Whitespace) and "," not in it.s: - # merge the whitespace - self._value[idx] = Whitespace(ws + it.s) - return - else: - ws = "" - has_newline = bool(set(ws) & set(TOMLChar.NL)) - has_space = ws and ws[-1] in TOMLChar.SPACES - if not has_space: - # four spaces for multiline array and single space otherwise - ws += " " if has_newline else " " - items.insert(0, Whitespace(ws)) - self._value[idx:idx] = items - i = idx - 1 - if pos > 0: # Check if the last item ends with a comma - while i >= 0 and isinstance(self._value[i], (Whitespace, Comment)): - if isinstance(self._value[i], Whitespace) and "," in self._value[i].s: - break - i -= 1 - else: - self._value.insert(i + 1, Whitespace(",")) - + if indent is None: + indent = last_item.indent + if not isinstance(last_item.value, Null) and "\n" in default_indent: + # Copy the comma from the last item if 1) it contains a value and + # 2) the array is multiline + comma = last_item.comma + if last_item.comma is None and not isinstance(last_item.value, Null): + # Add comma to the last item to separate it from the following items. + last_item.comma = Whitespace(",") + if indent is None and (idx > 0 or "\n" in default_indent): + # apply default indent if it isn't the first item or the array is multiline. + indent = Whitespace(default_indent) + new_item = _ArrayItemGroup(value=it, indent=indent, comma=comma) + self._value.insert(idx, new_item) self._reindex() def __delitem__(self, key: Union[int, slice]): length = len(self) list.__delitem__(self, key) - def get_indice_to_remove(idx: int) -> Iterable[int]: - try: - real_idx = self._index_map[idx] - except KeyError: - raise IndexError("list index out of range") - yield real_idx - for i in range(real_idx + 1, len(self._value)): - if isinstance(self._value[i], Whitespace): - yield i - else: - break - - indexes = set() if isinstance(key, slice): - for idx in range(key.start or 0, key.stop or length, key.step or 1): - indexes.update(get_indice_to_remove(idx)) + indices_to_remove = list( + range(key.start or 0, key.stop or length, key.step or 1) + ) else: - indexes.update(get_indice_to_remove(length + key if key < 0 else key)) - for i in sorted(indexes, reverse=True): - del self._value[i] - while self._value and isinstance(self._value[-1], Whitespace): - self._value.pop() + indices_to_remove = [length + key if key < 0 else key] + for i in sorted(indices_to_remove, reverse=True): + try: + idx = self._index_map[i] + except KeyError as e: + if not isinstance(key, slice): + raise IndexError("list index out of range") from e + else: + del self._value[idx] + if ( + idx == 0 + and len(self._value) > 0 + and "\n" not in self._value[idx].indent.s + ): + # Remove the indentation of the first item if not newline + self._value[idx].indent = None + if len(self._value) > 0: + v = self._value[-1] + if not v.is_whitespace(): + # remove the comma of the last item + v.comma = None + self._reindex() def __str__(self): - return str( - [v.value for v in self._value if not isinstance(v, (Whitespace, Comment))] - ) + return str([v.value.value for v in self._iter_items() if v.value is not None]) def _getstate(self, protocol=3): - return self._value, self._trivia + return list(self._iter_items()), self._trivia, self._multiline AT = TypeVar("AT", bound="AbstractTable")