diff --git a/CHANGELOG.md b/CHANGELOG.md index 949185a..eb91c6c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ - Fix the only child detection when creating tables. ([#175](https://github.com/sdispater/tomlkit/issues/175)) - Include the `docs/` directory and `CHANGELOG.md` in sdist tarball. ([#176](https://github.com/sdispater/tomlkit/issues/176)) +### Added + +- Add keyword arguments to `string` API to allow selecting the representation type. ([#177](https://github.com/sdispater/tomlkit/pull/177)) + ## [0.9.2] - 2022-02-08 ### Changed diff --git a/tests/test_api.py b/tests/test_api.py index 88b9772..3a27db4 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -21,6 +21,7 @@ from tomlkit.exceptions import InvalidDateError from tomlkit.exceptions import InvalidDateTimeError from tomlkit.exceptions import InvalidNumberError +from tomlkit.exceptions import InvalidStringError from tomlkit.exceptions import InvalidTimeError from tomlkit.exceptions import UnexpectedCharError from tomlkit.items import AoT @@ -392,3 +393,68 @@ def test_create_super_table_with_table(): def test_create_super_table_with_aot(): data = {"foo": {"bar": [{"a": 1}]}} assert dumps(data) == "[[foo.bar]]\na = 1\n" + + +@pytest.mark.parametrize( + "kwargs, example, expected", + [ + ({}, "My\nString", '"My\\nString"'), + ({"escape": False}, "My String\t", '"My String\t"'), + ({"literal": True}, "My String\t", "'My String\t'"), + ({"escape": True, "literal": True}, "My String\t", "'My String\t'"), + ({}, "My String\u0001", '"My String\\u0001"'), + ({}, "My String\u000b", '"My String\\u000b"'), + ({}, "My String\x08", '"My String\\b"'), + ({}, "My String\x0c", '"My String\\f"'), + ({}, "My String\x01", '"My String\\u0001"'), + ({}, "My String\x06", '"My String\\u0006"'), + ({}, "My String\x12", '"My String\\u0012"'), + ({}, "My String\x7f", '"My String\\u007f"'), + ({"escape": False}, "My String\u0001", '"My String\u0001"'), + ({"multiline": True}, "\nMy\nString\n", '"""\nMy\nString\n"""'), + ({"multiline": True}, 'My"String', '"""My"String"""'), + ({"multiline": True}, 'My""String', '"""My""String"""'), + ({"multiline": True}, 'My"""String', '"""My""\\"String"""'), + ({"multiline": True}, 'My""""String', '"""My""\\""String"""'), + ( + {"multiline": True}, + '"""My"""Str"""ing"""', + '"""""\\"My""\\"Str""\\"ing""\\""""', + ), + ({"multiline": True, "literal": True}, "My\nString", "'''My\nString'''"), + ({"multiline": True, "literal": True}, "My'String", "'''My'String'''"), + ({"multiline": True, "literal": True}, "My\r\nString", "'''My\r\nString'''"), + ( + {"literal": True}, + r"C:\Users\nodejs\templates", + r"'C:\Users\nodejs\templates'", + ), + ({"literal": True}, r"<\i\c*\s*>", r"'<\i\c*\s*>'"), + ( + {"multiline": True, "literal": True}, + r"I [dw]on't need \d{2} apples", + r"'''I [dw]on't need \d{2} apples'''", + ), + ], +) +def test_create_string(kwargs, example, expected): + value = tomlkit.string(example, **kwargs) + assert value.as_string() == expected + + +@pytest.mark.parametrize( + "kwargs, example", + [ + ({"literal": True}, "My'String"), + ({"literal": True}, "My\nString"), + ({"literal": True}, "My\r\nString"), + ({"literal": True}, "My\bString"), + ({"literal": True}, "My\x08String"), + ({"literal": True}, "My\x0cString"), + ({"literal": True}, "My\x7fString"), + ({"multiline": True, "literal": True}, "My'''String"), + ], +) +def test_create_string_with_invalid_characters(kwargs, example): + with pytest.raises(InvalidStringError): + tomlkit.string(example, **kwargs) diff --git a/tomlkit/_utils.py b/tomlkit/_utils.py index f3fa49f..5c8113f 100644 --- a/tomlkit/_utils.py +++ b/tomlkit/_utils.py @@ -6,6 +6,7 @@ from datetime import time from datetime import timedelta from datetime import timezone +from typing import Collection from typing import Union from ._compat import decode @@ -97,31 +98,49 @@ def parse_rfc3339(string: str) -> Union[datetime, date, time]: raise ValueError("Invalid RFC 339 string") -_escaped = {"b": "\b", "t": "\t", "n": "\n", "f": "\f", "r": "\r", '"': '"', "\\": "\\"} -_escapes = {v: k for k, v in _escaped.items()} +# https://toml.io/en/v1.0.0#string +CONTROL_CHARS = frozenset(chr(c) for c in range(0x20)) | {chr(0x7F)} +_escaped = { + "b": "\b", + "t": "\t", + "n": "\n", + "f": "\f", + "r": "\r", + '"': '"', + "\\": "\\", +} +_compact_escapes = { + **{v: f"\\{k}" for k, v in _escaped.items()}, + '"""': '""\\"', +} +_basic_escapes = CONTROL_CHARS | {'"'} -def escape_string(s: str) -> str: +def _unicode_escape(seq: str) -> str: + return "".join(f"\\u{ord(c):04x}" for c in seq) + + +def escape_string(s: str, escape_sequences: Collection[str] = _basic_escapes) -> str: s = decode(s) res = [] start = 0 + l = len(s) - def flush(): + def flush(inc=1): if start != i: res.append(s[start:i]) - return i + 1 + return i + inc i = 0 - while i < len(s): - c = s[i] - if c in '"\\\n\r\t\b\f': - start = flush() - res.append("\\" + _escapes[c]) - elif ord(c) < 0x20: - start = flush() - res.append("\\u%04x" % ord(c)) + while i < l: + for seq in escape_sequences: + seq_len = len(seq) + if s[i:].startswith(seq): + start = flush(seq_len) + res.append(_compact_escapes.get(seq) or _unicode_escape(seq)) + i += seq_len - 1 # fast-forward escape sequence i += 1 flush() diff --git a/tomlkit/api.py b/tomlkit/api.py index c4cc07f..273efc5 100644 --- a/tomlkit/api.py +++ b/tomlkit/api.py @@ -23,6 +23,7 @@ from .items import Key from .items import SingleKey from .items import String +from .items import StringType as _StringType from .items import Table from .items import Time from .items import Trivia @@ -104,9 +105,28 @@ def boolean(raw: str) -> Bool: return item(raw == "true") -def string(raw: str) -> String: - """Create a string item.""" - return item(raw) +def string( + raw: str, + *, + literal: bool = False, + multiline: bool = False, + escape: bool = True, +) -> String: + """Create a string item. + + By default, this function will create *single line basic* strings, but + boolean flags (e.g. ``literal=True`` and/or ``multiline=True``) + can be used for personalization. + + For more information, please check the spec: `https://toml.io/en/v1.0.0#string`_. + + Common escaping rules will be applied for basic strings. + This can be controlled by explicitly setting ``escape=False``. + Please note that, if you disable escaping, you will have to make sure that + the given strings don't contain any forbidden character or sequence. + """ + type_ = _StringType.select(literal, multiline) + return String.from_raw(raw, type_, escape) def date(raw: str) -> Date: diff --git a/tomlkit/exceptions.py b/tomlkit/exceptions.py index 66370db..6c2c7a1 100644 --- a/tomlkit/exceptions.py +++ b/tomlkit/exceptions.py @@ -1,3 +1,4 @@ +from typing import Collection from typing import Optional @@ -213,3 +214,12 @@ def __init__(self, line: int, col: int, char: int, type: str) -> None: ) super().__init__(line, col, message=message) + + +class InvalidStringError(ValueError, TOMLKitError): + def __init__(self, value: str, invalid_sequences: Collection[str], delimiter: str): + repr_ = repr(value)[1:-1] + super().__init__( + f"Invalid string: {delimiter}{repr_}{delimiter}. " + f"The character sequences {invalid_sequences} are invalid." + ) diff --git a/tomlkit/items.py b/tomlkit/items.py index c1eac02..ba7f848 100644 --- a/tomlkit/items.py +++ b/tomlkit/items.py @@ -11,6 +11,7 @@ from functools import lru_cache from typing import TYPE_CHECKING from typing import Any +from typing import Collection from typing import Dict from typing import Iterable from typing import Iterator @@ -23,7 +24,9 @@ from ._compat import PY38 from ._compat import decode +from ._utils import CONTROL_CHARS from ._utils import escape_string +from .exceptions import InvalidStringError from .toml_char import TOMLChar @@ -124,9 +127,7 @@ def item( return a elif isinstance(value, str): - escaped = escape_string(value) - - return String(StringType.SLB, decode(value), escaped, Trivia()) + return String.from_raw(value) elif isinstance(value, datetime): return DateTime( value.year, @@ -166,6 +167,39 @@ class StringType(Enum): # Multi Line Literal MLL = "'''" + @classmethod + def select(cls, literal=False, multiline=False) -> "StringType": + return { + (False, False): cls.SLB, + (False, True): cls.MLB, + (True, False): cls.SLL, + (True, True): cls.MLL, + }[(literal, multiline)] + + @property + def escaped_sequences(self) -> Collection[str]: + # https://toml.io/en/v1.0.0#string + escaped_in_basic = CONTROL_CHARS | {"\\"} + allowed_in_multiline = {"\n", "\r"} + return { + StringType.SLB: escaped_in_basic | {'"'}, + StringType.MLB: (escaped_in_basic | {'"""'}) - allowed_in_multiline, + StringType.SLL: (), + StringType.MLL: (), + }[self] + + @property + def invalid_sequences(self) -> Collection[str]: + # https://toml.io/en/v1.0.0#string + forbidden_in_literal = CONTROL_CHARS - {"\t"} + allowed_in_multiline = {"\n", "\r"} + return { + StringType.SLB: (), + StringType.MLB: (), + StringType.SLL: forbidden_in_literal | {"'"}, + StringType.MLL: (forbidden_in_literal | {"'''"}) - allowed_in_multiline, + }[self] + @property @lru_cache(maxsize=None) def unit(self) -> str: @@ -1512,6 +1546,19 @@ def _new(self, result): def _getstate(self, protocol=3): return self._t, str(self), self._original, self._trivia + @classmethod + def from_raw(cls, value: str, type_=StringType.SLB, escape=True) -> "String": + value = decode(value) + + invalid = type_.invalid_sequences + if any(c in value for c in invalid): + raise InvalidStringError(value, invalid, type_.value) + + escaped = type_.escaped_sequences + string_value = escape_string(value, escaped) if escape and escaped else value + + return cls(type_, decode(value), string_value, Trivia()) + class AoT(Item, _CustomList): """