Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow string type to be controlled from the public API #177

Merged
merged 7 commits into from Feb 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Expand Up @@ -7,6 +7,10 @@
- Fix the only child detection when creating tables. ([#175](https://github.com/sdispater/tomlkit/issues/175))
- Include the `docs/` directory and `CHANGELOG.md` in sdist tarball. ([#176](https://github.com/sdispater/tomlkit/issues/176))

### Added

- Add keyword arguments to `string` API to allow selecting the representation type. ([#177](https://github.com/sdispater/tomlkit/pull/177))

## [0.9.2] - 2022-02-08

### Changed
Expand Down
66 changes: 66 additions & 0 deletions tests/test_api.py
Expand Up @@ -21,6 +21,7 @@
from tomlkit.exceptions import InvalidDateError
from tomlkit.exceptions import InvalidDateTimeError
from tomlkit.exceptions import InvalidNumberError
from tomlkit.exceptions import InvalidStringError
from tomlkit.exceptions import InvalidTimeError
from tomlkit.exceptions import UnexpectedCharError
from tomlkit.items import AoT
Expand Down Expand Up @@ -392,3 +393,68 @@ def test_create_super_table_with_table():
def test_create_super_table_with_aot():
data = {"foo": {"bar": [{"a": 1}]}}
assert dumps(data) == "[[foo.bar]]\na = 1\n"


@pytest.mark.parametrize(
"kwargs, example, expected",
[
({}, "My\nString", '"My\\nString"'),
({"escape": False}, "My String\t", '"My String\t"'),
({"literal": True}, "My String\t", "'My String\t'"),
({"escape": True, "literal": True}, "My String\t", "'My String\t'"),
({}, "My String\u0001", '"My String\\u0001"'),
({}, "My String\u000b", '"My String\\u000b"'),
({}, "My String\x08", '"My String\\b"'),
({}, "My String\x0c", '"My String\\f"'),
({}, "My String\x01", '"My String\\u0001"'),
({}, "My String\x06", '"My String\\u0006"'),
({}, "My String\x12", '"My String\\u0012"'),
({}, "My String\x7f", '"My String\\u007f"'),
({"escape": False}, "My String\u0001", '"My String\u0001"'),
({"multiline": True}, "\nMy\nString\n", '"""\nMy\nString\n"""'),
({"multiline": True}, 'My"String', '"""My"String"""'),
({"multiline": True}, 'My""String', '"""My""String"""'),
({"multiline": True}, 'My"""String', '"""My""\\"String"""'),
({"multiline": True}, 'My""""String', '"""My""\\""String"""'),
(
{"multiline": True},
'"""My"""Str"""ing"""',
'"""""\\"My""\\"Str""\\"ing""\\""""',
),
({"multiline": True, "literal": True}, "My\nString", "'''My\nString'''"),
({"multiline": True, "literal": True}, "My'String", "'''My'String'''"),
({"multiline": True, "literal": True}, "My\r\nString", "'''My\r\nString'''"),
(
{"literal": True},
r"C:\Users\nodejs\templates",
r"'C:\Users\nodejs\templates'",
),
({"literal": True}, r"<\i\c*\s*>", r"'<\i\c*\s*>'"),
(
{"multiline": True, "literal": True},
r"I [dw]on't need \d{2} apples",
r"'''I [dw]on't need \d{2} apples'''",
),
],
)
def test_create_string(kwargs, example, expected):
value = tomlkit.string(example, **kwargs)
assert value.as_string() == expected


@pytest.mark.parametrize(
"kwargs, example",
[
({"literal": True}, "My'String"),
({"literal": True}, "My\nString"),
({"literal": True}, "My\r\nString"),
({"literal": True}, "My\bString"),
({"literal": True}, "My\x08String"),
({"literal": True}, "My\x0cString"),
({"literal": True}, "My\x7fString"),
({"multiline": True, "literal": True}, "My'''String"),
],
)
def test_create_string_with_invalid_characters(kwargs, example):
with pytest.raises(InvalidStringError):
tomlkit.string(example, **kwargs)
45 changes: 32 additions & 13 deletions tomlkit/_utils.py
Expand Up @@ -6,6 +6,7 @@
from datetime import time
from datetime import timedelta
from datetime import timezone
from typing import Collection
from typing import Union

from ._compat import decode
Expand Down Expand Up @@ -97,31 +98,49 @@ def parse_rfc3339(string: str) -> Union[datetime, date, time]:
raise ValueError("Invalid RFC 339 string")


_escaped = {"b": "\b", "t": "\t", "n": "\n", "f": "\f", "r": "\r", '"': '"', "\\": "\\"}
_escapes = {v: k for k, v in _escaped.items()}
# https://toml.io/en/v1.0.0#string
CONTROL_CHARS = frozenset(chr(c) for c in range(0x20)) | {chr(0x7F)}
_escaped = {
"b": "\b",
"t": "\t",
"n": "\n",
"f": "\f",
"r": "\r",
'"': '"',
"\\": "\\",
}
_compact_escapes = {
**{v: f"\\{k}" for k, v in _escaped.items()},
'"""': '""\\"',
}
_basic_escapes = CONTROL_CHARS | {'"'}


def escape_string(s: str) -> str:
def _unicode_escape(seq: str) -> str:
return "".join(f"\\u{ord(c):04x}" for c in seq)


def escape_string(s: str, escape_sequences: Collection[str] = _basic_escapes) -> str:
s = decode(s)

res = []
start = 0
l = len(s)

def flush():
def flush(inc=1):
if start != i:
res.append(s[start:i])

return i + 1
return i + inc

i = 0
while i < len(s):
c = s[i]
if c in '"\\\n\r\t\b\f':
start = flush()
res.append("\\" + _escapes[c])
elif ord(c) < 0x20:
start = flush()
res.append("\\u%04x" % ord(c))
while i < l:
for seq in escape_sequences:
seq_len = len(seq)
if s[i:].startswith(seq):
start = flush(seq_len)
res.append(_compact_escapes.get(seq) or _unicode_escape(seq))
i += seq_len - 1 # fast-forward escape sequence
i += 1

flush()
Expand Down
26 changes: 23 additions & 3 deletions tomlkit/api.py
Expand Up @@ -23,6 +23,7 @@
from .items import Key
from .items import SingleKey
from .items import String
from .items import StringType as _StringType
from .items import Table
from .items import Time
from .items import Trivia
Expand Down Expand Up @@ -104,9 +105,28 @@ def boolean(raw: str) -> Bool:
return item(raw == "true")


def string(raw: str) -> String:
"""Create a string item."""
return item(raw)
def string(
raw: str,
*,
frostming marked this conversation as resolved.
Show resolved Hide resolved
literal: bool = False,
multiline: bool = False,
escape: bool = True,
) -> String:
"""Create a string item.

By default, this function will create *single line basic* strings, but
boolean flags (e.g. ``literal=True`` and/or ``multiline=True``)
can be used for personalization.

For more information, please check the spec: `https://toml.io/en/v1.0.0#string`_.

Common escaping rules will be applied for basic strings.
This can be controlled by explicitly setting ``escape=False``.
Please note that, if you disable escaping, you will have to make sure that
the given strings don't contain any forbidden character or sequence.
"""
type_ = _StringType.select(literal, multiline)
return String.from_raw(raw, type_, escape)


def date(raw: str) -> Date:
Expand Down
10 changes: 10 additions & 0 deletions tomlkit/exceptions.py
@@ -1,3 +1,4 @@
from typing import Collection
from typing import Optional


Expand Down Expand Up @@ -213,3 +214,12 @@ def __init__(self, line: int, col: int, char: int, type: str) -> None:
)

super().__init__(line, col, message=message)


class InvalidStringError(ValueError, TOMLKitError):
def __init__(self, value: str, invalid_sequences: Collection[str], delimiter: str):
repr_ = repr(value)[1:-1]
super().__init__(
f"Invalid string: {delimiter}{repr_}{delimiter}. "
f"The character sequences {invalid_sequences} are invalid."
)
53 changes: 50 additions & 3 deletions tomlkit/items.py
Expand Up @@ -11,6 +11,7 @@
from functools import lru_cache
from typing import TYPE_CHECKING
from typing import Any
from typing import Collection
from typing import Dict
from typing import Iterable
from typing import Iterator
Expand All @@ -23,7 +24,9 @@

from ._compat import PY38
from ._compat import decode
from ._utils import CONTROL_CHARS
from ._utils import escape_string
from .exceptions import InvalidStringError
from .toml_char import TOMLChar


Expand Down Expand Up @@ -124,9 +127,7 @@ def item(

return a
elif isinstance(value, str):
escaped = escape_string(value)

return String(StringType.SLB, decode(value), escaped, Trivia())
return String.from_raw(value)
elif isinstance(value, datetime):
return DateTime(
value.year,
Expand Down Expand Up @@ -166,6 +167,39 @@ class StringType(Enum):
# Multi Line Literal
MLL = "'''"

@classmethod
def select(cls, literal=False, multiline=False) -> "StringType":
return {
(False, False): cls.SLB,
(False, True): cls.MLB,
(True, False): cls.SLL,
(True, True): cls.MLL,
}[(literal, multiline)]

@property
def escaped_sequences(self) -> Collection[str]:
# https://toml.io/en/v1.0.0#string
escaped_in_basic = CONTROL_CHARS | {"\\"}
allowed_in_multiline = {"\n", "\r"}
return {
StringType.SLB: escaped_in_basic | {'"'},
StringType.MLB: (escaped_in_basic | {'"""'}) - allowed_in_multiline,
StringType.SLL: (),
StringType.MLL: (),
}[self]

@property
def invalid_sequences(self) -> Collection[str]:
# https://toml.io/en/v1.0.0#string
forbidden_in_literal = CONTROL_CHARS - {"\t"}
allowed_in_multiline = {"\n", "\r"}
return {
StringType.SLB: (),
StringType.MLB: (),
StringType.SLL: forbidden_in_literal | {"'"},
StringType.MLL: (forbidden_in_literal | {"'''"}) - allowed_in_multiline,
}[self]

@property
@lru_cache(maxsize=None)
def unit(self) -> str:
Expand Down Expand Up @@ -1512,6 +1546,19 @@ def _new(self, result):
def _getstate(self, protocol=3):
return self._t, str(self), self._original, self._trivia

@classmethod
def from_raw(cls, value: str, type_=StringType.SLB, escape=True) -> "String":
value = decode(value)

invalid = type_.invalid_sequences
if any(c in value for c in invalid):
raise InvalidStringError(value, invalid, type_.value)

escaped = type_.escaped_sequences
string_value = escape_string(value, escaped) if escape and escaped else value

return cls(type_, decode(value), string_value, Trivia())


class AoT(Item, _CustomList):
"""
Expand Down