From 01a83ab25072c836e45c8e257cbc075d2f278fde Mon Sep 17 00:00:00 2001 From: Lele Gaifax Date: Sat, 28 May 2022 14:30:41 +0200 Subject: [PATCH] Get rid of the old Node wrappers Those wrappers were needed in v1 to hide the underlying JSON structures emitted by libpg_query, making it possible to see them as generic Python instances. Version 3 obsoleted that by introducing a set of concrete AST classes, and the wrappers were kept mainly for their ability to track the parentship of each node, required by some of the printer functions. Now they are gone, and everything works directly on AST classes, and the printers use the new "ancestors" slot attached to each instance. This implements issue #80. --- docs/api.rst | 3 +- docs/node.rst | 38 ---- docs/printers.rst | 4 +- docs/usage.rst | 96 +--------- pglast/__init__.py | 13 +- pglast/node.py | 300 ------------------------------- pglast/printers/__init__.py | 114 ++++++------ pglast/printers/ddl.py | 289 ++++++++++++++--------------- pglast/printers/dml.py | 174 +++++++++--------- pglast/printers/sfuncs.py | 6 +- pglast/stream.py | 111 ++++++------ tests/test_node.py | 92 ---------- tests/test_printers.py | 78 ++------ tests/test_printers_roundtrip.py | 12 +- tests/test_stream.py | 60 ++----- 15 files changed, 395 insertions(+), 995 deletions(-) delete mode 100644 docs/node.rst delete mode 100644 pglast/node.py delete mode 100644 tests/test_node.py diff --git a/docs/api.rst b/docs/api.rst index 9b8b93eb..b15b5b25 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -3,7 +3,7 @@ .. :Created: gio 10 ago 2017 10:14:17 CEST .. :Author: Lele Gaifax .. :License: GNU General Public License version 3 or later -.. :Copyright: © 2017, 2018, 2019, 2021 Lele Gaifax +.. :Copyright: © 2017, 2018, 2019, 2021, 2022 Lele Gaifax .. =================== @@ -23,7 +23,6 @@ This chapter briefly explains some implementation detail. ast enums keywords - node stream printers sfuncs diff --git a/docs/node.rst b/docs/node.rst deleted file mode 100644 index bed369e2..00000000 --- a/docs/node.rst +++ /dev/null @@ -1,38 +0,0 @@ -.. -*- coding: utf-8 -*- -.. :Project: pglast -- Node documentation -.. :Created: gio 10 ago 2017 10:28:36 CEST -.. :Author: Lele Gaifax -.. :License: GNU General Public License version 3 or later -.. :Copyright: © 2017, 2018, 2021 Lele Gaifax -.. - -===================================================================== - :mod:`pglast.node` --- The higher level interface to the parse tree -===================================================================== - -This module implements a set of classes that make it easier to deal with the :mod:`pglast.ast` -nodes. - -The :class:`pglast.node.Node` wraps a single :class:`pglast.ast.Node` adding a reference to the -parent node; the class:`pglast.node.List` wraps a sequence of them and -:class:`pglast.node.Scalar` represents plain values such a *strings*, *integers*, *booleans* or -*none*. - -Every node is identified by a *tag*, a string label that characterizes its content, exposed as -a set of *attributes* as well as with a dictionary-like interface (technically -:class:`pglast.node.Node` implements both a ``__getattr__`` method and a ``__getitem__`` -method, while underlying :class:`pglast.ast.Node` only the former). When asked for an -attribute, the node returns an instance of the base classes, i.e. another ``Node``, or a -``List`` or a ``Scalar``, depending on the data type of that item. When the node does not -contain the requested attribute it returns a singleton :data:`pglast.node.Missing` marker -instance. - -A ``List`` wraps a plain Python ``list`` and may contains a sequence of ``Node`` instances, or -in some cases other sub-lists, that can be accessed with the usual syntax, or iterated. - -Finally, a ``Scalar`` carries a single value of some scalar type, accessible through its -``value`` attribute. - -.. automodule:: pglast.node - :synopsis: The higher level interface to the parse tree - :members: diff --git a/docs/printers.rst b/docs/printers.rst index d50b93f5..60ec335b 100644 --- a/docs/printers.rst +++ b/docs/printers.rst @@ -3,7 +3,7 @@ .. :Created: gio 10 ago 2017 13:23:18 CEST .. :Author: Lele Gaifax .. :License: GNU General Public License version 3 or later -.. :Copyright: © 2017, 2018, 2021 Lele Gaifax +.. :Copyright: © 2017, 2018, 2021, 2022 Lele Gaifax .. ========================================================== @@ -22,7 +22,7 @@ associated :class:`~.node.Node` will be serialized. .. autoexception:: PrinterAlreadyPresentError -.. autofunction:: get_printer_for_node_tag +.. autofunction:: get_printer_for_node .. autofunction:: node_printer diff --git a/docs/usage.rst b/docs/usage.rst index 5ba556b5..2d2842da 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -3,7 +3,7 @@ .. :Created: gio 10 ago 2017 10:06:38 CEST .. :Author: Lele Gaifax .. :License: GNU General Public License version 3 or later -.. :Copyright: © 2017, 2018, 2019, 2021 Lele Gaifax +.. :Copyright: © 2017, 2018, 2019, 2021, 2022 Lele Gaifax .. .. _usage: @@ -15,7 +15,7 @@ Here are some example of how the module can be used. --------- -Low level +AST level --------- The lowest level is a Python wrapper around each *parse node* returned by the ``PostgreSQL`` @@ -181,89 +181,6 @@ This basically means that you can reconstruct a syntax tree from the result of c >>> clone == stmt True ------------- -Medium level ------------- - -Parse an ``SQL`` statement and get its *AST* root node -====================================================== - -.. doctest:: - - >>> from pglast import Node - >>> root = Node(parse_sql('SELECT foo FROM bar')) - >>> print(root) - None=[1*{RawStmt}] - -Get a particular node -===================== - -.. doctest:: - - >>> from_clause = root[0].stmt.fromClause - >>> print(from_clause) - fromClause=[1*{RangeVar}] - -Obtain some information about a node -==================================== - -.. doctest:: - - >>> range_var = from_clause[0] - >>> print(range_var.node_tag) - RangeVar - >>> print(range_var.attribute_names) - ('catalogname', 'schemaname', 'relname', 'inh', 'relpersistence', 'alias', 'location') - >>> print(range_var.parent_node) - stmt={SelectStmt} - -Iterate over nodes -================== - -.. doctest:: - - >>> for a in from_clause: - ... print(a) - ... for b in a: - ... print(b) - ... - fromClause[0]={RangeVar} - inh= - location=<16> - relname=<'bar'> - relpersistence=<'p'> - -Recursively :meth:`traverse ` the parse tree -======================================================================= - -.. doctest:: - - >>> for node in root.traverse(): - ... print(node) - ... - None[0]={RawStmt} - stmt={SelectStmt} - all= - fromClause[0]={RangeVar} - inh= - location=<16> - relname=<'bar'> - relpersistence=<'p'> - limitOption= - op= - targetList[0]={ResTarget} - location=<7> - val={ColumnRef} - fields[0]={String} - val=<'foo'> - location=<7> - stmt_len=<0> - stmt_location=<0> - -As you can see, the ``repr``\ esentation of each value is mnemonic: ``{some_tag}`` means a -``Node`` with tag ``some_tag``, ``[X*{some_tag}]`` is a ``List`` containing `X` nodes of that -particular kind\ [*]_ and ```` is a ``Scalar``. - Programmatically :func:`reformat ` a ``SQL`` statement ======================================================================= @@ -318,7 +235,7 @@ that extends :class:`pglast.stream.RawStream` adding a bit a aesthetic sense. .. doctest:: - >>> print(IndentedStream()(root)) + >>> print(IndentedStream()('select foo from bar')) SELECT foo FROM bar @@ -400,7 +317,7 @@ Customize a :func:`node printer ` >>> from pglast.printers import node_printer >>> @node_printer('ParamRef', override=True) ... def replace_param_ref(node, output): - ... output.write(repr(args[node.number.value - 1])) + ... output.write(repr(args[node.number - 1])) ... >>> args = ['Hello', 'Ciao'] >>> print(prettify(sql, safety_belt=False)) @@ -542,11 +459,6 @@ Preserve comments carry their exact location, so it is not possible to differentiate between ``SELECT * /*comment*/ FROM foo`` and ``SELECT * FROM /*comment*/ foo``. ---- - -.. [*] This is an approximation, because in principle a list can contain different kinds of - nodes, or even sub-lists in some cases: the ``List`` representation arbitrarily shows - the tag of the first object. Functions vs SQL syntax ======================= diff --git a/pglast/__init__.py b/pglast/__init__.py index 15ec191d..55b6e106 100644 --- a/pglast/__init__.py +++ b/pglast/__init__.py @@ -3,12 +3,13 @@ # :Created: mer 02 ago 2017 15:11:02 CEST # :Author: Lele Gaifax # :License: GNU General Public License version 3 or later -# :Copyright: © 2017, 2018, 2019, 2021 Lele Gaifax +# :Copyright: © 2017, 2018, 2019, 2021, 2022 Lele Gaifax # +from collections import namedtuple + from . import enums from .error import Error -from .node import Comment, Missing, Node try: from .parser import fingerprint, get_postgresql_version, parse_sql, scan, split except ModuleNotFoundError: # pragma: no cover @@ -31,6 +32,10 @@ def parse_plpgsql(statement): return loads(parse_plpgsql_json(statement)) +Comment = namedtuple('Comment', ('location', 'text', 'at_start_of_line', 'continue_previous')) +"A structure to carry information about a single SQL comment." + + def _extract_comments(statement): lines = [] lofs = 0 @@ -81,7 +86,7 @@ def prettify(statement, safety_belt=False, preserve_comments=False, **options): options['comments'] = _extract_comments(statement) orig_pt = parse_sql(statement) - prettified = IndentedStream(**options)(Node(orig_pt)) + prettified = IndentedStream(**options)(orig_pt) if safety_belt: from logging import getLogger import warnings @@ -109,5 +114,5 @@ def prettify(statement, safety_belt=False, preserve_comments=False, **options): return prettified -__all__ = ('Error', 'Missing', 'Node', 'enums', 'fingerprint', 'get_postgresql_version', +__all__ = ('Error', 'enums', 'fingerprint', 'get_postgresql_version', 'parse_plpgsql', 'parse_sql', 'prettify', 'split') diff --git a/pglast/node.py b/pglast/node.py deleted file mode 100644 index 43b022aa..00000000 --- a/pglast/node.py +++ /dev/null @@ -1,300 +0,0 @@ -# -*- coding: utf-8 -*- -# :Project: pglast -- Generic Node implementation -# :Created: mer 02 ago 2017 15:44:14 CEST -# :Author: Lele Gaifax -# :License: GNU General Public License version 3 or later -# :Copyright: © 2017, 2018, 2019, 2021, 2022 Lele Gaifax -# - -from collections import namedtuple -from decimal import Decimal -from enum import Enum - -from . import ast - - -class Missing: - def __bool__(self): - return False - - def __repr__(self): # pragma: no cover - return "MISSING" - - def __iter__(self): - return self - - def __next__(self): - raise StopIteration() - - -Missing = Missing() -"Singleton returned when trying to get a non-existing attribute out of a :class:`Node`." - - -Comment = namedtuple('Comment', ('location', 'text', 'at_start_of_line', 'continue_previous')) -"A structure to carry information about a single SQL comment." - - -class Base: - """Common base class. - - :param details: the *parse tree* - :type parent: ``None`` or :class:`Node` instance - :param parent: ``None`` to indicate that the node is the *root* of the parse tree, - otherwise it is the immediate parent of the new node - :type name: str or tuple - :param name: the name of the attribute in the `parent` node that *points* to this one; - it may be a tuple (name, position) when ``parent[name]`` is actually a list of - nodes - - Its main purpose is to create the right kind of instance, depending on the type of the - `details` argument passed to the constructor: a :class:`ast.Node ` - produces a :class:`Node` instance, a ``list`` or ``tuple`` produces a :class:`List` - instance, everything else a :class:`Scalar` instance. - - """ - - __slots__ = ('_parent_node', '_parent_attribute') - - def __new__(cls, details, parent=None, name=None): - if parent is not None and not isinstance(parent, Node): - raise ValueError("Unexpected value for 'parent', must be either None" - " or a Node instance, got %r" % type(parent)) - if name is not None and not isinstance(name, (str, tuple)): - raise ValueError("Unexpected value for 'name', must be either None," - " a string or a tuple, got %r" % type(name)) - if isinstance(details, (list, tuple)): - self = super().__new__(List) - elif isinstance(details, ast.Node): - self = super().__new__(Node) - else: - self = super().__new__(Scalar) - self._init(details, parent, name) - return self - - def _init(self, parent, name): - self._parent_node = parent - self._parent_attribute = name - - def __str__(self): # pragma: no cover - aname = self._parent_attribute - if isinstance(aname, tuple): - aname = '%s[%d]' % aname - return '%s=%r' % (aname, self) - - @property - def parent_node(self): - "The parent :class:`Node` of this element." - - return self._parent_node - - @property - def parent_attribute(self): - "The *attribute* in the parent :class:`Node` referencing this element." - - return self._parent_attribute - - -class List(Base): - """Represent a sequence of :class:`Node` instances. - - :type items: list - :param items: a list of items, usually :class:`Node` instances - :type parent: ``None`` or :class:`Node` instance - :param parent: ``None`` to indicate that the node is the *root* of the parse tree, - otherwise it is the immediate parent of the new node - :type name: str or tuple - :param name: the name of the attribute in the `parent` node that *points* to this one; - it may be a tuple (name, position) when ``parent[name]`` is actually a list of - nodes - """ - - __slots__ = ('_items',) - - def _init(self, items, parent, name): - if not isinstance(items, (list, tuple)) or not items: # pragma: no cover - raise ValueError("Unexpected value for 'items', must be a non empty tuple or list," - " got %r" % type(items)) - super()._init(parent, name) - self._items = items - - def __len__(self): - return len(self._items) - - def __bool__(self): - return len(self) > 0 - - def __repr__(self): - if not self: # pragma: no cover - return '[]' - # There's no guarantee that a list contains the same kind of objects, - # so picking the first is rather arbitrary but serves the purpose, as - # this is primarily an helper for investigating the internals of a tree. - count = len(self) - pivot = self[0] - return '[%d*%r]' % (count, pivot) - - def __iter__(self): - pnode = self.parent_node - aname = self.parent_attribute - for idx, item in enumerate(self._items): - yield Base(item, pnode, (aname, idx)) - - def __getitem__(self, index): - return Base(self._items[index], self.parent_node, (self.parent_attribute, index)) - - def __eq__(self, other): - cls = type(self) - if cls is type(other): - return self._items == other._items - return False - - @property - def string_value(self): - if len(self) != 1: # pragma: no cover - raise TypeError('%r does not contain a single String node' % self) - node = self[0] - if node.node_tag != 'String': # pragma: no cover - raise TypeError('%r does not contain a single String node' % self) - return node.val.value - - def traverse(self): - "A generator that recursively traverse all the items in the list." - - for item in self: - yield from item.traverse() - - -class Node(Base): - """Represent a single entry in a *parse tree*. - - :type details: :class:`.ast.Node` - :param details: the *parse tree* of the node - :type parent: ``None`` or :class:`Node` instance - :param parent: ``None`` to indicate that the node is the *root* of the parse tree, - otherwise it is the immediate parent of the new node - :type name: str or tuple - :param name: the name of the attribute in the `parent` node that *points* to this one; - it may be a tuple (name, position) when ``parent[name]`` is actually a list of - nodes - """ - - __slots__ = ('node_tag', 'ast_node') - - def _init(self, details, parent=None, name=None): - if not isinstance(details, ast.Node): # pragma: no cover - raise ValueError("Unexpected value for 'details', must be a ast.Node") - super()._init(parent, name) - self.node_tag = details.__class__.__name__ - self.ast_node = details - - def __getattr__(self, attr): - value = getattr(self.ast_node, attr) - if value is None: - return Missing - else: - return Base(value, self, attr) - - def __repr__(self): - return '{%s}' % self.node_tag - - def __getitem__(self, attr): - if isinstance(attr, tuple): # pragma: no cover - attr, index = attr - return self[attr][index] - elif isinstance(attr, str): - return getattr(self, attr) - else: - raise ValueError('Wrong key type %r, must be str or tuple' - % type(attr).__name__) - - def __iter__(self): - node = self.ast_node - for attr in sorted(node.__slots__): - value = getattr(node, attr) - if value is not None: - yield Base(value, self, attr) - - def __eq__(self, other): - cls = type(self) - if cls is type(other): - return self.ast_node == other.ast_node - return False - - @property - def attribute_names(self): - "The names of the attributes present in the parse tree of the node." - - return tuple(self.ast_node) - - def traverse(self): - "A generator that recursively traverse all attributes of the node." - - yield self - for item in self: - yield from item.traverse() - - -class Scalar(Base): - "Represent a single scalar value." - - __slots__ = ('_value',) - - def _init(self, value, parent, name): - if value is not None and not isinstance(value, (bool, float, int, str, Decimal, - ast.Value)): - raise ValueError("Unexpected value for 'value', must be either None or a" - " bool|float|int|str|Decimal instance, got %r" % type(value)) - super()._init(parent, name) - self._value = value - - def __and__(self, other): - value = self._value - if isinstance(value, int) and isinstance(other, int): - return value & other - else: # pragma: no cover - raise ValueError("Wrong operands for __and__: %r & %r" - % (type(value), type(other))) - - def __bool__(self): - value = self._value - if value is None: - return False - if isinstance(value, str): - if len(value) == 0: - return False - elif len(value) == 1: - return value[0] != '\x00' - else: - return True - elif isinstance(value, bool): - return value - return True - - def __hash__(self): - return hash(self._value) - - def __eq__(self, other): - value = self._value - if isinstance(other, Enum): - return value == other - elif isinstance(other, type(value)): - return value == other - else: - cls = type(self) - if cls is type(other): - return value == other._value - return False - - def __repr__(self): # pragma: no cover - if isinstance(self._value, Enum): - return repr(self._value) - else: - return '<%r>' % self._value - - def traverse(self): - yield self - - @property - def value(self): - return self._value diff --git a/pglast/printers/__init__.py b/pglast/printers/__init__.py index c0b8287d..ba181df5 100644 --- a/pglast/printers/__init__.py +++ b/pglast/printers/__init__.py @@ -3,7 +3,7 @@ # :Created: sab 05 ago 2017 16:33:14 CEST # :Author: Lele Gaifax # :License: GNU General Public License version 3 or later -# :Copyright: © 2017, 2018, 2020, 2021 Lele Gaifax +# :Copyright: © 2017, 2018, 2020, 2021, 2022 Lele Gaifax # from .. import ast @@ -11,7 +11,7 @@ NODE_PRINTERS = {} -"Registry of specialized printers, keyed by their `tag`." +"Registry of specialized node printers, keyed by their class." SPECIAL_FUNCTIONS = {} @@ -22,67 +22,69 @@ class PrinterAlreadyPresentError(Error): "Exception raised trying to register another function for a tag already present." -def get_printer_for_node_tag(parent_node_tag, node_tag): - """Get specific printer implementation for given `node_tag`. +def get_printer_for_node(node): + """Get specific printer implementation for given `node`. If there is a more specific printer for it, when it's inside a particular - `parent_node_tag`, return that instead. + ancestor, return that instead. """ - printer = NODE_PRINTERS.get((parent_node_tag, node_tag)) + node_class = type(node) + if not issubclass(node_class, ast.Node): + raise ValueError('Expected an ast.Node, not a %r' % node_class.__name__) + parent = abs(node.ancestors) + parent_node_class = None if parent is None else type(parent.node) + printer = NODE_PRINTERS.get((parent_node_class, node_class)) if printer is None: - printer = NODE_PRINTERS.get(node_tag) + printer = NODE_PRINTERS.get(node_class) if printer is None: raise NotImplementedError("Printer for node %r is not implemented yet" - % node_tag) + % node_class.__name__) return printer -def node_printer(*node_tags, override=False, check_tags=True): - r"""Decorator to register a specific printer implementation for given `node_tag`. +def node_printer(*nodes, override=False): + r"""Decorator to register a specific printer implementation for a (set of) `nodes`. - :param \*node_tags: one or two node tags + :param \*nodes: a list of one or two items :param bool override: when ``True`` the function will be registered even if already present in the :data:`NODE_PRINTERS` registry - :param bool check_tags: - by default each `node_tags` is checked for validity, that is must be a valid class - implemented by :mod:`pglast.ast`; pass ``False`` to disable the check - - When `node_tags` contains a single item then the decorated function is the *generic* one, - and it will be registered in :data:`NODE_PRINTERS` with that key alone. Otherwise it must - contain two elements: the first may be either a scalar value or a sequence of parent tags, - and the function will be registered under the key ``(parent_tag, tag)``. + + When `nodes` contains a single item then the decorated function is the *generic* + one, and it will be registered in :data:`NODE_PRINTERS` with that key alone. Otherwise it + must contain two elements: the first may be either a scalar value or a sequence of parent + nodes, and the function will be registered under the key ``(parent, node)``. """ - if check_tags: - for tag in node_tags: - if not isinstance(tag, (list, tuple)): - tag = (tag,) - for t in tag: - if not isinstance(t, str): - raise ValueError(f'Invalid tag {t!r}, expected a string') - if not hasattr(ast, t): - raise ValueError(f'Unknown tag name {t}') + n = len(nodes) + if n == 1: + parent_classes = (None,) + node_class = nodes[0] + elif n == 2: + pclasses, node_class = nodes + if not isinstance(pclasses, (list, tuple)): + pclasses = (pclasses,) + parent_classes = tuple(getattr(ast, cls) if isinstance(cls, str) else cls + for cls in pclasses) + else: + raise ValueError('Invalid nodes: must contain one or two items!') + + if isinstance(node_class, str): + node_class = getattr(ast, node_class) + elif not isinstance(node_class, type) or not issubclass(node_class, ast.Node): + raise ValueError('Invalid nodes: expected an ast.Node instance or its name,' + ' got %r' % node_class) def decorator(impl): - if len(node_tags) == 1: - parent_tags = (None,) - tag = node_tags[0] - elif len(node_tags) == 2: - parent_tags, tag = node_tags - if not isinstance(parent_tags, (list, tuple)): - parent_tags = (parent_tags,) - else: - raise ValueError(f'Must specify one or two tags, got {len(node_tags)} instead') - - for parent_tag in parent_tags: - t = tag if parent_tag is None else (parent_tag, tag) - if not override and t in NODE_PRINTERS: + for parent_class in parent_classes: + key = node_class if parent_class is None else (parent_class, node_class) + if not override and key in NODE_PRINTERS: # pragma: no cover raise PrinterAlreadyPresentError("A printer is already registered for tag %r" - % t) - NODE_PRINTERS[t] = impl + % key) + NODE_PRINTERS[key] = impl return impl + return decorator @@ -123,26 +125,20 @@ def __init__(self): self.value_to_symbol = {m.value: m.name for m in self.enum} def __call__(self, value, node, output): - from ..node import Missing, Scalar - - if value is Missing: + if value is None: # Should never happen, but better safe than sorry for symbol, member in self.enum.__members__.items(): if member.value == 0: break else: # pragma: no cover raise ValueError(f"Could not determine default value of class {self.enum!r}") - elif isinstance(value, Scalar): - if isinstance(value.value, str): - # libpg_query 13+ emits enum names, not values - symbol = value.value - if symbol not in self.enum.__members__: - raise ValueError('Unexpected symbol {symbol!r}, not in {self.enum!r}') - else: - symbol = self.value_to_symbol.get(value.value) + elif isinstance(value, self.enum): + symbol = self.value_to_symbol.get(value) + elif isinstance(value, ast.Integer): + symbol = self.value_to_symbol.get(value.val) else: symbol = value - if symbol is None: + if symbol is None: # pragma: no cover raise ValueError(f"Invalid value {value!r}, not in class {self.enum!r}") method = getattr(self, symbol, None) if method is None: @@ -155,4 +151,12 @@ def __call__(self, value, node, output): method(node, output) +def get_string_value(lst): + "Helper function to get a literal string value, wrapped in a one-sized list." + + if len(lst) != 1 or not isinstance(lst[0], ast.String): # pragma: no cover + raise TypeError('%r does not contain a single String node' % lst) + return lst[0].val + + from . import ddl, dml, sfuncs # noqa: F401,E402 diff --git a/pglast/printers/ddl.py b/pglast/printers/ddl.py index 33350d12..60c57b59 100644 --- a/pglast/printers/ddl.py +++ b/pglast/printers/ddl.py @@ -8,18 +8,17 @@ import re -from .. import enums -from ..node import Missing, List -from . import IntEnumPrinter, node_printer +from .. import ast, enums +from . import IntEnumPrinter, get_string_value, node_printer @node_printer('AccessPriv') def access_priv(node, output): - if node.priv_name is Missing: + if node.priv_name is None: output.write('ALL PRIVILEGES') else: - output.write(node.priv_name.value.upper()) - if node.cols is not Missing: + output.write(node.priv_name.upper()) + if node.cols is not None: output.write(' (') output.print_list(node.cols, ',', are_names=True) output.write(')') @@ -109,7 +108,7 @@ def alter_extension_stmt(node, output): @node_printer('AlterExtensionStmt', 'DefElem') def alter_extension_stmt_def_elem(node, output): - option = node.defname.value + option = node.defname if option == 'new_version': output.write('UPDATE TO ') output.print_node(node.arg) @@ -125,7 +124,7 @@ def alter_extension_contents_stmt(node, output): output.write(' DROP ') else: output.write(' ADD ') - output.write(OBJECT_NAMES[node.objtype.value]) + output.write(OBJECT_NAMES[node.objtype]) output.write(' ') output.print_node(node.object) @@ -137,19 +136,19 @@ def alter_enum_stmt(node, output): if node.newVal: if node.oldVal: output.write('RENAME VALUE ') - output.write_quoted_string(node.oldVal.value) + output.write_quoted_string(node.oldVal) output.write('TO ') else: output.write('ADD VALUE ') if node.skipIfNewValExists: output.write('IF NOT EXISTS ') - output.write_quoted_string(node.newVal.value) + output.write_quoted_string(node.newVal) if node.newValNeighbor: if node.newValIsAfter: output.write(' AFTER ') else: output.write(' BEFORE ') - output.write_quoted_string(node.newValNeighbor.value) + output.write_quoted_string(node.newValNeighbor) @node_printer('AlterDefaultPrivilegesStmt') @@ -157,14 +156,15 @@ def alter_default_privileges_stmt(node, output): output.writes('ALTER DEFAULT PRIVILEGES') roles = None schemas = None - for opt in node.options: - optname = opt.defname.value - if optname == 'roles': - roles = opt.arg - elif optname == 'schemas': - schemas = opt.arg - else: # pragma: no cover - raise NotImplementedError('Option not implemented: %s' % optname) + if node.options is not None: + for opt in node.options: + optname = opt.defname + if optname == 'roles': + roles = opt.arg + elif optname == 'schemas': + schemas = opt.arg + else: # pragma: no cover + raise NotImplementedError('Option not implemented: %s' % optname) if roles is not None: output.newline() with output.push_indent(2): @@ -191,7 +191,7 @@ def alter_default_privileges_stmt(node, output): else: output.write('ALL PRIVILEGES') output.write(' ON ') - output.write(OBJECT_NAMES[action.objtype.value]) + output.write(OBJECT_NAMES[action.objtype]) output.write('S ') output.writes(preposition) output.print_list(action.grantees, ',') @@ -275,7 +275,7 @@ def alter_op_family_stmt(node, output): @node_printer('AlterOwnerStmt') def alter_owner_stmt(node, output): output.write('ALTER ') - output.writes(OBJECT_NAMES[node.objectType.value]) + output.writes(OBJECT_NAMES[node.objectType]) OT = enums.ObjectType if node.objectType in (OT.OBJECT_OPFAMILY, OT.OBJECT_OPCLASS): @@ -337,7 +337,7 @@ def alter_seq_stmt(node, output): @node_printer('AlterTableStmt') def alter_table_stmt(node, output): output.write('ALTER ') - output.writes(OBJECT_NAMES[node.relkind.value]) + output.writes(OBJECT_NAMES[node.relkind]) if node.missing_ok: output.write('IF EXISTS ') output.print_node(node.relation) @@ -381,7 +381,7 @@ def alter_def_elem(node, output): @node_printer('AlterTableStmt', 'RangeVar') def range_var(node, output): - if node.parent_node.relkind == enums.ObjectType.OBJECT_TABLE and not node.inh: + if abs(node.ancestors).node.relkind == enums.ObjectType.OBJECT_TABLE and not node.inh: output.write('ONLY ') if node.schemaname: output.print_name(node.schemaname) @@ -399,7 +399,7 @@ class AlterTableTypePrinter(IntEnumPrinter): def AT_AddColumn(self, node, output): output.write('ADD ') - if node.parent_node.relkind == enums.ObjectType.OBJECT_TYPE: + if abs(node.ancestors).node.relkind == enums.ObjectType.OBJECT_TYPE: output.write('ATTRIBUTE ') else: output.write('COLUMN ') @@ -421,7 +421,7 @@ def AT_AddOf(self, node, output): def AT_AlterColumnType(self, node, output): output.write('ALTER ') - if node.parent_node.relkind == enums.ObjectType.OBJECT_TYPE: + if abs(node.ancestors).node.relkind == enums.ObjectType.OBJECT_TYPE: output.write('ATTRIBUTE ') else: output.write('COLUMN ') @@ -474,7 +474,7 @@ def AT_DropCluster(self, node, output): def AT_DropColumn(self, node, output): output.write('DROP ') - if node.parent_node.relkind == enums.ObjectType.OBJECT_TYPE: + if abs(node.ancestors).node.relkind == enums.ObjectType.OBJECT_TYPE: output.write('ATTRIBUTE ') else: output.write('COLUMN ') @@ -546,7 +546,7 @@ def AT_SetStatistics(self, node, output): if node.name: output.print_name(node.name) elif node.num: - output.write(str(node.num.value)) + output.write(str(node.num)) output.write(' SET STATISTICS ') output.print_node(node.def_) @@ -554,7 +554,7 @@ def AT_SetStorage(self, node, output): output.write('ALTER COLUMN ') output.print_name(node.name) output.write(' SET STORAGE ') - output.write(node.def_.val.value) + output.write(node.def_.val) def AT_SetUnLogged(self, node, output): output.write('SET UNLOGGED') @@ -596,7 +596,7 @@ def AT_DropExpression(self, node, output): def AT_AddIdentity(self, node, output): output.write('ALTER COLUMN ') output.print_name(node.name) - if node.num.value > 0: + if node.num > 0: # FIXME: find a way to get here output.print_node(node.num) output.write(' ADD ') @@ -606,7 +606,7 @@ def AT_AddIdentity(self, node, output): def AT_DropIdentity(self, node, output): output.write('ALTER COLUMN ') output.print_name(node.name) - if node.num.value > 0: + if node.num > 0: # FIXME: find a way to get here output.print_node(node.num) output.write(' DROP IDENTITY ') @@ -645,7 +645,7 @@ def AT_DisableTrigAll(self, node, output): def AT_SetIdentity(self, node, output): output.write('ALTER COLUMN ') output.print_name(node.name) - if node.num.value > 0: + if node.num > 0: # FIXME: find a way to get here output.print_node(node.num) for elem in node.def_: @@ -660,7 +660,7 @@ def AT_SetIdentity(self, node, output): output.write('CACHE ') output.print_node(elem.arg) elif elem.defname == 'cycle': - if elem.arg.val.value == 0: + if elem.arg.val == 0: output.write('NO ') output.write('CYCLE') elif elem.defname == 'increment': @@ -723,9 +723,9 @@ class AlterTSConfigTypePrinter(IntEnumPrinter): enum = enums.AlterTSConfigType def print_simple_name(self, node, output): - if isinstance(node, List): + if isinstance(node, tuple): node = node[0] - output.write(node.val.value) + output.write(node.val) def print_simple_list(self, nodes, output): first = True @@ -851,9 +851,9 @@ def alter_subscription_stmt(node, output): output.print_list(node.options) output.write(')') elif node.kind == enums.AlterSubscriptionType.ALTER_SUBSCRIPTION_ENABLED: - if node.options[0].arg.val.value == 0: + if node.options[0].arg.val == 0: output.write('DISABLE') - elif node.options[0].arg.val.value == 1: + elif node.options[0].arg.val == 1: output.write('ENABLE') @@ -920,7 +920,7 @@ def alter_user_mapping_stmt(node, output): output.write('ALTER USER MAPPING FOR ') role_spec(node.user, output) output.writes(' SERVER ') - output.write(node.servername.value) + output.write(node.servername) alter_def_elem(node.options, output) @@ -1036,9 +1036,8 @@ def column_def(node, output): @node_printer('CommentStmt') def comment_stmt(node, output): otypes = enums.ObjectType - output.write('COMMENT ') - output.write('ON ') - output.writes(OBJECT_NAMES[node.objtype.value]) + output.write('COMMENT ON ') + output.writes(OBJECT_NAMES[node.objtype]) if node.objtype in (otypes.OBJECT_OPCLASS, otypes.OBJECT_OPFAMILY): nodes = list(node.object) using = nodes.pop(0) @@ -1064,8 +1063,8 @@ def comment_stmt(node, output): output.print_name(nodes) elif node.objtype == otypes.OBJECT_AGGREGATE: _object_with_args(node.object, output, empty_placeholder='*') - elif isinstance(node.object, List): - if node.object[0].node_tag != 'String': + elif isinstance(node.object, tuple): + if not isinstance(node.object[0], ast.String): output.write(' (') output.print_list(node.object, ' AS ', standalone_items=False) output.write(')') @@ -1078,7 +1077,7 @@ def comment_stmt(node, output): output.write('IS ') if node.comment: with output.push_indent(): - output.write_quoted_string(node.comment.value) + output.write_quoted_string(node.comment) else: output.write('NULL') @@ -1166,14 +1165,14 @@ def CONSTR_FOREIGN(self, node, output): output.write(' (') output.print_name(node.pk_attrs, ',') output.write(')') - if node.fk_matchtype and node.fk_matchtype != enums.FKCONSTR_MATCH_SIMPLE: + if node.fk_matchtype != '\0' and node.fk_matchtype != enums.FKCONSTR_MATCH_SIMPLE: output.write(' MATCH ') if node.fk_matchtype == enums.FKCONSTR_MATCH_FULL: output.write('FULL') elif node.fk_matchtype == enums.FKCONSTR_MATCH_PARTIAL: # pragma: no cover # MATCH PARTIAL not yet implemented output.write('PARTIAL') - if node.fk_del_action and node.fk_del_action != enums.FKCONSTR_ACTION_NOACTION: + if node.fk_del_action != '\0' and node.fk_del_action != enums.FKCONSTR_ACTION_NOACTION: output.write(' ON DELETE ') if node.fk_del_action == enums.FKCONSTR_ACTION_RESTRICT: output.write('RESTRICT') @@ -1183,7 +1182,7 @@ def CONSTR_FOREIGN(self, node, output): output.write('SET NULL') elif node.fk_del_action == enums.FKCONSTR_ACTION_SETDEFAULT: output.write('SET DEFAULT') - if node.fk_upd_action and node.fk_upd_action != enums.FKCONSTR_ACTION_NOACTION: + if node.fk_upd_action != '\0' and node.fk_upd_action != enums.FKCONSTR_ACTION_NOACTION: output.write(' ON UPDATE ') if node.fk_upd_action == enums.FKCONSTR_ACTION_RESTRICT: output.write('RESTRICT') @@ -1313,15 +1312,15 @@ def create_db_stmt(node, output): @node_printer('CreatedbStmt', 'DefElem') def create_db_stmt_def_elem(node, output): - option = node.defname.value + option = node.defname if option == 'connection_limit': output.write('connection limit') else: output.print_symbol(node.defname) - if node.arg is not Missing: + if node.arg is not None: output.write(' = ') - if isinstance(node.arg, List) or option in ('allow_connections', 'is_template'): - output.write(node.arg.val.value) + if isinstance(node.arg, tuple) or option in ('allow_connections', 'is_template'): + output.write(node.arg.val) else: output.print_node(node.arg) @@ -1353,8 +1352,8 @@ def create_conversion_stmt(node, output): output.write('DEFAULT ') output.write('CONVERSION ') output.print_name(node.conversion_name) - output.write(" FOR '%s' TO '%s'" % (node.for_encoding_name.value, - node.to_encoding_name.value)) + output.write(" FOR '%s' TO '%s'" % (node.for_encoding_name, + node.to_encoding_name)) output.write(' FROM ') output.print_name(node.func_name) @@ -1420,9 +1419,9 @@ def create_extension_stmt(node, output): @node_printer('CreateExtensionStmt', 'DefElem') def create_extension_stmt_def_elem(node, output): - option = node.defname.value + option = node.defname if option == 'cascade': - if node.arg.val.value == 1: + if node.arg.val == 1: output.write('CASCADE') elif option == 'old_version': # FIXME: find a way to get here @@ -1462,15 +1461,15 @@ def create_fdw_stmt(node, output): @node_printer('CreateUserMappingStmt', 'DefElem') @node_printer('CreateFdwStmt', 'DefElem') def create_fdw_stmt_def_elem(node, output): - if node.parent_attribute[0] == 'options' or node.parent_attribute[0] == 'fdwoptions': - if ' ' in node.defname.value: - output.write(f'"{node.defname.value}"') + if abs(node.ancestors).member in ('options', 'fdwoptions'): + if ' ' in node.defname: + output.write(f'"{node.defname}"') else: - output.write(node.defname.value) + output.write(node.defname) output.write(' ') output.print_node(node.arg) else: - output.write(node.defname.value.upper()) + output.write(node.defname.upper()) output.write(' ') output.print_name(node.arg) @@ -1514,10 +1513,10 @@ def create_foreign_table_stmt(node, output): @node_printer('CreateForeignTableStmt', 'DefElem') @node_printer('CreateForeignServerStmt', 'DefElem') def create_foreign_table_stmt_def_elem(node, output): - if ' ' in node.defname.value: - output.write(f'"{node.defname.value}"') + if ' ' in node.defname: + output.write(f'"{node.defname}"') else: - output.write(node.defname.value) + output.write(node.defname) output.write(' ') output.print_node(node.arg) @@ -1534,21 +1533,24 @@ def create_function_stmt(node, output): output.print_name(node.funcname) output.write('(') - # Functions returning a SETOF needs special care, because the resulting record - # definition is intermixed with real parameters: split them into two separated - # lists - real_params = node.parameters - if node.returnType and node.returnType.setof: - fpm = enums.FunctionParameterMode - record_def = [] - real_params = [] - for param in node.parameters: - if param.mode == fpm.FUNC_PARAM_TABLE: - record_def.append(param) - else: - real_params.append(param) - if real_params: - output.print_list(real_params) + if node.parameters is not None: + # Functions returning a SETOF needs special care, because the resulting record + # definition is intermixed with real parameters: split them into two separated + # lists + real_params = node.parameters + if node.returnType and node.returnType.setof: + fpm = enums.FunctionParameterMode + record_def = [] + real_params = [] + for param in node.parameters: + if param.mode == fpm.FUNC_PARAM_TABLE: + record_def.append(param) + else: + real_params.append(param) + if real_params: + output.print_list(real_params) + else: + record_def = False output.write(')') if node.returnType: @@ -1561,32 +1563,31 @@ def create_function_stmt(node, output): output.write(')') else: output.print_node(node.returnType) - for option in node.options: output.print_node(option) @node_printer(('AlterFunctionStmt', 'CreateFunctionStmt', 'DoStmt'), 'DefElem') def create_function_option(node, output): - option = node.defname.value + option = node.defname if option == 'as': - if isinstance(node.arg, List) and len(node.arg) > 1: + if isinstance(node.arg, tuple) and len(node.arg) > 1: # We are in the weird C case output.write('AS ') output.print_list(node.arg) return - if node.parent_node.node_tag == 'CreateFunctionStmt': + if isinstance(abs(node.ancestors).node, ast.CreateFunctionStmt): output.newline() output.write('AS ') # Choose a valid dollar-string delimiter - if isinstance(node.arg, List): - code = node.arg[0].val.value + if isinstance(node.arg, tuple): + code = node.arg[0].val else: - code = node.arg.val.value + code = node.arg.val used_delimiters = set(re.findall(r"\$(\w*)(?=\$)", code)) unique_delimiter = '' while unique_delimiter in used_delimiters: @@ -1601,7 +1602,7 @@ def create_function_option(node, output): return if option == 'security': - if node.arg.val.value == 1: + if node.arg.val == 1: output.swrite('SECURITY DEFINER') else: output.swrite('SECURITY INVOKER') @@ -1616,7 +1617,7 @@ def create_function_option(node, output): if option == 'volatility': output.separator() - output.write(node.arg.val.value.upper()) + output.write(node.arg.val.upper()) return if option == 'parallel': @@ -1629,7 +1630,7 @@ def create_function_option(node, output): return if option == 'leakproof': - if node.arg.val.value == 0: + if node.arg.val == 0: output.swrite('NOT') output.swrite('LEAKPROOF') return @@ -1643,7 +1644,7 @@ def create_function_option(node, output): output.write('WINDOW') return - output.writes(node.defname.value.upper()) + output.writes(node.defname.upper()) output.print_symbol(node.arg) @@ -1665,7 +1666,7 @@ def create_opclass_stmt(node, output): def create_opclass_item(node, output): if node.itemtype == enums.OPCLASS_ITEM_OPERATOR: output.write('OPERATOR ') - output.write('%d ' % node.number._value) + output.write('%d ' % node.number) if node.name: _object_with_args(node.name, output, symbol=True, skip_empty_args=True) if node.order_family: @@ -1677,7 +1678,7 @@ def create_opclass_item(node, output): output.write(')') elif node.itemtype == enums.OPCLASS_ITEM_FUNCTION: output.write('FUNCTION ') - output.write('%d ' % node.number._value) + output.write('%d ' % node.number) if node.class_args: output.write(' (') output.print_list(node.class_args, standalone_items=False) @@ -1817,7 +1818,7 @@ def create_role_stmt(node, output): @node_printer('AlterRoleStmt', 'DefElem') @node_printer('CreateRoleStmt', 'DefElem') def create_or_alter_role_option(node, output): - option = node.defname.value + option = node.defname argv = node.arg if option == 'sysid': output.write('SYSID ') @@ -1888,7 +1889,7 @@ def create_seq_stmt(node, output): output.writes('SEQUENCE') if node.if_not_exists: output.writes('IF NOT EXISTS') - if node.sequence.schemaname is not Missing: + if node.sequence.schemaname is not None: output.print_name(node.sequence.schemaname) output.write('.') output.print_name(node.sequence.relname) @@ -1903,9 +1904,9 @@ def create_seq_stmt(node, output): @node_printer('CreateSeqStmt', 'DefElem') @node_printer('AlterSeqStmt', 'DefElem') def create_seq_stmt_def_elem(node, output): - option = node.defname.value + option = node.defname if option == 'cycle': - if node.arg.val.value == 0: + if node.arg.val == 0: output.write('NO ') output.write('CYCLE') elif option == 'increment': @@ -1923,10 +1924,10 @@ def create_seq_stmt_def_elem(node, output): output.write(' WITH ') output.print_node(node.arg) else: - if node.arg is Missing: + if node.arg is None: output.write('NO ') output.write(option.upper()) - if node.arg is not Missing: + if node.arg is not None: output.write(' ') output.print_node(node.arg) if node.defaction != enums.DefElemAction.DEFELEM_UNSPEC: # pragma: no cover @@ -1952,11 +1953,7 @@ def create_stats_stmt(node, output): @node_printer('CreateStmt') def create_stmt(node, output): output.writes('CREATE') - # NB: parent_node may be None iff we are dealing with a single concrete statement, not with - # the ephemeral RawStmt returned by parse_sql(); in all other cases in this source where - # we access the parent_node we are actually in a "contextualized printer", that can only be - # entered when there is a parent node - if node.parent_node is not None and node.parent_node.node_tag == 'CreateForeignTableStmt': + if isinstance(node.ancestors[0], ast.CreateForeignTableStmt): output.writes('FOREIGN') else: if node.relation.relpersistence == enums.RELPERSISTENCE_TEMP: @@ -2045,7 +2042,7 @@ def create_table_as_stmt(node, output): output.writes('TEMPORARY') elif node.into.rel.relpersistence == enums.RELPERSISTENCE_UNLOGGED: output.writes('UNLOGGED') - output.writes(OBJECT_NAMES[node.relkind.value]) + output.writes(OBJECT_NAMES[node.relkind]) if node.if_not_exists: output.writes('IF NOT EXISTS') output.print_node(node.into) @@ -2068,7 +2065,7 @@ def create_table_space_stmt(node, output): output.print_node(node.owner) output.space() output.write('LOCATION ') - output.write_quoted_string(node.location.value) + output.write_quoted_string(node.location) if node.options: output.write(' WITH (') output.print_list(node.options) @@ -2084,7 +2081,7 @@ def create_trig_stmt(node, output): output.print_name(node.trigname) output.newline() with output.push_indent(2): - if node.timing.value: + if node.timing: if node.timing & enums.TRIGGER_TYPE_BEFORE: output.write('BEFORE ') elif node.timing & enums.TRIGGER_TYPE_INSTEAD: @@ -2149,8 +2146,8 @@ def create_subscription_stmt_stmt_def_elem(node, output): output.print_name(node.defname) if node.arg: output.write(' = ') - if node.arg.node_tag == 'String' and node.arg.val.value in ('true', 'false'): - output.write(node.arg.val.value) + if isinstance(node.arg, ast.String) and node.arg.val in ('true', 'false'): + output.write(node.arg.val) else: output.print_node(node.arg) @@ -2220,7 +2217,7 @@ def create_user_mapping_stmt(node, output): output.write('FOR ') role_spec(node.user, output) output.writes(' SERVER') - output.write(node.servername.value) + output.write(node.servername) if node.options: output.write(' OPTIONS (') output.print_list(node.options, ',') @@ -2241,22 +2238,22 @@ def define_stmt(node, output): output.write('CREATE ') if node.replace: output.write('OR REPLACE ') - output.writes(OBJECT_NAMES[node.kind.value]) + output.writes(OBJECT_NAMES[node.kind]) if node.if_not_exists: output.write('IF NOT EXISTS ') output.print_list(node.defnames, '.', standalone_items=False, are_names=True, is_symbol=node.kind == enums.ObjectType.OBJECT_OPERATOR) - if node.args is not Missing: + if node.args is not None: # args is actually a tuple (list-of-nodes, integer): the integer value, if different # from -1, is the number of nodes representing the actual arguments, remaining are # ORDER BY args, count = node.args - count = count.val.value + count = count.val output.write(' (') if count == -1: # Special case: if it's an aggregate, and the scalar is equal to # None (not is, since it's a Scalar), write a star - if ((node.kind.value == enums.ObjectType.OBJECT_AGGREGATE + if ((node.kind == enums.ObjectType.OBJECT_AGGREGATE and args == None)): output.write('*') actual_args = [] @@ -2300,7 +2297,7 @@ def define_stmt(node, output): @node_printer('DefElem') def def_elem(node, output): output.print_symbol(node.defname) - if node.arg is not Missing: + if node.arg is not None: output.write(' = ') output.print_node(node.arg) if node.defaction != enums.DefElemAction.DEFELEM_UNSPEC: # pragma: no cover @@ -2309,10 +2306,10 @@ def def_elem(node, output): @node_printer('DefineStmt', 'DefElem') def define_stmt_def_elem(node, output): - output.print_node(node.defname, is_name=True) - if node.arg is not Missing: + output.print_name(node.defname) + if node.arg is not None: output.write(' = ') - if isinstance(node.arg, List): + if isinstance(node.arg, tuple): is_symbol = node.defname in ('commutator', 'negator') if is_symbol and len(node.arg) > 1: output.write('OPERATOR(') @@ -2394,7 +2391,7 @@ def drop_stmt(node, output): otypes = enums.ObjectType output.write('DROP ') # Special case functions since they are not special objects - output.writes(OBJECT_NAMES[node.removeType.value]) + output.writes(OBJECT_NAMES[node.removeType]) if node.removeType == otypes.OBJECT_INDEX: if node.concurrent: output.write(' CONCURRENTLY') @@ -2422,13 +2419,16 @@ def drop_stmt(node, output): output.print_name(nodes[-1]) output.write(' ON ') output.print_name(on) - elif isinstance(node.objects[0], List): - if node.objects[0][0].node_tag != 'String': - output.print_lists(node.objects, ' AS ', standalone_items=False, - are_names=True) - else: - output.print_lists(node.objects, sep='.', sublist_open='', sublist_close='', - standalone_items=False, are_names=True) + elif node.removeType == otypes.OBJECT_CAST: + names = node.objects[0] + output.write('(') + output.print_name(names[0]) + output.write(' AS ') + output.print_name(names[1]) + output.write(')') + elif isinstance(node.objects[0], tuple): + output.print_lists(node.objects, sep='.', sublist_open='', sublist_close='', + standalone_items=False, are_names=True) else: output.print_list(node.objects, ',', standalone_items=False, are_names=True) if node.behavior == enums.DropBehavior.DROP_CASCADE: @@ -2470,7 +2470,7 @@ def drop_user_mapping_stmt(node, output): @node_printer('FunctionParameter') def function_parameter(node, output): - if node.mode is not Missing: + if node.mode is not None: pm = enums.FunctionParameterMode if node.mode == pm.FUNC_PARAM_IN: pass # omit the default, output.write('IN ') @@ -2488,7 +2488,7 @@ def function_parameter(node, output): output.print_name(node.name) output.write(' ') output.print_node(node.argType) - if node.defexpr is not Missing: + if node.defexpr is not None: output.write(' = ') output.print_node(node.defexpr) @@ -2508,10 +2508,10 @@ def grant_stmt(node, output): else: output.write('ALL PRIVILEGES') # hack for OBJECT_FOREIGN_SERVER - if node.objtype.value == enums.ObjectType.OBJECT_FOREIGN_SERVER: + if node.objtype == enums.ObjectType.OBJECT_FOREIGN_SERVER: object_name = 'FOREIGN SERVER' else: - object_name = OBJECT_NAMES[node.objtype.value] + object_name = OBJECT_NAMES[node.objtype] target = node.targtype output.newline() output.space(2) @@ -2631,7 +2631,7 @@ def index_stmt(node, output): def lock_stmt(node, output): output.write('LOCK ') output.print_list(node.relations, ',') - lock_mode = node.mode.value + lock_mode = node.mode lock_str = LOCK_MODE_NAMES[lock_mode] output.write('IN ') output.write(lock_str) @@ -2646,7 +2646,7 @@ def notify_stmt(node, output): output.print_name(node.conditionname) if node.payload: output.write(', ') - output.write_quoted_string(node.payload.value) + output.write_quoted_string(node.payload) def _object_with_args(node, output, unquote_name=False, symbol=False, @@ -2657,9 +2657,9 @@ def _object_with_args(node, output, unquote_name=False, symbol=False, for idx, name in enumerate(node.objname): if idx > 0: output.write('.') - output.write(name.val.value) + output.write(name.val) else: - output.write(node.objname.string_value) + output.write(get_string_value(node.objname)) elif symbol: output.print_symbol(node.objname) else: @@ -2682,7 +2682,7 @@ def object_with_args(node, output): @node_printer(('AlterObjectSchemaStmt',), 'ObjectWithArgs') def alter_object_schema_stmt_object_with_args(node, output): - symbol = node.parent_node.objectType == enums.ObjectType.OBJECT_OPERATOR + symbol = abs(node.ancestors).node.objectType == enums.ObjectType.OBJECT_OPERATOR _object_with_args(node, output, symbol=symbol) @@ -2693,20 +2693,21 @@ def alter_operator_stmt_object_with_args(node, output): @node_printer(('AlterOwnerStmt',), 'ObjectWithArgs') def alter_owner_stmt_object_with_args(node, output): - unquote_name = node.parent_node.objectType == enums.ObjectType.OBJECT_OPERATOR + unquote_name = abs(node.ancestors).node.objectType == enums.ObjectType.OBJECT_OPERATOR _object_with_args(node, output, unquote_name=unquote_name) @node_printer(('CommentStmt',), 'ObjectWithArgs') def comment_stmt_object_with_args(node, output): - unquote_name = node.parent_node.objtype == enums.ObjectType.OBJECT_OPERATOR + unquote_name = abs(node.ancestors).node.objtype == enums.ObjectType.OBJECT_OPERATOR _object_with_args(node, output, unquote_name=unquote_name) @node_printer(('DropStmt',), 'ObjectWithArgs') def drop_stmt_object_with_args(node, output): - unquote_name = node.parent_node.removeType == enums.ObjectType.OBJECT_OPERATOR - if node.parent_node.removeType == enums.ObjectType.OBJECT_AGGREGATE: + parent_node = abs(node.ancestors).node + unquote_name = parent_node.removeType == enums.ObjectType.OBJECT_OPERATOR + if parent_node.removeType == enums.ObjectType.OBJECT_AGGREGATE: _object_with_args(node, output, empty_placeholder='*', unquote_name=unquote_name) else: _object_with_args(node, output, unquote_name=unquote_name) @@ -2730,7 +2731,7 @@ def partition_bound_spec(node, output): output.write(')') elif node.strategy == enums.PARTITION_STRATEGY_HASH: output.write('WITH (MODULUS %d, REMAINDER %d)' - % (node.modulus.value, node.remainder.value)) + % (node.modulus, node.remainder)) else: raise NotImplementedError('Unhandled strategy %r' % node.strategy) @@ -2765,7 +2766,7 @@ def partition_range_datum(node, output): elif node.kind == enums.PartitionRangeDatumKind.PARTITION_RANGE_DATUM_MAXVALUE: output.write('MAXVALUE') else: - output.print_node(node.value) + output.print_node(node) @node_printer('PartitionSpec') @@ -2875,8 +2876,8 @@ def rename_stmt(node, output): @node_printer('RenameStmt', 'RangeVar') def rename_stmt_range_var(node, output): OT = enums.ObjectType - if not node.inh and node.parent_node.renameType not in (OT.OBJECT_ATTRIBUTE, - OT.OBJECT_TYPE): + if not node.inh and abs(node.ancestors).node.renameType not in (OT.OBJECT_ATTRIBUTE, + OT.OBJECT_TYPE): output.write('ONLY ') if node.schemaname: output.print_name(node.schemaname) @@ -2932,7 +2933,7 @@ def rule_stmt_printer(node, output): output.newline() with output.push_indent(2): output.write('ON ') - output.write(EVENT_NAMES[node.event.value]) + output.write(EVENT_NAMES[node.event]) output.write(' TO ') output.print_name(node.relation) if node.whereClause: @@ -2987,7 +2988,7 @@ def sec_label_stmt(node, output): output.write('FOR ') output.print_name(node.provider) output.write(' ON ') - output.write(OBJECT_NAMES[node.objtype.value]) + output.write(OBJECT_NAMES[node.objtype]) output.write(' ') output.print_name(node.object) output.write(' IS ') @@ -3048,10 +3049,10 @@ def vacuum_stmt(node, output): @node_printer('VacuumStmt', 'DefElem') def vacuum_stmt_def_elem(node, output): - output.write(node.defname.value.upper()) + output.write(node.defname.upper()) if node.arg: output.write(' ') - output.write(str(node.arg.val.value)) + output.write(str(node.arg.val)) @node_printer('VacuumRelation') diff --git a/pglast/printers/dml.py b/pglast/printers/dml.py index f0f63eba..9d4aabb8 100644 --- a/pglast/printers/dml.py +++ b/pglast/printers/dml.py @@ -6,9 +6,8 @@ # :Copyright: © 2017, 2018, 2019, 2020, 2021, 2022 Lele Gaifax # -from .. import enums -from ..node import Missing, List -from . import IntEnumPrinter, node_printer +from .. import ast, enums +from . import IntEnumPrinter, get_string_value, node_printer @node_printer('A_ArrayExpr') @@ -38,24 +37,24 @@ def AEXPR_BETWEEN_SYM(self, node, output): output.print_list(node.rexpr, 'AND', relative_indent=-4) def AEXPR_DISTINCT(self, node, output): - if node.lexpr.node_tag == 'BoolExpr': + if isinstance(node.lexpr, ast.BoolExpr): output.write('(') output.print_node(node.lexpr) - if node.lexpr.node_tag == 'BoolExpr': + if isinstance(node.lexpr, ast.BoolExpr): output.write(') ') output.swrites('IS DISTINCT FROM') output.print_node(node.rexpr) def AEXPR_ILIKE(self, node, output): output.print_node(node.lexpr) - if node.name.string_value == '!~~*': + if get_string_value(node.name) == '!~~*': output.swrites('NOT') output.swrites('ILIKE') output.print_node(node.rexpr) def AEXPR_IN(self, node, output): output.print_node(node.lexpr) - if node.name.string_value == '<>': + if get_string_value(node.name) == '<>': output.swrites('NOT') output.swrite('IN (') output.print_list(node.rexpr) @@ -63,7 +62,7 @@ def AEXPR_IN(self, node, output): def AEXPR_LIKE(self, node, output): output.print_node(node.lexpr) - if node.name.string_value == '!~~': + if get_string_value(node.name) == '!~~': output.swrites('NOT') output.swrites('LIKE') output.print_node(node.rexpr) @@ -91,7 +90,7 @@ def AEXPR_NULLIF(self, node, output): def AEXPR_OF(self, node, output): output.print_node(node.lexpr) output.swrites('IS') - if node.name.string_value == '<>': + if get_string_value(node.name) == '<>': output.writes('NOT') output.write('OF (') output.print_list(node.rexpr) @@ -100,8 +99,8 @@ def AEXPR_OF(self, node, output): def AEXPR_OP(self, node, output): with output.expression(): # lexpr is optional because these are valid: -(1+1), +(1+1), ~(1+1) - if node.lexpr is not Missing: - if node.lexpr.node_tag == 'A_Expr': + if node.lexpr is not None: + if isinstance(node.lexpr, ast.A_Expr): if node.lexpr.kind == node.kind and node.lexpr.name == node.name: output.print_node(node.lexpr) else: @@ -110,15 +109,15 @@ def AEXPR_OP(self, node, output): else: output.print_node(node.lexpr) output.write(' ') - if isinstance(node.name, List) and len(node.name) > 1: + if isinstance(node.name, tuple) and len(node.name) > 1: output.write('OPERATOR(') output.print_symbol(node.name) output.write(') ') else: output.print_symbol(node.name) output.write(' ') - if node.rexpr is not Missing: - if node.rexpr.node_tag == 'A_Expr': + if node.rexpr is not None: + if isinstance(node.rexpr, ast.A_Expr): if node.rexpr.kind == node.kind and node.rexpr.name == node.name: output.print_node(node.rexpr) else: @@ -130,7 +129,7 @@ def AEXPR_OP(self, node, output): def AEXPR_OP_ALL(self, node, output): output.print_node(node.lexpr) output.write(' ') - output.write(node.name.string_value) + output.write(get_string_value(node.name)) output.write(' ALL(') output.print_node(node.rexpr) output.write(')') @@ -138,7 +137,7 @@ def AEXPR_OP_ALL(self, node, output): def AEXPR_OP_ANY(self, node, output): output.print_node(node.lexpr) output.write(' ') - output.write(node.name.string_value) + output.write(get_string_value(node.name)) output.write(' ANY(') output.print_node(node.rexpr) output.write(')') @@ -152,11 +151,11 @@ def AEXPR_PAREN(self, node, output): # pragma: no cover def AEXPR_SIMILAR(self, node, output): output.print_node(node.lexpr) - if node.name.string_value == '!~': + if get_string_value(node.name) == '!~': output.swrites('NOT') output.swrites('SIMILAR TO') - assert (node.rexpr.node_tag == 'FuncCall' - and node.rexpr.funcname[1].val.value == 'similar_to_escape') + assert (isinstance(node.rexpr, ast.FuncCall) + and node.rexpr.funcname[1].val == 'similar_to_escape') pattern = node.rexpr.args[0] output.print_node(pattern) if len(node.rexpr.args) > 1: @@ -188,11 +187,11 @@ def a_indices(node, output): @node_printer('A_Indirection') def a_indirection(node, output): - bracket = ((node.arg.node_tag in ('A_ArrayExpr', 'A_Expr', 'A_Indirection', 'FuncCall', - 'RowExpr', 'TypeCast')) + bracket = (isinstance(node.arg, (ast.A_ArrayExpr, ast.A_Expr, ast.A_Indirection, + ast.FuncCall, ast.RowExpr, ast.TypeCast)) or - (node.arg.node_tag == 'ColumnRef' - and node.indirection[0].node_tag != 'A_Indices')) + (isinstance(node.arg, ast.ColumnRef) + and not isinstance(node.indirection[0], ast.A_Indices))) if bracket: output.write('(') output.print_node(node.arg) @@ -244,9 +243,7 @@ def alias(node, output): @node_printer('BitString') def bitstring(node, output): - output.write(f"{node.val.value[0]}'") - output.write(node.val.value[1:]) - output.write("'") + output.write(f"{node.val[0]}'{node.val[1:]}'") @node_printer('BoolExpr') @@ -254,7 +251,7 @@ def bool_expr(node, output): bet = enums.BoolExprType outer_exp_level = output.expression_level with output.expression(): - in_res_target = node.parent_node.node_tag == 'ResTarget' + in_res_target = isinstance(node.ancestors[0], ast.ResTarget) if node.boolop == bet.AND_EXPR: relindent = -4 if not in_res_target and outer_exp_level == 0 else None output.print_list(node.args, 'AND', relative_indent=relindent) @@ -425,7 +422,7 @@ def copy_stmt(node, output): if node.is_program: output.write('PROGRAM ') if node.filename: - output.print_node(node.filename) + output.write_quoted_string(node.filename) else: if node.is_from: output.write('STDIN') @@ -444,7 +441,7 @@ def copy_stmt(node, output): @node_printer('CopyStmt', 'DefElem') def copy_stmt_def_elem(node, output): - option = node.defname.value + option = node.defname argv = node.arg if option == 'format': output.write('FORMAT ') @@ -452,7 +449,7 @@ def copy_stmt_def_elem(node, output): elif option == 'freeze': output.write('FREEZE') if argv: - output.swrite(str(argv.val.value)) + output.swrite(str(argv.val)) elif option == 'delimiter': output.write('DELIMITER ') output.print_node(argv) @@ -462,7 +459,7 @@ def copy_stmt_def_elem(node, output): elif option == 'header': output.write('HEADER') if argv: - output.swrite(str(argv.val.value)) + output.swrite(str(argv.val)) elif option == 'quote': output.write('QUOTE ') output.print_node(argv) @@ -472,7 +469,7 @@ def copy_stmt_def_elem(node, output): elif option == 'force_quote': output.write('FORCE_QUOTE ') # If it is a list print it. - if isinstance(argv, List): + if isinstance(argv, tuple): output.write('(') output.print_list(argv, are_names=True) output.write(')') @@ -548,7 +545,7 @@ def delete_stmt(node, output): @node_printer('ExecuteStmt') def execute_stmt(node, output): output.write('EXECUTE ') - output.print_node(node.name, is_name=True) + output.print_name(node.name) if node.params: output.write('(') output.print_list(node.params) @@ -570,7 +567,7 @@ def explain_stmt(node, output): @node_printer('ExplainStmt', 'DefElem') def explain_stmt_def_elem(node, output): output.print_symbol(node.defname) - if node.arg is not Missing: + if node.arg is not None: output.write(' ') output.print_symbol(node.arg) @@ -582,13 +579,13 @@ def FETCH_FORWARD(self, node, output): if node.howMany == enums.FETCH_ALL: output.write('ALL ') elif node.howMany != 1: - output.write(f'FORWARD {node.howMany.value} ') + output.write(f'FORWARD {node.howMany} ') def FETCH_BACKWARD(self, node, output): if node.howMany == enums.FETCH_ALL: output.write('BACKWARD ALL ') elif node.howMany != 1: - output.write(f'BACKWARD {node.howMany.value} ') + output.write(f'BACKWARD {node.howMany} ') else: output.write('PRIOR ') @@ -598,10 +595,10 @@ def FETCH_ABSOLUTE(self, node, output): elif node.howMany == -1: output.write('LAST ') else: - output.write(f'ABSOLUTE {node.howMany.value} ') + output.write(f'ABSOLUTE {node.howMany} ') def FETCH_RELATIVE(self, node, output): - output.write(f'RELATIVE {node.howMany.value} ') + output.write(f'RELATIVE {node.howMany} ') fetch_direction_printer = FetchDirectionPrinter() @@ -621,7 +618,7 @@ def float(node, output): @node_printer('FuncCall') def func_call(node, output): - name = '.'.join(n.val.value for n in node.funcname) + name = '.'.join(n.val for n in node.funcname) special_printer = output.get_printer_for_function(name) if special_printer is not None: special_printer(node, output) @@ -631,7 +628,7 @@ def func_call(node, output): output.write('(') if node.agg_distinct: output.writes('DISTINCT') - if node.args is Missing: + if node.args is None: if node.agg_star: output.write('*') else: @@ -698,7 +695,7 @@ def grouping_func(node, output): @node_printer('IndexElem') def index_elem(node, output): - if node.name is not Missing: + if node.name is not None: output.print_name(node.name) else: output.write('(') @@ -849,7 +846,7 @@ def join_expr(node, output): output.swrites('JOIN') - if node.rarg.node_tag == 'JoinExpr': + if isinstance(node.rarg, ast.JoinExpr): output.indent(3, relative=False) # need this for: # tests/test_printers_roundtrip.py::test_pg_regress_corpus[join.sql] - @@ -876,7 +873,7 @@ def join_expr(node, output): output.writes(') AS') output.print_name(node.alias) - if node.rarg.node_tag == 'JoinExpr': + if isinstance(node.rarg, ast.JoinExpr): output.dedent() @@ -946,13 +943,13 @@ def null_test(node, output): @node_printer('ParamRef') def param_ref(node, output): - if node.number is Missing: # pragma: no cover + if node.number is None: # pragma: no cover # NB: standard PG does not allow "?"-style param placeholders, this is a minor # deviation introduced by libpg_query; in version 2 apparently the case is merged # back to the standard style below output.write('?') else: - output.write('$%d' % node.number.value) + output.write('$%d' % node.number) @node_printer('PrepareStmt') @@ -1212,7 +1209,7 @@ def select_stmt(node, output): if node.valuesLists: # Is this a SELECT ... FROM (VALUES (...))? - require_parens = node.parent_node.node_tag == 'RangeSubselect' + require_parens = isinstance(node.ancestors[0], ast.RangeSubselect) if require_parens: output.write('(') output.write('VALUES ') @@ -1298,12 +1295,12 @@ def select_stmt(node, output): output.write('FETCH FIRST ') # FIXME do we need add '()' for all ? if ((node.limitCount - and node.limitCount.node_tag == "A_Expr" + and isinstance(node.limitCount, ast.A_Expr) and node.limitCount.kind == enums.A_Expr_Kind.AEXPR_OP)): output.write('(') output.print_node(node.limitCount) if ((node.limitCount - and node.limitCount.node_tag == "A_Expr" + and isinstance(node.limitCount, ast.A_Expr) and node.limitCount.kind == enums.A_Expr_Kind.AEXPR_OP)): output.write(')') if node.limitOption == enums.LimitOption.LIMIT_OPTION_WITH_TIES: @@ -1366,12 +1363,12 @@ def SVFOP_CURRENT_TIMESTAMP(self, node, output): def SVFOP_CURRENT_TIMESTAMP_N(self, node, output): # pragma: no cover output.write('CURRENT_TIMESTAMP(') - output.write(str(node.typmod.value)) + output.write(str(node.typmod)) output.write(')') def SVFOP_CURRENT_TIME_N(self, node, output): # pragma: no cover output.write('CURRENT_TIME(') - output.write(str(node.typmod.value)) + output.write(str(node.typmod)) output.write(')') def SVFOP_CURRENT_USER(self, node, output): @@ -1385,12 +1382,12 @@ def SVFOP_LOCALTIMESTAMP(self, node, output): def SVFOP_LOCALTIMESTAMP_N(self, node, output): # pragma: no cover output.write('LOCALTIMESTAMP(') - output.write(str(node.typmod.value)) + output.write(str(node.typmod)) output.write(')') def SVFOP_LOCALTIME_N(self, node, output): # pragma: no cover output.write('LOCALTIME(') - output.write(str(node.typmod.value)) + output.write(str(node.typmod)) output.write(')') def SVFOP_SESSION_USER(self, node, output): @@ -1422,13 +1419,13 @@ def sub_link(node, output): elif node.subLinkType == slt.ALL_SUBLINK: output.print_node(node.testexpr) output.write(' ') - output.write(node.operName.string_value) + output.write(get_string_value(node.operName)) output.write(' ALL ') elif node.subLinkType == slt.ANY_SUBLINK: output.print_node(node.testexpr) if node.operName: output.write(' ') - output.write(node.operName.string_value) + output.write(get_string_value(node.operName)) output.write(' ANY ') else: output.write(' IN ') @@ -1451,57 +1448,57 @@ def sub_link(node, output): @node_printer('TransactionStmt') def transaction_stmt(node, output): tsk = enums.TransactionStmtKind - if node.kind.value == tsk.TRANS_STMT_BEGIN: + if node.kind == tsk.TRANS_STMT_BEGIN: output.write('BEGIN ') if node.options: output.print_list(node.options) - elif node.kind.value == tsk.TRANS_STMT_START: + elif node.kind == tsk.TRANS_STMT_START: output.write('START TRANSACTION ') if node.options: output.print_list(node.options) - elif node.kind.value == tsk.TRANS_STMT_COMMIT: + elif node.kind == tsk.TRANS_STMT_COMMIT: output.write('COMMIT ') if node.chain: output.write('AND CHAIN ') - elif node.kind.value == tsk.TRANS_STMT_ROLLBACK: + elif node.kind == tsk.TRANS_STMT_ROLLBACK: output.write('ROLLBACK ') if node.chain: output.write('AND CHAIN ') - elif node.kind.value == tsk.TRANS_STMT_SAVEPOINT: + elif node.kind == tsk.TRANS_STMT_SAVEPOINT: output.write('SAVEPOINT ') - output.write(node.savepoint_name.value) - elif node.kind.value == tsk.TRANS_STMT_RELEASE: + output.write(node.savepoint_name) + elif node.kind == tsk.TRANS_STMT_RELEASE: output.write('RELEASE ') - output.write(node.savepoint_name.value) - elif node.kind.value == tsk.TRANS_STMT_ROLLBACK_TO: + output.write(node.savepoint_name) + elif node.kind == tsk.TRANS_STMT_ROLLBACK_TO: output.write('ROLLBACK TO SAVEPOINT ') - output.write(node.savepoint_name.value) - elif node.kind.value == tsk.TRANS_STMT_PREPARE: + output.write(node.savepoint_name) + elif node.kind == tsk.TRANS_STMT_PREPARE: output.write('PREPARE TRANSACTION ') - output.write("'%s'" % node.gid.value) - elif node.kind.value == tsk.TRANS_STMT_COMMIT_PREPARED: + output.write("'%s'" % node.gid) + elif node.kind == tsk.TRANS_STMT_COMMIT_PREPARED: output.write('COMMIT PREPARED ') - output.write("'%s'" % node.gid.value) - elif node.kind.value == tsk.TRANS_STMT_ROLLBACK_PREPARED: + output.write("'%s'" % node.gid) + elif node.kind == tsk.TRANS_STMT_ROLLBACK_PREPARED: output.write('ROLLBACK PREPARED ') - output.write("'%s'" % node.gid.value) + output.write("'%s'" % node.gid) @node_printer('TransactionStmt', 'DefElem') def transaction_stmt_def_elem(node, output): - value = node.defname.value + value = node.defname argv = node.arg.val if value == 'transaction_isolation': output.write('ISOLATION LEVEL ') - output.write(argv.val.value.upper()) + output.write(argv.val.upper()) elif value == 'transaction_read_only': output.write('READ ') - if argv.val.value == 0: + if argv.val == 0: output.write('WRITE') else: output.write('ONLY') elif value == 'transaction_deferrable': - if argv.val.value == 0: + if argv.val == 0: output.write('NOT ') output.write('DEFERRABLE') else: # pragma: no cover @@ -1520,15 +1517,15 @@ def truncate_stmt(node, output): @node_printer('TypeCast') def type_cast(node, output): - if node.arg.node_tag == 'A_Const': + if isinstance(node.arg, ast.A_Const): # Special case for boolean constants - if ((node.arg.val.node_tag != 'Null' - and node.arg.val.val.value in ('t', 'f') - and '.'.join(n.val.value for n in node.typeName.names) == 'pg_catalog.bool')): + if ((not isinstance(node.arg.val, ast.Null) + and node.arg.val.val in ('t', 'f') + and '.'.join(n.val for n in node.typeName.names) == 'pg_catalog.bool')): output.write('TRUE' if node.arg.val.val == 't' else 'FALSE') return # Special case for bpchar - elif (('.'.join(n.val.value for n in node.typeName.names) == 'pg_catalog.bpchar' + elif (('.'.join(n.val for n in node.typeName.names) == 'pg_catalog.bpchar' and not node.typeName.typmods)): output.write('char ') output.print_node(node.arg) @@ -1595,7 +1592,7 @@ def type_name(node, output): if node.setof: # FIXME: is this used only by plpgsql? output.writes('SETOF') - name = '.'.join(n.val.value for n in node.names) + name = '.'.join(n.val for n in node.names) suffix = '' if name in system_types: prefix, suffix = system_types[name] @@ -1611,7 +1608,7 @@ def type_name(node, output): else: if node.typmods: if name == 'pg_catalog.interval': - typmod = node.typmods[0].val.val.value + typmod = node.typmods[0].val.val if typmod in interval_ranges: output.swrite(interval_ranges[typmod]) if len(node.typmods) == 2: @@ -1626,7 +1623,7 @@ def type_name(node, output): if node.arrayBounds: for ab in node.arrayBounds: output.write('[') - if ab.val.value >= 0: + if ab.val >= 0: output.print_node(ab) output.write(']') @@ -1795,21 +1792,21 @@ def window_def(node, output): def print_indirection(node, output): for idx, subnode in enumerate(node): - if subnode.node_tag == 'String': + if isinstance(subnode, ast.String): output.write('.') output.print_node(subnode, is_name=True) @node_printer(('OnConflictClause', 'UpdateStmt'), 'ResTarget') def update_stmt_res_target(node, output): - if node.val.node_tag == 'MultiAssignRef': + if isinstance(node.val, ast.MultiAssignRef): if node.val.colno == 1: output.write('( ') output.indent(-2) output.print_name(node.name) if node.indirection: print_indirection(node.indirection, output) - if node.val.colno.value == node.val.ncolumns.value: + if node.val.colno == node.val.ncolumns: output.dedent() output.write(') = ') output.print_node(node.val) @@ -1884,7 +1881,8 @@ def IS_XMLPARSE(self, node, output): # XMLPARSE(text, is_doc, preserve_ws) xml_option_type_printer(node.xmloption, node, output) arg, preserve_ws = node.args output.print_node(arg) - if preserve_ws.arg.val.val.value == 't': + if preserve_ws.arg.val.val == 't': + # FIXME: find a way to get here output.write(' PRESERVE WHITESPACE') output.write(')') @@ -1901,11 +1899,11 @@ def IS_XMLROOT(self, node, output): # XMLROOT(xml, version, standalone) xml, version, standalone = node.args output.print_node(xml) output.write(', version ') - if version.val.node_tag == 'Null': + if isinstance(version.val, ast.Null): output.write('NO VALUE') else: output.print_node(version) - xml_standalone_type_printer(standalone.val.val, node, output) + xml_standalone_type_printer(standalone.val, node, output) output.write(')') def IS_XMLSERIALIZE(self, node, output): # XMLSERIALIZE(is_document, xmlval) diff --git a/pglast/printers/sfuncs.py b/pglast/printers/sfuncs.py index 9e8964e4..5f52a009 100644 --- a/pglast/printers/sfuncs.py +++ b/pglast/printers/sfuncs.py @@ -3,7 +3,7 @@ # :Created: mer 22 nov 2017 08:34:34 CET # :Author: Lele Gaifax # :License: GNU General Public License version 3 or later -# :Copyright: © 2017, 2018, 2021 Lele Gaifax +# :Copyright: © 2017, 2018, 2021, 2022 Lele Gaifax # from . import special_function @@ -49,7 +49,7 @@ def date_part(node, output): ``EXTRACT(field FROM timestamp).``. """ output.write('EXTRACT(') - output.write(node.args[0].val.val.value.upper()) + output.write(node.args[0].val.val.upper()) output.write(' FROM ') output.print_node(node.args[1]) output.write(')') @@ -75,7 +75,7 @@ def normalize(node, output): output.print_node(node.args[0]) if len(node.args) > 1: output.write(', ') - output.write(node.args[1].val.val.value.upper()) + output.write(node.args[1].val.val.upper()) output.write(')') diff --git a/pglast/stream.py b/pglast/stream.py index ff08b404..715aca63 100644 --- a/pglast/stream.py +++ b/pglast/stream.py @@ -11,10 +11,9 @@ from re import match from sys import stderr -from .node import List, Missing, Node, Scalar from . import ast, parse_plpgsql, parse_sql, visitors from .keywords import RESERVED_KEYWORDS, TYPE_FUNC_NAME_KEYWORDS -from .printers import get_printer_for_node_tag, get_special_function +from .printers import get_printer_for_node, get_special_function class OutputStream(StringIO): @@ -156,25 +155,18 @@ def __call__(self, sql, plpgsql=False): :returns: a string with the equivalent SQL obtained by serializing the syntax tree The `sql` statement may be either a ``str`` containing the ``SQL`` in textual form, or - a :class:`.node.Node` instance, or a :class:`.node.List` instance containing - :class:`.node.Node` instances, or a concrete :class:`.ast.Node` instance or a tuple of - them. + a concrete :class:`.ast.Node` instance or a tuple of them. """ from . import ast if isinstance(sql, str): - sql = Node(parse_plpgsql(sql) if plpgsql else parse_sql(sql)) - elif isinstance(sql, Node): - sql = (sql,) + sql = parse_plpgsql(sql) if plpgsql else parse_sql(sql) elif isinstance(sql, ast.Node): - sql = (Node(sql),) - elif isinstance(sql, tuple) and sql and isinstance(sql[0], ast.Node): - sql = (Node(n) for n in sql) - elif not isinstance(sql, List): + sql = (sql,) + elif not (isinstance(sql, tuple) and sql and isinstance(sql[0], ast.Node)): raise ValueError("Unexpected value for 'sql', must be either a string," - " a node.Node instance, a node.List, an ast.Node or tuple of" - " them, got %r" % type(sql)) + " an ast.Node or tuple of them, got %r" % type(sql)) class UpdateAncestors(visitors.Visitor): def visit(self, ancestors, node): @@ -277,19 +269,21 @@ def write_quoted_string(self, s): def _print_scalar(self, node, is_name, is_symbol): "Print the scalar `node`, special-casing string literals." - value = node.value - if is_symbol: + value = node.val if isinstance(node, ast.Value) else node + is_string = isinstance(value, str) + if is_symbol and is_string: self.write(value) - elif is_name: - # The `scalar` represent a name of a column/table/alias: when any of its + elif is_name and is_string: + # The `node` represent a name of a column/table/alias: when any of its # characters is not a lower case letter, a digit or underscore, it must be # double quoted + value = str(value) if ((not match(r'[a-z_][a-z0-9_]*$', value) or value in RESERVED_KEYWORDS or value in TYPE_FUNC_NAME_KEYWORDS)): value = '"%s"' % value.replace('"', '""') self.write(value) - elif isinstance(value, str): # node.parent_node.node_tag == 'String': + elif is_string: self.write_quoted_string(value) else: self.write(str(value)) @@ -317,23 +311,29 @@ def print_comment(self, comment): def print_name(self, nodes, sep='.'): "Helper method, execute :meth:`print_node` or :meth:`print_list` as needed." - if isinstance(nodes, (List, list)): + if isinstance(nodes, (list, tuple)): self.print_list(nodes, sep, standalone_items=False, are_names=True) - else: + elif isinstance(nodes, ast.Node): self.print_node(nodes, is_name=True) + else: + self._print_scalar(nodes, is_name=True, is_symbol=False) + self.separator() def print_symbol(self, nodes, sep='.'): "Helper method, execute :meth:`print_node` or :meth:`print_list` as needed." - if isinstance(nodes, (List, list)): + if isinstance(nodes, (list, tuple)): self.print_list(nodes, sep, standalone_items=False, are_names=True, is_symbol=True) - else: + elif isinstance(nodes, ast.Node): self.print_node(nodes, is_name=True, is_symbol=True) + else: + self._print_scalar(str(nodes), is_name=True, is_symbol=True) + self.separator() def print_node(self, node, is_name=False, is_symbol=False): """Lookup the specific printer for the given `node` and execute it. - :param node: an instance of :class:`~.node.Node` or :class:`~.node.Scalar` + :param node: an instance of :class:`~.ast.Node` :param bool is_name: whether this is a *name* of something, that may need to be double quoted :param bool is_symbol: @@ -346,38 +346,38 @@ def print_node(self, node, is_name=False, is_symbol=False): elif hasattr(node, 'stmt_location'): node_location = getattr(node, 'stmt_location') else: - node_location = Missing - if node_location is not Missing: + node_location = None + if node_location is not None: nextc = self.comments[0] - if nextc.location <= node_location.value: + if nextc.location <= node_location: self.print_comment(self.comments.pop(0)) while self.comments and self.comments[0].continue_previous: self.print_comment(self.comments.pop(0)) - if isinstance(node, Scalar): - self._print_scalar(node, is_name, is_symbol) - elif is_name and isinstance(node, (List, list)): + if is_name and isinstance(node, (list, tuple)): self.print_list(node, '.', standalone_items=False, are_names=True) - else: - parent_node_tag = node.parent_node and node.parent_node.node_tag - printer = get_printer_for_node_tag(parent_node_tag, node.node_tag) - if is_name and node.node_tag == 'String': + elif isinstance(node, ast.Node): + printer = get_printer_for_node(node) + if is_name and isinstance(node, ast.String): printer(node, self, is_name=is_name, is_symbol=is_symbol) else: printer(node, self) + else: + self._print_scalar(node, is_name, is_symbol) + self.separator() def _is_pg_catalog_func(self, items): return ( self.remove_pg_catalog_from_functions and len(items) > 1 - and isinstance(items, List) - and items.parent_attribute == 'funcname' - and items[0].val.value == 'pg_catalog' + and isinstance(items, (list, tuple)) + and items[0].ancestors.parent.member == 'funcname' + and items[0].val == 'pg_catalog' # The list contains all functions that cannot be found without an # explicit pg_catalog schema. ie: # position(a,b) is invalid but pg_catalog.position(a,b) is fine - and items[1].val.value not in ('position', 'xmlexists') + and items[1].val not in ('position', 'xmlexists') ) def _print_items(self, items, sep, newline, are_names=False, is_symbol=False): @@ -401,13 +401,16 @@ def _print_items(self, items, sep, newline, are_names=False, is_symbol=False): self.write(sep) if sep != '.': self.write(' ') - self.print_node(item, is_name=are_names, is_symbol=is_symbol and idx == last) + if item is None: + self.write('None') + else: + self.print_node(item, is_name=are_names, is_symbol=is_symbol and idx == last) def print_list(self, nodes, sep=',', relative_indent=None, standalone_items=None, are_names=False, is_symbol=False): """Execute :meth:`print_node` on all the `nodes`, separating them with `sep`. - :param nodes: a sequence of :class:`~.node.Node` instances or a single List node + :param nodes: a sequence of :class:`~.ast.Node` instances :param str sep: the separator between them :param bool relative_indent: if given, the relative amount of indentation to apply before the first item, by @@ -421,12 +424,6 @@ def print_list(self, nodes, sep=',', relative_indent=None, standalone_items=None case the last one must be printed verbatim (e.g. ``"MySchema".===``) """ - if isinstance(nodes, Node): # pragma: no cover - if nodes.node_tag != 'List': - raise ValueError("Unexpected value for 'nodes', must be either a List instance" - " or a sequence of Node instances, got %r" % type(nodes)) - nodes = nodes.items - if relative_indent is None: if are_names or is_symbol: relative_indent = 0 @@ -436,9 +433,8 @@ def print_list(self, nodes, sep=',', relative_indent=None, standalone_items=None else 0) if standalone_items is None: - standalone_items = not all(isinstance(n, Node) - and n.node_tag in ('A_Const', 'ColumnRef', - 'SetToDefault', 'RangeVar') + standalone_items = not all(isinstance(n, (ast.A_Const, ast.ColumnRef, + ast.SetToDefault, ast.RangeVar)) for n in nodes) with self.push_indent(relative_indent): @@ -450,7 +446,7 @@ def print_lists(self, lists, sep=',', relative_indent=None, standalone_items=Non sublist_relative_indent=None): """Execute :meth:`print_list` on all the `lists` items. - :param lists: a sequence of sequences of :class:`~.node.Node` instances + :param lists: a sequence of sequences of :class:`~.ast.Node` instances :param str sep: passed as is to :meth:`print_list` :param bool relative_indent: passed as is to :meth:`print_list` :param bool standalone_items: passed as is to :meth:`print_list` @@ -594,7 +590,7 @@ def print_list(self, nodes, sep=',', relative_indent=None, standalone_items=None are_names=False, is_symbol=False): """Execute :meth:`print_node` on all the `nodes`, separating them with `sep`. - :param nodes: a sequence of :class:`~.node.Node` instances + :param nodes: a sequence of :class:`~.ast.Node` instances :param str sep: the separator between them :param bool relative_indent: if given, the relative amount of indentation to apply before the first item, by @@ -608,12 +604,6 @@ def print_list(self, nodes, sep=',', relative_indent=None, standalone_items=None must be printed verbatim (such as ``"MySchema".===``) """ - if isinstance(nodes, Node): # pragma: no cover - if nodes.node_tag != 'List': - raise ValueError("Unexpected value for 'nodes', must be either a List instance" - " or a sequence of Node instances, got %r" % type(nodes)) - nodes = nodes.items - if standalone_items is None: clm = self.compact_lists_margin if clm is not None and clm > 0: @@ -622,10 +612,9 @@ def print_list(self, nodes, sep=',', relative_indent=None, standalone_items=None self.write(rawlist) return - standalone_items = not all( - (isinstance(n, Node) - and n.node_tag in ('A_Const', 'ColumnRef', 'SetToDefault', 'RangeVar')) - for n in nodes) + standalone_items = not all(isinstance(n, (ast.A_Const, ast.ColumnRef, + ast.SetToDefault, ast.RangeVar)) + for n in nodes) if (((sep != ',' or not self.comma_at_eoln) and len(nodes) > 1 diff --git a/tests/test_node.py b/tests/test_node.py deleted file mode 100644 index 0726944d..00000000 --- a/tests/test_node.py +++ /dev/null @@ -1,92 +0,0 @@ -# -*- coding: utf-8 -*- -# :Project: pglast -- Test the node.py module -# :Created: ven 04 ago 2017 09:31:57 CEST -# :Author: Lele Gaifax -# :License: GNU General Public License version 3 or later -# :Copyright: © 2017, 2018, 2021 Lele Gaifax -# - -import pytest - -from pglast import ast, Missing, Node, parse_sql -from pglast.node import Base, List - - -def test_bad_base_construction(): - pytest.raises(ValueError, Base, {}, parent=1.0) - pytest.raises(ValueError, Base, [], name=1.0) - pytest.raises(ValueError, Base, set()) - - -def test_basic(): - root = Node(parse_sql('SELECT 1')) - assert root.parent_node is None - assert root.parent_attribute is None - assert isinstance(root, List) - assert len(root) == 1 - assert repr(root) == '[1*{RawStmt}]' - with pytest.raises(AttributeError): - root.not_there - - rawstmt = root[0] - assert rawstmt != root - assert rawstmt.node_tag == 'RawStmt' - assert isinstance(rawstmt.ast_node, ast.RawStmt) - assert rawstmt.parent_node is None - assert rawstmt.parent_attribute == (None, 0) - assert repr(rawstmt) == '{RawStmt}' - assert rawstmt.attribute_names == ('stmt', 'stmt_location', 'stmt_len') - with pytest.raises(ValueError): - rawstmt[1.0] - - stmt = rawstmt.stmt - assert stmt.node_tag == 'SelectStmt' - assert stmt.parent_node is rawstmt - assert stmt.parent_attribute == 'stmt' - assert rawstmt[stmt.parent_attribute] == stmt - assert stmt.whereClause is Missing - assert not stmt.whereClause - - -def test_scalar(): - constraint = ast.Constraint() - constraint.fk_matchtype = '\00' - node = Node(constraint) - assert not node.fk_matchtype - assert node.fk_matchtype != 1 - - -def test_traverse(): - root = Node(parse_sql('SELECT a, b, c FROM sometable')) - assert [repr(n) for n in root.traverse()] == [ - "{RawStmt}", - "{SelectStmt}", - "", - "{RangeVar}", - "", - "<20>", - "<'sometable'>", - "<'p'>", - "", - "", - "{ResTarget}", - "<7>", - "{ColumnRef}", - "{String}", - "<'a'>", - "<7>", - "{ResTarget}", - "<10>", - "{ColumnRef}", - "{String}", - "<'b'>", - "<10>", - "{ResTarget}", - "<13>", - "{ColumnRef}", - "{String}", - "<'c'>", - "<13>", - "<0>", - "<0>", - ] diff --git a/tests/test_printers.py b/tests/test_printers.py index a743b250..8e34e57d 100644 --- a/tests/test_printers.py +++ b/tests/test_printers.py @@ -3,22 +3,21 @@ # :Created: sab 05 ago 2017 10:31:23 CEST # :Author: Lele Gaifax # :License: GNU General Public License version 3 or later -# :Copyright: © 2017, 2018, 2019, 2021 Lele Gaifax +# :Copyright: © 2017, 2018, 2019, 2021, 2022 Lele Gaifax # import warnings import pytest -from pglast import enums, prettify -from pglast.node import Missing, Scalar +from pglast import ast, enums, prettify from pglast.printers import IntEnumPrinter, NODE_PRINTERS, PrinterAlreadyPresentError -from pglast.printers import get_printer_for_node_tag, node_printer +from pglast.printers import get_printer_for_node, node_printer def test_registry(): - with pytest.raises(NotImplementedError): - get_printer_for_node_tag(None, 'non_existing') + with pytest.raises(ValueError): + get_printer_for_node(None) with pytest.raises(ValueError): @node_printer() @@ -35,54 +34,11 @@ def invalid_tag(node, output): def too_many_tags1(node, output): pass - with pytest.raises(ValueError): - @node_printer('one', 'two', 'three', check_tags=False) - def too_many_tags2(node, output): - pass - - try: - @node_printer('test_tag1', check_tags=False) - def tag1(node, output): - pass - - assert get_printer_for_node_tag(None, 'test_tag1') is tag1 - - with pytest.raises(PrinterAlreadyPresentError): - @node_printer('test_tag1', check_tags=False) - def tag3(node, output): - pass - - @node_printer('test_tag1', override=True, check_tags=False) - def tag1_bis(node, output): - pass - - assert get_printer_for_node_tag(None, 'test_tag1') is tag1_bis - - @node_printer('test_tag_3', check_tags=False) - def generic_tag3(node, output): - pass - - @node_printer('test_tag_1', 'test_tag_3', check_tags=False) - def specific_tag3(node, output): - pass - - @node_printer(('test_tag_a', 'test_tag_b'), 'test_tag_3', check_tags=False) - def specific_tag4(node, output): - pass - - assert get_printer_for_node_tag(None, 'test_tag_3') is generic_tag3 - assert get_printer_for_node_tag('Foo', 'test_tag_3') is generic_tag3 - assert get_printer_for_node_tag('test_tag_1', 'test_tag_3') is specific_tag3 - assert get_printer_for_node_tag('test_tag_a', 'test_tag_3') is specific_tag4 - assert get_printer_for_node_tag('test_tag_b', 'test_tag_3') is specific_tag4 - finally: - NODE_PRINTERS.pop('test_tag1', None) - def test_prettify_safety_belt(): - raw_stmt_printer = NODE_PRINTERS.pop('RawStmt', None) + raw_stmt_printer = NODE_PRINTERS.pop(ast.RawStmt, None) try: - @node_printer('RawStmt') + @node_printer(ast.RawStmt) def raw_stmt_1(node, output): output.write('Yeah') @@ -94,7 +50,7 @@ def raw_stmt_1(node, output): assert output == 'select 42' assert 'Detected a bug' in str(w[0].message) - @node_printer('RawStmt', override=True) + @node_printer(ast.RawStmt, override=True) def raw_stmt_2(node, output): output.write('select 1') @@ -107,9 +63,9 @@ def raw_stmt_2(node, output): assert 'Detected a non-cosmetic difference' in str(w[0].message) finally: if raw_stmt_printer is not None: - NODE_PRINTERS['RawStmt'] = raw_stmt_printer + NODE_PRINTERS[ast.RawStmt] = raw_stmt_printer else: - NODE_PRINTERS.pop('RawStmt', None) + NODE_PRINTERS.pop(ast.RawStmt, None) def test_int_enum_printer(): @@ -127,24 +83,12 @@ def LockWaitBlock(self, node, output): with pytest.raises(NotImplementedError): lwp('LockWaitError', object(), result) - with pytest.raises(ValueError): - lwp(None, object(), result) - with pytest.raises(ValueError): lwp('FooBar', object(), result) - lwp(Scalar('LockWaitBlock'), object(), result) + lwp(None, object(), result) assert result == ['block']*2 - with pytest.raises(ValueError): - lwp(Scalar('FooBar'), object(), result) - - lwp(Scalar(0), object(), result) - assert result == ['block']*3 - - lwp(Missing, object(), result) - assert result == ['block']*4 - def test_not_int_enum_printer(): class NotIntEnum(IntEnumPrinter): diff --git a/tests/test_printers_roundtrip.py b/tests/test_printers_roundtrip.py index 72c93d19..e8110e04 100644 --- a/tests/test_printers_roundtrip.py +++ b/tests/test_printers_roundtrip.py @@ -3,7 +3,7 @@ # :Created: dom 17 mar 2019 09:24:11 CET # :Author: Lele Gaifax # :License: GNU General Public License version 3 or later -# :Copyright: © 2019, 2021 Lele Gaifax +# :Copyright: © 2019, 2021, 2022 Lele Gaifax # from pathlib import Path @@ -11,7 +11,7 @@ import pytest -from pglast import Node, parse_sql, split +from pglast import parse_sql, split from pglast.parser import ParseError from pglast.stream import RawStream, IndentedStream import pglast.printers # noqa @@ -47,7 +47,7 @@ def test_printers_roundtrip(src, lineno, statement): except: # noqa raise RuntimeError("%s:%d:Could not parse %r" % (src, lineno, statement)) - serialized = RawStream()(Node(orig_ast)) + serialized = RawStream()(orig_ast) try: serialized_ast = parse_sql(serialized) except: # noqa @@ -55,7 +55,7 @@ def test_printers_roundtrip(src, lineno, statement): assert orig_ast == serialized_ast, "%s:%s:%r != %r" % (src, lineno, statement, serialized) - indented = IndentedStream()(Node(orig_ast)) + indented = IndentedStream()(orig_ast) try: indented_ast = parse_sql(indented) except: # noqa @@ -82,7 +82,7 @@ def test_stream_call_with_single_node(src, lineno, statement): for rawstmt in parsed: stmt = rawstmt.stmt try: - RawStream()(Node(stmt)) + RawStream()(stmt) except Exception: raise AssertionError('Could not serialize single statement %r' % stmt) @@ -129,7 +129,7 @@ def test_pg_regress_corpus(filename): % (trimmed_stmt, rel_src, lineno, e)) try: - serialized = RawStream()(Node(orig_ast)) + serialized = RawStream()(orig_ast) except NotImplementedError as e: raise NotImplementedError("Statement “%s” from %s at line %d, could not reprint: %s" % (trimmed_stmt, rel_src, lineno, e)) diff --git a/tests/test_stream.py b/tests/test_stream.py index a65f35ec..8772e261 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -3,12 +3,12 @@ # :Created: sab 05 ago 2017 10:31:23 CEST # :Author: Lele Gaifax # :License: GNU General Public License version 3 or later -# :Copyright: © 2017, 2018, 2019, 2021 Lele Gaifax +# :Copyright: © 2017, 2018, 2019, 2021, 2022 Lele Gaifax # import pytest -from pglast import ast, node, parse_sql +from pglast import ast, parse_sql from pglast.printers import NODE_PRINTERS, PrinterAlreadyPresentError, SPECIAL_FUNCTIONS from pglast.printers import node_printer, special_function from pglast.stream import IndentedStream, OutputStream, RawStream @@ -26,9 +26,9 @@ def test_output_stream(): def test_raw_stream_with_sql(): - raw_stmt_printer = NODE_PRINTERS.pop('RawStmt', None) + raw_stmt_printer = NODE_PRINTERS.pop(ast.RawStmt, None) try: - @node_printer('RawStmt') + @node_printer(ast.RawStmt) def raw_stmt(node, output): output.write('Yeah') @@ -37,33 +37,15 @@ def raw_stmt(node, output): assert result == 'Yeah; Yeah' finally: if raw_stmt_printer is not None: - NODE_PRINTERS['RawStmt'] = raw_stmt_printer + NODE_PRINTERS[ast.RawStmt] = raw_stmt_printer else: - NODE_PRINTERS.pop('RawStmt', None) + NODE_PRINTERS.pop(ast.RawStmt, None) -def test_raw_stream_with_node(): - raw_stmt_printer = NODE_PRINTERS.pop('RawStmt', None) +def test_raw_stream(): + raw_stmt_printer = NODE_PRINTERS.pop(ast.RawStmt, None) try: - @node_printer('RawStmt') - def raw_stmt(node, output): - output.write('Yeah') - - root = parse_sql('SELECT 1') - output = RawStream() - result = output(node.Node(root)) - assert result == 'Yeah' - finally: - if raw_stmt_printer is not None: - NODE_PRINTERS['RawStmt'] = raw_stmt_printer - else: - NODE_PRINTERS.pop('RawStmt', None) - - -def test_raw_stream_with_ast_node(): - raw_stmt_printer = NODE_PRINTERS.pop('RawStmt', None) - try: - @node_printer('RawStmt') + @node_printer(ast.RawStmt) def raw_stmt(node, output): output.write('Yeah') @@ -71,15 +53,11 @@ def raw_stmt(node, output): output = RawStream() result = output(root) assert result == 'Yeah' - - output = RawStream() - result = output(root[0]) - assert result == 'Yeah' finally: if raw_stmt_printer is not None: - NODE_PRINTERS['RawStmt'] = raw_stmt_printer + NODE_PRINTERS[ast.RawStmt] = raw_stmt_printer else: - NODE_PRINTERS.pop('RawStmt', None) + NODE_PRINTERS.pop(ast.RawStmt, None) def test_raw_stream_invalid_call(): @@ -88,9 +66,9 @@ def test_raw_stream_invalid_call(): def test_indented_stream_with_sql(): - raw_stmt_printer = NODE_PRINTERS.pop('RawStmt', None) + raw_stmt_printer = NODE_PRINTERS.pop(ast.RawStmt, None) try: - @node_printer('RawStmt') + @node_printer(ast.RawStmt) def raw_stmt(node, output): output.write('Yeah') @@ -103,16 +81,16 @@ def raw_stmt(node, output): assert result == 'Yeah;\nYeah' finally: if raw_stmt_printer is not None: - NODE_PRINTERS['RawStmt'] = raw_stmt_printer + NODE_PRINTERS[ast.RawStmt] = raw_stmt_printer else: - NODE_PRINTERS.pop('RawStmt', None) + NODE_PRINTERS.pop(ast.RawStmt, None) def test_separate_statements(): """Separate statements by ``separate_statements`` (int) newlines.""" - raw_stmt_printer = NODE_PRINTERS.pop('RawStmt', None) + raw_stmt_printer = NODE_PRINTERS.pop(ast.RawStmt, None) try: - @node_printer('RawStmt') + @node_printer(ast.RawStmt) def raw_stmt(node, output): output.write('Yeah') @@ -121,9 +99,9 @@ def raw_stmt(node, output): assert result == 'Yeah;\n\n\nYeah' finally: if raw_stmt_printer is not None: - NODE_PRINTERS['RawStmt'] = raw_stmt_printer + NODE_PRINTERS[ast.RawStmt] = raw_stmt_printer else: - NODE_PRINTERS.pop('RawStmt', None) + NODE_PRINTERS.pop(ast.RawStmt, None) def test_special_function():