diff --git a/docs/api.rst b/docs/api.rst index 9b8b93e..b15b5b2 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -3,7 +3,7 @@ .. :Created: gio 10 ago 2017 10:14:17 CEST .. :Author: Lele Gaifax .. :License: GNU General Public License version 3 or later -.. :Copyright: © 2017, 2018, 2019, 2021 Lele Gaifax +.. :Copyright: © 2017, 2018, 2019, 2021, 2022 Lele Gaifax .. =================== @@ -23,7 +23,6 @@ This chapter briefly explains some implementation detail. ast enums keywords - node stream printers sfuncs diff --git a/docs/node.rst b/docs/node.rst deleted file mode 100644 index bed369e..0000000 --- a/docs/node.rst +++ /dev/null @@ -1,38 +0,0 @@ -.. -*- coding: utf-8 -*- -.. :Project: pglast -- Node documentation -.. :Created: gio 10 ago 2017 10:28:36 CEST -.. :Author: Lele Gaifax -.. :License: GNU General Public License version 3 or later -.. :Copyright: © 2017, 2018, 2021 Lele Gaifax -.. - -===================================================================== - :mod:`pglast.node` --- The higher level interface to the parse tree -===================================================================== - -This module implements a set of classes that make it easier to deal with the :mod:`pglast.ast` -nodes. - -The :class:`pglast.node.Node` wraps a single :class:`pglast.ast.Node` adding a reference to the -parent node; the class:`pglast.node.List` wraps a sequence of them and -:class:`pglast.node.Scalar` represents plain values such a *strings*, *integers*, *booleans* or -*none*. - -Every node is identified by a *tag*, a string label that characterizes its content, exposed as -a set of *attributes* as well as with a dictionary-like interface (technically -:class:`pglast.node.Node` implements both a ``__getattr__`` method and a ``__getitem__`` -method, while underlying :class:`pglast.ast.Node` only the former). When asked for an -attribute, the node returns an instance of the base classes, i.e. another ``Node``, or a -``List`` or a ``Scalar``, depending on the data type of that item. When the node does not -contain the requested attribute it returns a singleton :data:`pglast.node.Missing` marker -instance. - -A ``List`` wraps a plain Python ``list`` and may contains a sequence of ``Node`` instances, or -in some cases other sub-lists, that can be accessed with the usual syntax, or iterated. - -Finally, a ``Scalar`` carries a single value of some scalar type, accessible through its -``value`` attribute. - -.. automodule:: pglast.node - :synopsis: The higher level interface to the parse tree - :members: diff --git a/docs/printers.rst b/docs/printers.rst index d50b93f..60ec335 100644 --- a/docs/printers.rst +++ b/docs/printers.rst @@ -3,7 +3,7 @@ .. :Created: gio 10 ago 2017 13:23:18 CEST .. :Author: Lele Gaifax .. :License: GNU General Public License version 3 or later -.. :Copyright: © 2017, 2018, 2021 Lele Gaifax +.. :Copyright: © 2017, 2018, 2021, 2022 Lele Gaifax .. ========================================================== @@ -22,7 +22,7 @@ associated :class:`~.node.Node` will be serialized. .. autoexception:: PrinterAlreadyPresentError -.. autofunction:: get_printer_for_node_tag +.. autofunction:: get_printer_for_node .. autofunction:: node_printer diff --git a/docs/usage.rst b/docs/usage.rst index 5ba556b..2d2842d 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -3,7 +3,7 @@ .. :Created: gio 10 ago 2017 10:06:38 CEST .. :Author: Lele Gaifax .. :License: GNU General Public License version 3 or later -.. :Copyright: © 2017, 2018, 2019, 2021 Lele Gaifax +.. :Copyright: © 2017, 2018, 2019, 2021, 2022 Lele Gaifax .. .. _usage: @@ -15,7 +15,7 @@ Here are some example of how the module can be used. --------- -Low level +AST level --------- The lowest level is a Python wrapper around each *parse node* returned by the ``PostgreSQL`` @@ -181,89 +181,6 @@ This basically means that you can reconstruct a syntax tree from the result of c >>> clone == stmt True ------------- -Medium level ------------- - -Parse an ``SQL`` statement and get its *AST* root node -====================================================== - -.. doctest:: - - >>> from pglast import Node - >>> root = Node(parse_sql('SELECT foo FROM bar')) - >>> print(root) - None=[1*{RawStmt}] - -Get a particular node -===================== - -.. doctest:: - - >>> from_clause = root[0].stmt.fromClause - >>> print(from_clause) - fromClause=[1*{RangeVar}] - -Obtain some information about a node -==================================== - -.. doctest:: - - >>> range_var = from_clause[0] - >>> print(range_var.node_tag) - RangeVar - >>> print(range_var.attribute_names) - ('catalogname', 'schemaname', 'relname', 'inh', 'relpersistence', 'alias', 'location') - >>> print(range_var.parent_node) - stmt={SelectStmt} - -Iterate over nodes -================== - -.. doctest:: - - >>> for a in from_clause: - ... print(a) - ... for b in a: - ... print(b) - ... - fromClause[0]={RangeVar} - inh= - location=<16> - relname=<'bar'> - relpersistence=<'p'> - -Recursively :meth:`traverse ` the parse tree -======================================================================= - -.. doctest:: - - >>> for node in root.traverse(): - ... print(node) - ... - None[0]={RawStmt} - stmt={SelectStmt} - all= - fromClause[0]={RangeVar} - inh= - location=<16> - relname=<'bar'> - relpersistence=<'p'> - limitOption= - op= - targetList[0]={ResTarget} - location=<7> - val={ColumnRef} - fields[0]={String} - val=<'foo'> - location=<7> - stmt_len=<0> - stmt_location=<0> - -As you can see, the ``repr``\ esentation of each value is mnemonic: ``{some_tag}`` means a -``Node`` with tag ``some_tag``, ``[X*{some_tag}]`` is a ``List`` containing `X` nodes of that -particular kind\ [*]_ and ```` is a ``Scalar``. - Programmatically :func:`reformat ` a ``SQL`` statement ======================================================================= @@ -318,7 +235,7 @@ that extends :class:`pglast.stream.RawStream` adding a bit a aesthetic sense. .. doctest:: - >>> print(IndentedStream()(root)) + >>> print(IndentedStream()('select foo from bar')) SELECT foo FROM bar @@ -400,7 +317,7 @@ Customize a :func:`node printer ` >>> from pglast.printers import node_printer >>> @node_printer('ParamRef', override=True) ... def replace_param_ref(node, output): - ... output.write(repr(args[node.number.value - 1])) + ... output.write(repr(args[node.number - 1])) ... >>> args = ['Hello', 'Ciao'] >>> print(prettify(sql, safety_belt=False)) @@ -542,11 +459,6 @@ Preserve comments carry their exact location, so it is not possible to differentiate between ``SELECT * /*comment*/ FROM foo`` and ``SELECT * FROM /*comment*/ foo``. ---- - -.. [*] This is an approximation, because in principle a list can contain different kinds of - nodes, or even sub-lists in some cases: the ``List`` representation arbitrarily shows - the tag of the first object. Functions vs SQL syntax ======================= diff --git a/pglast/__init__.py b/pglast/__init__.py index 15ec191..55b6e10 100644 --- a/pglast/__init__.py +++ b/pglast/__init__.py @@ -3,12 +3,13 @@ # :Created: mer 02 ago 2017 15:11:02 CEST # :Author: Lele Gaifax # :License: GNU General Public License version 3 or later -# :Copyright: © 2017, 2018, 2019, 2021 Lele Gaifax +# :Copyright: © 2017, 2018, 2019, 2021, 2022 Lele Gaifax # +from collections import namedtuple + from . import enums from .error import Error -from .node import Comment, Missing, Node try: from .parser import fingerprint, get_postgresql_version, parse_sql, scan, split except ModuleNotFoundError: # pragma: no cover @@ -31,6 +32,10 @@ def parse_plpgsql(statement): return loads(parse_plpgsql_json(statement)) +Comment = namedtuple('Comment', ('location', 'text', 'at_start_of_line', 'continue_previous')) +"A structure to carry information about a single SQL comment." + + def _extract_comments(statement): lines = [] lofs = 0 @@ -81,7 +86,7 @@ def prettify(statement, safety_belt=False, preserve_comments=False, **options): options['comments'] = _extract_comments(statement) orig_pt = parse_sql(statement) - prettified = IndentedStream(**options)(Node(orig_pt)) + prettified = IndentedStream(**options)(orig_pt) if safety_belt: from logging import getLogger import warnings @@ -109,5 +114,5 @@ def prettify(statement, safety_belt=False, preserve_comments=False, **options): return prettified -__all__ = ('Error', 'Missing', 'Node', 'enums', 'fingerprint', 'get_postgresql_version', +__all__ = ('Error', 'enums', 'fingerprint', 'get_postgresql_version', 'parse_plpgsql', 'parse_sql', 'prettify', 'split') diff --git a/pglast/node.py b/pglast/node.py deleted file mode 100644 index 43b022a..0000000 --- a/pglast/node.py +++ /dev/null @@ -1,300 +0,0 @@ -# -*- coding: utf-8 -*- -# :Project: pglast -- Generic Node implementation -# :Created: mer 02 ago 2017 15:44:14 CEST -# :Author: Lele Gaifax -# :License: GNU General Public License version 3 or later -# :Copyright: © 2017, 2018, 2019, 2021, 2022 Lele Gaifax -# - -from collections import namedtuple -from decimal import Decimal -from enum import Enum - -from . import ast - - -class Missing: - def __bool__(self): - return False - - def __repr__(self): # pragma: no cover - return "MISSING" - - def __iter__(self): - return self - - def __next__(self): - raise StopIteration() - - -Missing = Missing() -"Singleton returned when trying to get a non-existing attribute out of a :class:`Node`." - - -Comment = namedtuple('Comment', ('location', 'text', 'at_start_of_line', 'continue_previous')) -"A structure to carry information about a single SQL comment." - - -class Base: - """Common base class. - - :param details: the *parse tree* - :type parent: ``None`` or :class:`Node` instance - :param parent: ``None`` to indicate that the node is the *root* of the parse tree, - otherwise it is the immediate parent of the new node - :type name: str or tuple - :param name: the name of the attribute in the `parent` node that *points* to this one; - it may be a tuple (name, position) when ``parent[name]`` is actually a list of - nodes - - Its main purpose is to create the right kind of instance, depending on the type of the - `details` argument passed to the constructor: a :class:`ast.Node ` - produces a :class:`Node` instance, a ``list`` or ``tuple`` produces a :class:`List` - instance, everything else a :class:`Scalar` instance. - - """ - - __slots__ = ('_parent_node', '_parent_attribute') - - def __new__(cls, details, parent=None, name=None): - if parent is not None and not isinstance(parent, Node): - raise ValueError("Unexpected value for 'parent', must be either None" - " or a Node instance, got %r" % type(parent)) - if name is not None and not isinstance(name, (str, tuple)): - raise ValueError("Unexpected value for 'name', must be either None," - " a string or a tuple, got %r" % type(name)) - if isinstance(details, (list, tuple)): - self = super().__new__(List) - elif isinstance(details, ast.Node): - self = super().__new__(Node) - else: - self = super().__new__(Scalar) - self._init(details, parent, name) - return self - - def _init(self, parent, name): - self._parent_node = parent - self._parent_attribute = name - - def __str__(self): # pragma: no cover - aname = self._parent_attribute - if isinstance(aname, tuple): - aname = '%s[%d]' % aname - return '%s=%r' % (aname, self) - - @property - def parent_node(self): - "The parent :class:`Node` of this element." - - return self._parent_node - - @property - def parent_attribute(self): - "The *attribute* in the parent :class:`Node` referencing this element." - - return self._parent_attribute - - -class List(Base): - """Represent a sequence of :class:`Node` instances. - - :type items: list - :param items: a list of items, usually :class:`Node` instances - :type parent: ``None`` or :class:`Node` instance - :param parent: ``None`` to indicate that the node is the *root* of the parse tree, - otherwise it is the immediate parent of the new node - :type name: str or tuple - :param name: the name of the attribute in the `parent` node that *points* to this one; - it may be a tuple (name, position) when ``parent[name]`` is actually a list of - nodes - """ - - __slots__ = ('_items',) - - def _init(self, items, parent, name): - if not isinstance(items, (list, tuple)) or not items: # pragma: no cover - raise ValueError("Unexpected value for 'items', must be a non empty tuple or list," - " got %r" % type(items)) - super()._init(parent, name) - self._items = items - - def __len__(self): - return len(self._items) - - def __bool__(self): - return len(self) > 0 - - def __repr__(self): - if not self: # pragma: no cover - return '[]' - # There's no guarantee that a list contains the same kind of objects, - # so picking the first is rather arbitrary but serves the purpose, as - # this is primarily an helper for investigating the internals of a tree. - count = len(self) - pivot = self[0] - return '[%d*%r]' % (count, pivot) - - def __iter__(self): - pnode = self.parent_node - aname = self.parent_attribute - for idx, item in enumerate(self._items): - yield Base(item, pnode, (aname, idx)) - - def __getitem__(self, index): - return Base(self._items[index], self.parent_node, (self.parent_attribute, index)) - - def __eq__(self, other): - cls = type(self) - if cls is type(other): - return self._items == other._items - return False - - @property - def string_value(self): - if len(self) != 1: # pragma: no cover - raise TypeError('%r does not contain a single String node' % self) - node = self[0] - if node.node_tag != 'String': # pragma: no cover - raise TypeError('%r does not contain a single String node' % self) - return node.val.value - - def traverse(self): - "A generator that recursively traverse all the items in the list." - - for item in self: - yield from item.traverse() - - -class Node(Base): - """Represent a single entry in a *parse tree*. - - :type details: :class:`.ast.Node` - :param details: the *parse tree* of the node - :type parent: ``None`` or :class:`Node` instance - :param parent: ``None`` to indicate that the node is the *root* of the parse tree, - otherwise it is the immediate parent of the new node - :type name: str or tuple - :param name: the name of the attribute in the `parent` node that *points* to this one; - it may be a tuple (name, position) when ``parent[name]`` is actually a list of - nodes - """ - - __slots__ = ('node_tag', 'ast_node') - - def _init(self, details, parent=None, name=None): - if not isinstance(details, ast.Node): # pragma: no cover - raise ValueError("Unexpected value for 'details', must be a ast.Node") - super()._init(parent, name) - self.node_tag = details.__class__.__name__ - self.ast_node = details - - def __getattr__(self, attr): - value = getattr(self.ast_node, attr) - if value is None: - return Missing - else: - return Base(value, self, attr) - - def __repr__(self): - return '{%s}' % self.node_tag - - def __getitem__(self, attr): - if isinstance(attr, tuple): # pragma: no cover - attr, index = attr - return self[attr][index] - elif isinstance(attr, str): - return getattr(self, attr) - else: - raise ValueError('Wrong key type %r, must be str or tuple' - % type(attr).__name__) - - def __iter__(self): - node = self.ast_node - for attr in sorted(node.__slots__): - value = getattr(node, attr) - if value is not None: - yield Base(value, self, attr) - - def __eq__(self, other): - cls = type(self) - if cls is type(other): - return self.ast_node == other.ast_node - return False - - @property - def attribute_names(self): - "The names of the attributes present in the parse tree of the node." - - return tuple(self.ast_node) - - def traverse(self): - "A generator that recursively traverse all attributes of the node." - - yield self - for item in self: - yield from item.traverse() - - -class Scalar(Base): - "Represent a single scalar value." - - __slots__ = ('_value',) - - def _init(self, value, parent, name): - if value is not None and not isinstance(value, (bool, float, int, str, Decimal, - ast.Value)): - raise ValueError("Unexpected value for 'value', must be either None or a" - " bool|float|int|str|Decimal instance, got %r" % type(value)) - super()._init(parent, name) - self._value = value - - def __and__(self, other): - value = self._value - if isinstance(value, int) and isinstance(other, int): - return value & other - else: # pragma: no cover - raise ValueError("Wrong operands for __and__: %r & %r" - % (type(value), type(other))) - - def __bool__(self): - value = self._value - if value is None: - return False - if isinstance(value, str): - if len(value) == 0: - return False - elif len(value) == 1: - return value[0] != '\x00' - else: - return True - elif isinstance(value, bool): - return value - return True - - def __hash__(self): - return hash(self._value) - - def __eq__(self, other): - value = self._value - if isinstance(other, Enum): - return value == other - elif isinstance(other, type(value)): - return value == other - else: - cls = type(self) - if cls is type(other): - return value == other._value - return False - - def __repr__(self): # pragma: no cover - if isinstance(self._value, Enum): - return repr(self._value) - else: - return '<%r>' % self._value - - def traverse(self): - yield self - - @property - def value(self): - return self._value diff --git a/pglast/printers/__init__.py b/pglast/printers/__init__.py index c0b8287..ba181df 100644 --- a/pglast/printers/__init__.py +++ b/pglast/printers/__init__.py @@ -3,7 +3,7 @@ # :Created: sab 05 ago 2017 16:33:14 CEST # :Author: Lele Gaifax # :License: GNU General Public License version 3 or later -# :Copyright: © 2017, 2018, 2020, 2021 Lele Gaifax +# :Copyright: © 2017, 2018, 2020, 2021, 2022 Lele Gaifax # from .. import ast @@ -11,7 +11,7 @@ NODE_PRINTERS = {} -"Registry of specialized printers, keyed by their `tag`." +"Registry of specialized node printers, keyed by their class." SPECIAL_FUNCTIONS = {} @@ -22,67 +22,69 @@ class PrinterAlreadyPresentError(Error): "Exception raised trying to register another function for a tag already present." -def get_printer_for_node_tag(parent_node_tag, node_tag): - """Get specific printer implementation for given `node_tag`. +def get_printer_for_node(node): + """Get specific printer implementation for given `node`. If there is a more specific printer for it, when it's inside a particular - `parent_node_tag`, return that instead. + ancestor, return that instead. """ - printer = NODE_PRINTERS.get((parent_node_tag, node_tag)) + node_class = type(node) + if not issubclass(node_class, ast.Node): + raise ValueError('Expected an ast.Node, not a %r' % node_class.__name__) + parent = abs(node.ancestors) + parent_node_class = None if parent is None else type(parent.node) + printer = NODE_PRINTERS.get((parent_node_class, node_class)) if printer is None: - printer = NODE_PRINTERS.get(node_tag) + printer = NODE_PRINTERS.get(node_class) if printer is None: raise NotImplementedError("Printer for node %r is not implemented yet" - % node_tag) + % node_class.__name__) return printer -def node_printer(*node_tags, override=False, check_tags=True): - r"""Decorator to register a specific printer implementation for given `node_tag`. +def node_printer(*nodes, override=False): + r"""Decorator to register a specific printer implementation for a (set of) `nodes`. - :param \*node_tags: one or two node tags + :param \*nodes: a list of one or two items :param bool override: when ``True`` the function will be registered even if already present in the :data:`NODE_PRINTERS` registry - :param bool check_tags: - by default each `node_tags` is checked for validity, that is must be a valid class - implemented by :mod:`pglast.ast`; pass ``False`` to disable the check - - When `node_tags` contains a single item then the decorated function is the *generic* one, - and it will be registered in :data:`NODE_PRINTERS` with that key alone. Otherwise it must - contain two elements: the first may be either a scalar value or a sequence of parent tags, - and the function will be registered under the key ``(parent_tag, tag)``. + + When `nodes` contains a single item then the decorated function is the *generic* + one, and it will be registered in :data:`NODE_PRINTERS` with that key alone. Otherwise it + must contain two elements: the first may be either a scalar value or a sequence of parent + nodes, and the function will be registered under the key ``(parent, node)``. """ - if check_tags: - for tag in node_tags: - if not isinstance(tag, (list, tuple)): - tag = (tag,) - for t in tag: - if not isinstance(t, str): - raise ValueError(f'Invalid tag {t!r}, expected a string') - if not hasattr(ast, t): - raise ValueError(f'Unknown tag name {t}') + n = len(nodes) + if n == 1: + parent_classes = (None,) + node_class = nodes[0] + elif n == 2: + pclasses, node_class = nodes + if not isinstance(pclasses, (list, tuple)): + pclasses = (pclasses,) + parent_classes = tuple(getattr(ast, cls) if isinstance(cls, str) else cls + for cls in pclasses) + else: + raise ValueError('Invalid nodes: must contain one or two items!') + + if isinstance(node_class, str): + node_class = getattr(ast, node_class) + elif not isinstance(node_class, type) or not issubclass(node_class, ast.Node): + raise ValueError('Invalid nodes: expected an ast.Node instance or its name,' + ' got %r' % node_class) def decorator(impl): - if len(node_tags) == 1: - parent_tags = (None,) - tag = node_tags[0] - elif len(node_tags) == 2: - parent_tags, tag = node_tags - if not isinstance(parent_tags, (list, tuple)): - parent_tags = (parent_tags,) - else: - raise ValueError(f'Must specify one or two tags, got {len(node_tags)} instead') - - for parent_tag in parent_tags: - t = tag if parent_tag is None else (parent_tag, tag) - if not override and t in NODE_PRINTERS: + for parent_class in parent_classes: + key = node_class if parent_class is None else (parent_class, node_class) + if not override and key in NODE_PRINTERS: # pragma: no cover raise PrinterAlreadyPresentError("A printer is already registered for tag %r" - % t) - NODE_PRINTERS[t] = impl + % key) + NODE_PRINTERS[key] = impl return impl + return decorator @@ -123,26 +125,20 @@ def __init__(self): self.value_to_symbol = {m.value: m.name for m in self.enum} def __call__(self, value, node, output): - from ..node import Missing, Scalar - - if value is Missing: + if value is None: # Should never happen, but better safe than sorry for symbol, member in self.enum.__members__.items(): if member.value == 0: break else: # pragma: no cover raise ValueError(f"Could not determine default value of class {self.enum!r}") - elif isinstance(value, Scalar): - if isinstance(value.value, str): - # libpg_query 13+ emits enum names, not values - symbol = value.value - if symbol not in self.enum.__members__: - raise ValueError('Unexpected symbol {symbol!r}, not in {self.enum!r}') - else: - symbol = self.value_to_symbol.get(value.value) + elif isinstance(value, self.enum): + symbol = self.value_to_symbol.get(value) + elif isinstance(value, ast.Integer): + symbol = self.value_to_symbol.get(value.val) else: symbol = value - if symbol is None: + if symbol is None: # pragma: no cover raise ValueError(f"Invalid value {value!r}, not in class {self.enum!r}") method = getattr(self, symbol, None) if method is None: @@ -155,4 +151,12 @@ def __call__(self, value, node, output): method(node, output) +def get_string_value(lst): + "Helper function to get a literal string value, wrapped in a one-sized list." + + if len(lst) != 1 or not isinstance(lst[0], ast.String): # pragma: no cover + raise TypeError('%r does not contain a single String node' % lst) + return lst[0].val + + from . import ddl, dml, sfuncs # noqa: F401,E402 diff --git a/pglast/printers/ddl.py b/pglast/printers/ddl.py index 33350d1..60c57b5 100644 --- a/pglast/printers/ddl.py +++ b/pglast/printers/ddl.py @@ -8,18 +8,17 @@ import re -from .. import enums -from ..node import Missing, List -from . import IntEnumPrinter, node_printer +from .. import ast, enums +from . import IntEnumPrinter, get_string_value, node_printer @node_printer('AccessPriv') def access_priv(node, output): - if node.priv_name is Missing: + if node.priv_name is None: output.write('ALL PRIVILEGES') else: - output.write(node.priv_name.value.upper()) - if node.cols is not Missing: + output.write(node.priv_name.upper()) + if node.cols is not None: output.write(' (') output.print_list(node.cols, ',', are_names=True) output.write(')') @@ -109,7 +108,7 @@ def alter_extension_stmt(node, output): @node_printer('AlterExtensionStmt', 'DefElem') def alter_extension_stmt_def_elem(node, output): - option = node.defname.value + option = node.defname if option == 'new_version': output.write('UPDATE TO ') output.print_node(node.arg) @@ -125,7 +124,7 @@ def alter_extension_contents_stmt(node, output): output.write(' DROP ') else: output.write(' ADD ') - output.write(OBJECT_NAMES[node.objtype.value]) + output.write(OBJECT_NAMES[node.objtype]) output.write(' ') output.print_node(node.object) @@ -137,19 +136,19 @@ def alter_enum_stmt(node, output): if node.newVal: if node.oldVal: output.write('RENAME VALUE ') - output.write_quoted_string(node.oldVal.value) + output.write_quoted_string(node.oldVal) output.write('TO ') else: output.write('ADD VALUE ') if node.skipIfNewValExists: output.write('IF NOT EXISTS ') - output.write_quoted_string(node.newVal.value) + output.write_quoted_string(node.newVal) if node.newValNeighbor: if node.newValIsAfter: output.write(' AFTER ') else: output.write(' BEFORE ') - output.write_quoted_string(node.newValNeighbor.value) + output.write_quoted_string(node.newValNeighbor) @node_printer('AlterDefaultPrivilegesStmt') @@ -157,14 +156,15 @@ def alter_default_privileges_stmt(node, output): output.writes('ALTER DEFAULT PRIVILEGES') roles = None schemas = None - for opt in node.options: - optname = opt.defname.value - if optname == 'roles': - roles = opt.arg - elif optname == 'schemas': - schemas = opt.arg - else: # pragma: no cover - raise NotImplementedError('Option not implemented: %s' % optname) + if node.options is not None: + for opt in node.options: + optname = opt.defname + if optname == 'roles': + roles = opt.arg + elif optname == 'schemas': + schemas = opt.arg + else: # pragma: no cover + raise NotImplementedError('Option not implemented: %s' % optname) if roles is not None: output.newline() with output.push_indent(2): @@ -191,7 +191,7 @@ def alter_default_privileges_stmt(node, output): else: output.write('ALL PRIVILEGES') output.write(' ON ') - output.write(OBJECT_NAMES[action.objtype.value]) + output.write(OBJECT_NAMES[action.objtype]) output.write('S ') output.writes(preposition) output.print_list(action.grantees, ',') @@ -275,7 +275,7 @@ def alter_op_family_stmt(node, output): @node_printer('AlterOwnerStmt') def alter_owner_stmt(node, output): output.write('ALTER ') - output.writes(OBJECT_NAMES[node.objectType.value]) + output.writes(OBJECT_NAMES[node.objectType]) OT = enums.ObjectType if node.objectType in (OT.OBJECT_OPFAMILY, OT.OBJECT_OPCLASS): @@ -337,7 +337,7 @@ def alter_seq_stmt(node, output): @node_printer('AlterTableStmt') def alter_table_stmt(node, output): output.write('ALTER ') - output.writes(OBJECT_NAMES[node.relkind.value]) + output.writes(OBJECT_NAMES[node.relkind]) if node.missing_ok: output.write('IF EXISTS ') output.print_node(node.relation) @@ -381,7 +381,7 @@ def alter_def_elem(node, output): @node_printer('AlterTableStmt', 'RangeVar') def range_var(node, output): - if node.parent_node.relkind == enums.ObjectType.OBJECT_TABLE and not node.inh: + if abs(node.ancestors).node.relkind == enums.ObjectType.OBJECT_TABLE and not node.inh: output.write('ONLY ') if node.schemaname: output.print_name(node.schemaname) @@ -399,7 +399,7 @@ class AlterTableTypePrinter(IntEnumPrinter): def AT_AddColumn(self, node, output): output.write('ADD ') - if node.parent_node.relkind == enums.ObjectType.OBJECT_TYPE: + if abs(node.ancestors).node.relkind == enums.ObjectType.OBJECT_TYPE: output.write('ATTRIBUTE ') else: output.write('COLUMN ') @@ -421,7 +421,7 @@ def AT_AddOf(self, node, output): def AT_AlterColumnType(self, node, output): output.write('ALTER ') - if node.parent_node.relkind == enums.ObjectType.OBJECT_TYPE: + if abs(node.ancestors).node.relkind == enums.ObjectType.OBJECT_TYPE: output.write('ATTRIBUTE ') else: output.write('COLUMN ') @@ -474,7 +474,7 @@ def AT_DropCluster(self, node, output): def AT_DropColumn(self, node, output): output.write('DROP ') - if node.parent_node.relkind == enums.ObjectType.OBJECT_TYPE: + if abs(node.ancestors).node.relkind == enums.ObjectType.OBJECT_TYPE: output.write('ATTRIBUTE ') else: output.write('COLUMN ') @@ -546,7 +546,7 @@ def AT_SetStatistics(self, node, output): if node.name: output.print_name(node.name) elif node.num: - output.write(str(node.num.value)) + output.write(str(node.num)) output.write(' SET STATISTICS ') output.print_node(node.def_) @@ -554,7 +554,7 @@ def AT_SetStorage(self, node, output): output.write('ALTER COLUMN ') output.print_name(node.name) output.write(' SET STORAGE ') - output.write(node.def_.val.value) + output.write(node.def_.val) def AT_SetUnLogged(self, node, output): output.write('SET UNLOGGED') @@ -596,7 +596,7 @@ def AT_DropExpression(self, node, output): def AT_AddIdentity(self, node, output): output.write('ALTER COLUMN ') output.print_name(node.name) - if node.num.value > 0: + if node.num > 0: # FIXME: find a way to get here output.print_node(node.num) output.write(' ADD ') @@ -606,7 +606,7 @@ def AT_AddIdentity(self, node, output): def AT_DropIdentity(self, node, output): output.write('ALTER COLUMN ') output.print_name(node.name) - if node.num.value > 0: + if node.num > 0: # FIXME: find a way to get here output.print_node(node.num) output.write(' DROP IDENTITY ') @@ -645,7 +645,7 @@ def AT_DisableTrigAll(self, node, output): def AT_SetIdentity(self, node, output): output.write('ALTER COLUMN ') output.print_name(node.name) - if node.num.value > 0: + if node.num > 0: # FIXME: find a way to get here output.print_node(node.num) for elem in node.def_: @@ -660,7 +660,7 @@ def AT_SetIdentity(self, node, output): output.write('CACHE ') output.print_node(elem.arg) elif elem.defname == 'cycle': - if elem.arg.val.value == 0: + if elem.arg.val == 0: output.write('NO ') output.write('CYCLE') elif elem.defname == 'increment': @@ -723,9 +723,9 @@ class AlterTSConfigTypePrinter(IntEnumPrinter): enum = enums.AlterTSConfigType def print_simple_name(self, node, output): - if isinstance(node, List): + if isinstance(node, tuple): node = node[0] - output.write(node.val.value) + output.write(node.val) def print_simple_list(self, nodes, output): first = True @@ -851,9 +851,9 @@ def alter_subscription_stmt(node, output): output.print_list(node.options) output.write(')') elif node.kind == enums.AlterSubscriptionType.ALTER_SUBSCRIPTION_ENABLED: - if node.options[0].arg.val.value == 0: + if node.options[0].arg.val == 0: output.write('DISABLE') - elif node.options[0].arg.val.value == 1: + elif node.options[0].arg.val == 1: output.write('ENABLE') @@ -920,7 +920,7 @@ def alter_user_mapping_stmt(node, output): output.write('ALTER USER MAPPING FOR ') role_spec(node.user, output) output.writes(' SERVER ') - output.write(node.servername.value) + output.write(node.servername) alter_def_elem(node.options, output) @@ -1036,9 +1036,8 @@ def column_def(node, output): @node_printer('CommentStmt') def comment_stmt(node, output): otypes = enums.ObjectType - output.write('COMMENT ') - output.write('ON ') - output.writes(OBJECT_NAMES[node.objtype.value]) + output.write('COMMENT ON ') + output.writes(OBJECT_NAMES[node.objtype]) if node.objtype in (otypes.OBJECT_OPCLASS, otypes.OBJECT_OPFAMILY): nodes = list(node.object) using = nodes.pop(0) @@ -1064,8 +1063,8 @@ def comment_stmt(node, output): output.print_name(nodes) elif node.objtype == otypes.OBJECT_AGGREGATE: _object_with_args(node.object, output, empty_placeholder='*') - elif isinstance(node.object, List): - if node.object[0].node_tag != 'String': + elif isinstance(node.object, tuple): + if not isinstance(node.object[0], ast.String): output.write(' (') output.print_list(node.object, ' AS ', standalone_items=False) output.write(')') @@ -1078,7 +1077,7 @@ def comment_stmt(node, output): output.write('IS ') if node.comment: with output.push_indent(): - output.write_quoted_string(node.comment.value) + output.write_quoted_string(node.comment) else: output.write('NULL') @@ -1166,14 +1165,14 @@ def CONSTR_FOREIGN(self, node, output): output.write(' (') output.print_name(node.pk_attrs, ',') output.write(')') - if node.fk_matchtype and node.fk_matchtype != enums.FKCONSTR_MATCH_SIMPLE: + if node.fk_matchtype != '\0' and node.fk_matchtype != enums.FKCONSTR_MATCH_SIMPLE: output.write(' MATCH ') if node.fk_matchtype == enums.FKCONSTR_MATCH_FULL: output.write('FULL') elif node.fk_matchtype == enums.FKCONSTR_MATCH_PARTIAL: # pragma: no cover # MATCH PARTIAL not yet implemented output.write('PARTIAL') - if node.fk_del_action and node.fk_del_action != enums.FKCONSTR_ACTION_NOACTION: + if node.fk_del_action != '\0' and node.fk_del_action != enums.FKCONSTR_ACTION_NOACTION: output.write(' ON DELETE ') if node.fk_del_action == enums.FKCONSTR_ACTION_RESTRICT: output.write('RESTRICT') @@ -1183,7 +1182,7 @@ def CONSTR_FOREIGN(self, node, output): output.write('SET NULL') elif node.fk_del_action == enums.FKCONSTR_ACTION_SETDEFAULT: output.write('SET DEFAULT') - if node.fk_upd_action and node.fk_upd_action != enums.FKCONSTR_ACTION_NOACTION: + if node.fk_upd_action != '\0' and node.fk_upd_action != enums.FKCONSTR_ACTION_NOACTION: output.write(' ON UPDATE ') if node.fk_upd_action == enums.FKCONSTR_ACTION_RESTRICT: output.write('RESTRICT') @@ -1313,15 +1312,15 @@ def create_db_stmt(node, output): @node_printer('CreatedbStmt', 'DefElem') def create_db_stmt_def_elem(node, output): - option = node.defname.value + option = node.defname if option == 'connection_limit': output.write('connection limit') else: output.print_symbol(node.defname) - if node.arg is not Missing: + if node.arg is not None: output.write(' = ') - if isinstance(node.arg, List) or option in ('allow_connections', 'is_template'): - output.write(node.arg.val.value) + if isinstance(node.arg, tuple) or option in ('allow_connections', 'is_template'): + output.write(node.arg.val) else: output.print_node(node.arg) @@ -1353,8 +1352,8 @@ def create_conversion_stmt(node, output): output.write('DEFAULT ') output.write('CONVERSION ') output.print_name(node.conversion_name) - output.write(" FOR '%s' TO '%s'" % (node.for_encoding_name.value, - node.to_encoding_name.value)) + output.write(" FOR '%s' TO '%s'" % (node.for_encoding_name, + node.to_encoding_name)) output.write(' FROM ') output.print_name(node.func_name) @@ -1420,9 +1419,9 @@ def create_extension_stmt(node, output): @node_printer('CreateExtensionStmt', 'DefElem') def create_extension_stmt_def_elem(node, output): - option = node.defname.value + option = node.defname if option == 'cascade': - if node.arg.val.value == 1: + if node.arg.val == 1: output.write('CASCADE') elif option == 'old_version': # FIXME: find a way to get here @@ -1462,15 +1461,15 @@ def create_fdw_stmt(node, output): @node_printer('CreateUserMappingStmt', 'DefElem') @node_printer('CreateFdwStmt', 'DefElem') def create_fdw_stmt_def_elem(node, output): - if node.parent_attribute[0] == 'options' or node.parent_attribute[0] == 'fdwoptions': - if ' ' in node.defname.value: - output.write(f'"{node.defname.value}"') + if abs(node.ancestors).member in ('options', 'fdwoptions'): + if ' ' in node.defname: + output.write(f'"{node.defname}"') else: - output.write(node.defname.value) + output.write(node.defname) output.write(' ') output.print_node(node.arg) else: - output.write(node.defname.value.upper()) + output.write(node.defname.upper()) output.write(' ') output.print_name(node.arg) @@ -1514,10 +1513,10 @@ def create_foreign_table_stmt(node, output): @node_printer('CreateForeignTableStmt', 'DefElem') @node_printer('CreateForeignServerStmt', 'DefElem') def create_foreign_table_stmt_def_elem(node, output): - if ' ' in node.defname.value: - output.write(f'"{node.defname.value}"') + if ' ' in node.defname: + output.write(f'"{node.defname}"') else: - output.write(node.defname.value) + output.write(node.defname) output.write(' ') output.print_node(node.arg) @@ -1534,21 +1533,24 @@ def create_function_stmt(node, output): output.print_name(node.funcname) output.write('(') - # Functions returning a SETOF needs special care, because the resulting record - # definition is intermixed with real parameters: split them into two separated - # lists - real_params = node.parameters - if node.returnType and node.returnType.setof: - fpm = enums.FunctionParameterMode - record_def = [] - real_params = [] - for param in node.parameters: - if param.mode == fpm.FUNC_PARAM_TABLE: - record_def.append(param) - else: - real_params.append(param) - if real_params: - output.print_list(real_params) + if node.parameters is not None: + # Functions returning a SETOF needs special care, because the resulting record + # definition is intermixed with real parameters: split them into two separated + # lists + real_params = node.parameters + if node.returnType and node.returnType.setof: + fpm = enums.FunctionParameterMode + record_def = [] + real_params = [] + for param in node.parameters: + if param.mode == fpm.FUNC_PARAM_TABLE: + record_def.append(param) + else: + real_params.append(param) + if real_params: + output.print_list(real_params) + else: + record_def = False output.write(')') if node.returnType: @@ -1561,32 +1563,31 @@ def create_function_stmt(node, output): output.write(')') else: output.print_node(node.returnType) - for option in node.options: output.print_node(option) @node_printer(('AlterFunctionStmt', 'CreateFunctionStmt', 'DoStmt'), 'DefElem') def create_function_option(node, output): - option = node.defname.value + option = node.defname if option == 'as': - if isinstance(node.arg, List) and len(node.arg) > 1: + if isinstance(node.arg, tuple) and len(node.arg) > 1: # We are in the weird C case output.write('AS ') output.print_list(node.arg) return - if node.parent_node.node_tag == 'CreateFunctionStmt': + if isinstance(abs(node.ancestors).node, ast.CreateFunctionStmt): output.newline() output.write('AS ') # Choose a valid dollar-string delimiter - if isinstance(node.arg, List): - code = node.arg[0].val.value + if isinstance(node.arg, tuple): + code = node.arg[0].val else: - code = node.arg.val.value + code = node.arg.val used_delimiters = set(re.findall(r"\$(\w*)(?=\$)", code)) unique_delimiter = '' while unique_delimiter in used_delimiters: @@ -1601,7 +1602,7 @@ def create_function_option(node, output): return if option == 'security': - if node.arg.val.value == 1: + if node.arg.val == 1: output.swrite('SECURITY DEFINER') else: output.swrite('SECURITY INVOKER') @@ -1616,7 +1617,7 @@ def create_function_option(node, output): if option == 'volatility': output.separator() - output.write(node.arg.val.value.upper()) + output.write(node.arg.val.upper()) return if option == 'parallel': @@ -1629,7 +1630,7 @@ def create_function_option(node, output): return if option == 'leakproof': - if node.arg.val.value == 0: + if node.arg.val == 0: output.swrite('NOT') output.swrite('LEAKPROOF') return @@ -1643,7 +1644,7 @@ def create_function_option(node, output): output.write('WINDOW') return - output.writes(node.defname.value.upper()) + output.writes(node.defname.upper()) output.print_symbol(node.arg) @@ -1665,7 +1666,7 @@ def create_opclass_stmt(node, output): def create_opclass_item(node, output): if node.itemtype == enums.OPCLASS_ITEM_OPERATOR: output.write('OPERATOR ') - output.write('%d ' % node.number._value) + output.write('%d ' % node.number) if node.name: _object_with_args(node.name, output, symbol=True, skip_empty_args=True) if node.order_family: @@ -1677,7 +1678,7 @@ def create_opclass_item(node, output): output.write(')') elif node.itemtype == enums.OPCLASS_ITEM_FUNCTION: output.write('FUNCTION ') - output.write('%d ' % node.number._value) + output.write('%d ' % node.number) if node.class_args: output.write(' (') output.print_list(node.class_args, standalone_items=False) @@ -1817,7 +1818,7 @@ def create_role_stmt(node, output): @node_printer('AlterRoleStmt', 'DefElem') @node_printer('CreateRoleStmt', 'DefElem') def create_or_alter_role_option(node, output): - option = node.defname.value + option = node.defname argv = node.arg if option == 'sysid': output.write('SYSID ') @@ -1888,7 +1889,7 @@ def create_seq_stmt(node, output): output.writes('SEQUENCE') if node.if_not_exists: output.writes('IF NOT EXISTS') - if node.sequence.schemaname is not Missing: + if node.sequence.schemaname is not None: output.print_name(node.sequence.schemaname) output.write('.') output.print_name(node.sequence.relname) @@ -1903,9 +1904,9 @@ def create_seq_stmt(node, output): @node_printer('CreateSeqStmt', 'DefElem') @node_printer('AlterSeqStmt', 'DefElem') def create_seq_stmt_def_elem(node, output): - option = node.defname.value + option = node.defname if option == 'cycle': - if node.arg.val.value == 0: + if node.arg.val == 0: output.write('NO ') output.write('CYCLE') elif option == 'increment': @@ -1923,10 +1924,10 @@ def create_seq_stmt_def_elem(node, output): output.write(' WITH ') output.print_node(node.arg) else: - if node.arg is Missing: + if node.arg is None: output.write('NO ') output.write(option.upper()) - if node.arg is not Missing: + if node.arg is not None: output.write(' ') output.print_node(node.arg) if node.defaction != enums.DefElemAction.DEFELEM_UNSPEC: # pragma: no cover @@ -1952,11 +1953,7 @@ def create_stats_stmt(node, output): @node_printer('CreateStmt') def create_stmt(node, output): output.writes('CREATE') - # NB: parent_node may be None iff we are dealing with a single concrete statement, not with - # the ephemeral RawStmt returned by parse_sql(); in all other cases in this source where - # we access the parent_node we are actually in a "contextualized printer", that can only be - # entered when there is a parent node - if node.parent_node is not None and node.parent_node.node_tag == 'CreateForeignTableStmt': + if isinstance(node.ancestors[0], ast.CreateForeignTableStmt): output.writes('FOREIGN') else: if node.relation.relpersistence == enums.RELPERSISTENCE_TEMP: @@ -2045,7 +2042,7 @@ def create_table_as_stmt(node, output): output.writes('TEMPORARY') elif node.into.rel.relpersistence == enums.RELPERSISTENCE_UNLOGGED: output.writes('UNLOGGED') - output.writes(OBJECT_NAMES[node.relkind.value]) + output.writes(OBJECT_NAMES[node.relkind]) if node.if_not_exists: output.writes('IF NOT EXISTS') output.print_node(node.into) @@ -2068,7 +2065,7 @@ def create_table_space_stmt(node, output): output.print_node(node.owner) output.space() output.write('LOCATION ') - output.write_quoted_string(node.location.value) + output.write_quoted_string(node.location) if node.options: output.write(' WITH (') output.print_list(node.options) @@ -2084,7 +2081,7 @@ def create_trig_stmt(node, output): output.print_name(node.trigname) output.newline() with output.push_indent(2): - if node.timing.value: + if node.timing: if node.timing & enums.TRIGGER_TYPE_BEFORE: output.write('BEFORE ') elif node.timing & enums.TRIGGER_TYPE_INSTEAD: @@ -2149,8 +2146,8 @@ def create_subscription_stmt_stmt_def_elem(node, output): output.print_name(node.defname) if node.arg: output.write(' = ') - if node.arg.node_tag == 'String' and node.arg.val.value in ('true', 'false'): - output.write(node.arg.val.value) + if isinstance(node.arg, ast.String) and node.arg.val in ('true', 'false'): + output.write(node.arg.val) else: output.print_node(node.arg) @@ -2220,7 +2217,7 @@ def create_user_mapping_stmt(node, output): output.write('FOR ') role_spec(node.user, output) output.writes(' SERVER') - output.write(node.servername.value) + output.write(node.servername) if node.options: output.write(' OPTIONS (') output.print_list(node.options, ',') @@ -2241,22 +2238,22 @@ def define_stmt(node, output): output.write('CREATE ') if node.replace: output.write('OR REPLACE ') - output.writes(OBJECT_NAMES[node.kind.value]) + output.writes(OBJECT_NAMES[node.kind]) if node.if_not_exists: output.write('IF NOT EXISTS ') output.print_list(node.defnames, '.', standalone_items=False, are_names=True, is_symbol=node.kind == enums.ObjectType.OBJECT_OPERATOR) - if node.args is not Missing: + if node.args is not None: # args is actually a tuple (list-of-nodes, integer): the integer value, if different # from -1, is the number of nodes representing the actual arguments, remaining are # ORDER BY args, count = node.args - count = count.val.value + count = count.val output.write(' (') if count == -1: # Special case: if it's an aggregate, and the scalar is equal to # None (not is, since it's a Scalar), write a star - if ((node.kind.value == enums.ObjectType.OBJECT_AGGREGATE + if ((node.kind == enums.ObjectType.OBJECT_AGGREGATE and args == None)): output.write('*') actual_args = [] @@ -2300,7 +2297,7 @@ def define_stmt(node, output): @node_printer('DefElem') def def_elem(node, output): output.print_symbol(node.defname) - if node.arg is not Missing: + if node.arg is not None: output.write(' = ') output.print_node(node.arg) if node.defaction != enums.DefElemAction.DEFELEM_UNSPEC: # pragma: no cover @@ -2309,10 +2306,10 @@ def def_elem(node, output): @node_printer('DefineStmt', 'DefElem') def define_stmt_def_elem(node, output): - output.print_node(node.defname, is_name=True) - if node.arg is not Missing: + output.print_name(node.defname) + if node.arg is not None: output.write(' = ') - if isinstance(node.arg, List): + if isinstance(node.arg, tuple): is_symbol = node.defname in ('commutator', 'negator') if is_symbol and len(node.arg) > 1: output.write('OPERATOR(') @@ -2394,7 +2391,7 @@ def drop_stmt(node, output): otypes = enums.ObjectType output.write('DROP ') # Special case functions since they are not special objects - output.writes(OBJECT_NAMES[node.removeType.value]) + output.writes(OBJECT_NAMES[node.removeType]) if node.removeType == otypes.OBJECT_INDEX: if node.concurrent: output.write(' CONCURRENTLY') @@ -2422,13 +2419,16 @@ def drop_stmt(node, output): output.print_name(nodes[-1]) output.write(' ON ') output.print_name(on) - elif isinstance(node.objects[0], List): - if node.objects[0][0].node_tag != 'String': - output.print_lists(node.objects, ' AS ', standalone_items=False, - are_names=True) - else: - output.print_lists(node.objects, sep='.', sublist_open='', sublist_close='', - standalone_items=False, are_names=True) + elif node.removeType == otypes.OBJECT_CAST: + names = node.objects[0] + output.write('(') + output.print_name(names[0]) + output.write(' AS ') + output.print_name(names[1]) + output.write(')') + elif isinstance(node.objects[0], tuple): + output.print_lists(node.objects, sep='.', sublist_open='', sublist_close='', + standalone_items=False, are_names=True) else: output.print_list(node.objects, ',', standalone_items=False, are_names=True) if node.behavior == enums.DropBehavior.DROP_CASCADE: @@ -2470,7 +2470,7 @@ def drop_user_mapping_stmt(node, output): @node_printer('FunctionParameter') def function_parameter(node, output): - if node.mode is not Missing: + if node.mode is not None: pm = enums.FunctionParameterMode if node.mode == pm.FUNC_PARAM_IN: pass # omit the default, output.write('IN ') @@ -2488,7 +2488,7 @@ def function_parameter(node, output): output.print_name(node.name) output.write(' ') output.print_node(node.argType) - if node.defexpr is not Missing: + if node.defexpr is not None: output.write(' = ') output.print_node(node.defexpr) @@ -2508,10 +2508,10 @@ def grant_stmt(node, output): else: output.write('ALL PRIVILEGES') # hack for OBJECT_FOREIGN_SERVER - if node.objtype.value == enums.ObjectType.OBJECT_FOREIGN_SERVER: + if node.objtype == enums.ObjectType.OBJECT_FOREIGN_SERVER: object_name = 'FOREIGN SERVER' else: - object_name = OBJECT_NAMES[node.objtype.value] + object_name = OBJECT_NAMES[node.objtype] target = node.targtype output.newline() output.space(2) @@ -2631,7 +2631,7 @@ def index_stmt(node, output): def lock_stmt(node, output): output.write('LOCK ') output.print_list(node.relations, ',') - lock_mode = node.mode.value + lock_mode = node.mode lock_str = LOCK_MODE_NAMES[lock_mode] output.write('IN ') output.write(lock_str) @@ -2646,7 +2646,7 @@ def notify_stmt(node, output): output.print_name(node.conditionname) if node.payload: output.write(', ') - output.write_quoted_string(node.payload.value) + output.write_quoted_string(node.payload) def _object_with_args(node, output, unquote_name=False, symbol=False, @@ -2657,9 +2657,9 @@ def _object_with_args(node, output, unquote_name=False, symbol=False, for idx, name in enumerate(node.objname): if idx > 0: output.write('.') - output.write(name.val.value) + output.write(name.val) else: - output.write(node.objname.string_value) + output.write(get_string_value(node.objname)) elif symbol: output.print_symbol(node.objname) else: @@ -2682,7 +2682,7 @@ def object_with_args(node, output): @node_printer(('AlterObjectSchemaStmt',), 'ObjectWithArgs') def alter_object_schema_stmt_object_with_args(node, output): - symbol = node.parent_node.objectType == enums.ObjectType.OBJECT_OPERATOR + symbol = abs(node.ancestors).node.objectType == enums.ObjectType.OBJECT_OPERATOR _object_with_args(node, output, symbol=symbol) @@ -2693,20 +2693,21 @@ def alter_operator_stmt_object_with_args(node, output): @node_printer(('AlterOwnerStmt',), 'ObjectWithArgs') def alter_owner_stmt_object_with_args(node, output): - unquote_name = node.parent_node.objectType == enums.ObjectType.OBJECT_OPERATOR + unquote_name = abs(node.ancestors).node.objectType == enums.ObjectType.OBJECT_OPERATOR _object_with_args(node, output, unquote_name=unquote_name) @node_printer(('CommentStmt',), 'ObjectWithArgs') def comment_stmt_object_with_args(node, output): - unquote_name = node.parent_node.objtype == enums.ObjectType.OBJECT_OPERATOR + unquote_name = abs(node.ancestors).node.objtype == enums.ObjectType.OBJECT_OPERATOR _object_with_args(node, output, unquote_name=unquote_name) @node_printer(('DropStmt',), 'ObjectWithArgs') def drop_stmt_object_with_args(node, output): - unquote_name = node.parent_node.removeType == enums.ObjectType.OBJECT_OPERATOR - if node.parent_node.removeType == enums.ObjectType.OBJECT_AGGREGATE: + parent_node = abs(node.ancestors).node + unquote_name = parent_node.removeType == enums.ObjectType.OBJECT_OPERATOR + if parent_node.removeType == enums.ObjectType.OBJECT_AGGREGATE: _object_with_args(node, output, empty_placeholder='*', unquote_name=unquote_name) else: _object_with_args(node, output, unquote_name=unquote_name) @@ -2730,7 +2731,7 @@ def partition_bound_spec(node, output): output.write(')') elif node.strategy == enums.PARTITION_STRATEGY_HASH: output.write('WITH (MODULUS %d, REMAINDER %d)' - % (node.modulus.value, node.remainder.value)) + % (node.modulus, node.remainder)) else: raise NotImplementedError('Unhandled strategy %r' % node.strategy) @@ -2765,7 +2766,7 @@ def partition_range_datum(node, output): elif node.kind == enums.PartitionRangeDatumKind.PARTITION_RANGE_DATUM_MAXVALUE: output.write('MAXVALUE') else: - output.print_node(node.value) + output.print_node(node) @node_printer('PartitionSpec') @@ -2875,8 +2876,8 @@ def rename_stmt(node, output): @node_printer('RenameStmt', 'RangeVar') def rename_stmt_range_var(node, output): OT = enums.ObjectType - if not node.inh and node.parent_node.renameType not in (OT.OBJECT_ATTRIBUTE, - OT.OBJECT_TYPE): + if not node.inh and abs(node.ancestors).node.renameType not in (OT.OBJECT_ATTRIBUTE, + OT.OBJECT_TYPE): output.write('ONLY ') if node.schemaname: output.print_name(node.schemaname) @@ -2932,7 +2933,7 @@ def rule_stmt_printer(node, output): output.newline() with output.push_indent(2): output.write('ON ') - output.write(EVENT_NAMES[node.event.value]) + output.write(EVENT_NAMES[node.event]) output.write(' TO ') output.print_name(node.relation) if node.whereClause: @@ -2987,7 +2988,7 @@ def sec_label_stmt(node, output): output.write('FOR ') output.print_name(node.provider) output.write(' ON ') - output.write(OBJECT_NAMES[node.objtype.value]) + output.write(OBJECT_NAMES[node.objtype]) output.write(' ') output.print_name(node.object) output.write(' IS ') @@ -3048,10 +3049,10 @@ def vacuum_stmt(node, output): @node_printer('VacuumStmt', 'DefElem') def vacuum_stmt_def_elem(node, output): - output.write(node.defname.value.upper()) + output.write(node.defname.upper()) if node.arg: output.write(' ') - output.write(str(node.arg.val.value)) + output.write(str(node.arg.val)) @node_printer('VacuumRelation') diff --git a/pglast/printers/dml.py b/pglast/printers/dml.py index f0f63eb..9d4aabb 100644 --- a/pglast/printers/dml.py +++ b/pglast/printers/dml.py @@ -6,9 +6,8 @@ # :Copyright: © 2017, 2018, 2019, 2020, 2021, 2022 Lele Gaifax # -from .. import enums -from ..node import Missing, List -from . import IntEnumPrinter, node_printer +from .. import ast, enums +from . import IntEnumPrinter, get_string_value, node_printer @node_printer('A_ArrayExpr') @@ -38,24 +37,24 @@ def AEXPR_BETWEEN_SYM(self, node, output): output.print_list(node.rexpr, 'AND', relative_indent=-4) def AEXPR_DISTINCT(self, node, output): - if node.lexpr.node_tag == 'BoolExpr': + if isinstance(node.lexpr, ast.BoolExpr): output.write('(') output.print_node(node.lexpr) - if node.lexpr.node_tag == 'BoolExpr': + if isinstance(node.lexpr, ast.BoolExpr): output.write(') ') output.swrites('IS DISTINCT FROM') output.print_node(node.rexpr) def AEXPR_ILIKE(self, node, output): output.print_node(node.lexpr) - if node.name.string_value == '!~~*': + if get_string_value(node.name) == '!~~*': output.swrites('NOT') output.swrites('ILIKE') output.print_node(node.rexpr) def AEXPR_IN(self, node, output): output.print_node(node.lexpr) - if node.name.string_value == '<>': + if get_string_value(node.name) == '<>': output.swrites('NOT') output.swrite('IN (') output.print_list(node.rexpr) @@ -63,7 +62,7 @@ def AEXPR_IN(self, node, output): def AEXPR_LIKE(self, node, output): output.print_node(node.lexpr) - if node.name.string_value == '!~~': + if get_string_value(node.name) == '!~~': output.swrites('NOT') output.swrites('LIKE') output.print_node(node.rexpr) @@ -91,7 +90,7 @@ def AEXPR_NULLIF(self, node, output): def AEXPR_OF(self, node, output): output.print_node(node.lexpr) output.swrites('IS') - if node.name.string_value == '<>': + if get_string_value(node.name) == '<>': output.writes('NOT') output.write('OF (') output.print_list(node.rexpr) @@ -100,8 +99,8 @@ def AEXPR_OF(self, node, output): def AEXPR_OP(self, node, output): with output.expression(): # lexpr is optional because these are valid: -(1+1), +(1+1), ~(1+1) - if node.lexpr is not Missing: - if node.lexpr.node_tag == 'A_Expr': + if node.lexpr is not None: + if isinstance(node.lexpr, ast.A_Expr): if node.lexpr.kind == node.kind and node.lexpr.name == node.name: output.print_node(node.lexpr) else: @@ -110,15 +109,15 @@ def AEXPR_OP(self, node, output): else: output.print_node(node.lexpr) output.write(' ') - if isinstance(node.name, List) and len(node.name) > 1: + if isinstance(node.name, tuple) and len(node.name) > 1: output.write('OPERATOR(') output.print_symbol(node.name) output.write(') ') else: output.print_symbol(node.name) output.write(' ') - if node.rexpr is not Missing: - if node.rexpr.node_tag == 'A_Expr': + if node.rexpr is not None: + if isinstance(node.rexpr, ast.A_Expr): if node.rexpr.kind == node.kind and node.rexpr.name == node.name: output.print_node(node.rexpr) else: @@ -130,7 +129,7 @@ def AEXPR_OP(self, node, output): def AEXPR_OP_ALL(self, node, output): output.print_node(node.lexpr) output.write(' ') - output.write(node.name.string_value) + output.write(get_string_value(node.name)) output.write(' ALL(') output.print_node(node.rexpr) output.write(')') @@ -138,7 +137,7 @@ def AEXPR_OP_ALL(self, node, output): def AEXPR_OP_ANY(self, node, output): output.print_node(node.lexpr) output.write(' ') - output.write(node.name.string_value) + output.write(get_string_value(node.name)) output.write(' ANY(') output.print_node(node.rexpr) output.write(')') @@ -152,11 +151,11 @@ def AEXPR_PAREN(self, node, output): # pragma: no cover def AEXPR_SIMILAR(self, node, output): output.print_node(node.lexpr) - if node.name.string_value == '!~': + if get_string_value(node.name) == '!~': output.swrites('NOT') output.swrites('SIMILAR TO') - assert (node.rexpr.node_tag == 'FuncCall' - and node.rexpr.funcname[1].val.value == 'similar_to_escape') + assert (isinstance(node.rexpr, ast.FuncCall) + and node.rexpr.funcname[1].val == 'similar_to_escape') pattern = node.rexpr.args[0] output.print_node(pattern) if len(node.rexpr.args) > 1: @@ -188,11 +187,11 @@ def a_indices(node, output): @node_printer('A_Indirection') def a_indirection(node, output): - bracket = ((node.arg.node_tag in ('A_ArrayExpr', 'A_Expr', 'A_Indirection', 'FuncCall', - 'RowExpr', 'TypeCast')) + bracket = (isinstance(node.arg, (ast.A_ArrayExpr, ast.A_Expr, ast.A_Indirection, + ast.FuncCall, ast.RowExpr, ast.TypeCast)) or - (node.arg.node_tag == 'ColumnRef' - and node.indirection[0].node_tag != 'A_Indices')) + (isinstance(node.arg, ast.ColumnRef) + and not isinstance(node.indirection[0], ast.A_Indices))) if bracket: output.write('(') output.print_node(node.arg) @@ -244,9 +243,7 @@ def alias(node, output): @node_printer('BitString') def bitstring(node, output): - output.write(f"{node.val.value[0]}'") - output.write(node.val.value[1:]) - output.write("'") + output.write(f"{node.val[0]}'{node.val[1:]}'") @node_printer('BoolExpr') @@ -254,7 +251,7 @@ def bool_expr(node, output): bet = enums.BoolExprType outer_exp_level = output.expression_level with output.expression(): - in_res_target = node.parent_node.node_tag == 'ResTarget' + in_res_target = isinstance(node.ancestors[0], ast.ResTarget) if node.boolop == bet.AND_EXPR: relindent = -4 if not in_res_target and outer_exp_level == 0 else None output.print_list(node.args, 'AND', relative_indent=relindent) @@ -425,7 +422,7 @@ def copy_stmt(node, output): if node.is_program: output.write('PROGRAM ') if node.filename: - output.print_node(node.filename) + output.write_quoted_string(node.filename) else: if node.is_from: output.write('STDIN') @@ -444,7 +441,7 @@ def copy_stmt(node, output): @node_printer('CopyStmt', 'DefElem') def copy_stmt_def_elem(node, output): - option = node.defname.value + option = node.defname argv = node.arg if option == 'format': output.write('FORMAT ') @@ -452,7 +449,7 @@ def copy_stmt_def_elem(node, output): elif option == 'freeze': output.write('FREEZE') if argv: - output.swrite(str(argv.val.value)) + output.swrite(str(argv.val)) elif option == 'delimiter': output.write('DELIMITER ') output.print_node(argv) @@ -462,7 +459,7 @@ def copy_stmt_def_elem(node, output): elif option == 'header': output.write('HEADER') if argv: - output.swrite(str(argv.val.value)) + output.swrite(str(argv.val)) elif option == 'quote': output.write('QUOTE ') output.print_node(argv) @@ -472,7 +469,7 @@ def copy_stmt_def_elem(node, output): elif option == 'force_quote': output.write('FORCE_QUOTE ') # If it is a list print it. - if isinstance(argv, List): + if isinstance(argv, tuple): output.write('(') output.print_list(argv, are_names=True) output.write(')') @@ -548,7 +545,7 @@ def delete_stmt(node, output): @node_printer('ExecuteStmt') def execute_stmt(node, output): output.write('EXECUTE ') - output.print_node(node.name, is_name=True) + output.print_name(node.name) if node.params: output.write('(') output.print_list(node.params) @@ -570,7 +567,7 @@ def explain_stmt(node, output): @node_printer('ExplainStmt', 'DefElem') def explain_stmt_def_elem(node, output): output.print_symbol(node.defname) - if node.arg is not Missing: + if node.arg is not None: output.write(' ') output.print_symbol(node.arg) @@ -582,13 +579,13 @@ def FETCH_FORWARD(self, node, output): if node.howMany == enums.FETCH_ALL: output.write('ALL ') elif node.howMany != 1: - output.write(f'FORWARD {node.howMany.value} ') + output.write(f'FORWARD {node.howMany} ') def FETCH_BACKWARD(self, node, output): if node.howMany == enums.FETCH_ALL: output.write('BACKWARD ALL ') elif node.howMany != 1: - output.write(f'BACKWARD {node.howMany.value} ') + output.write(f'BACKWARD {node.howMany} ') else: output.write('PRIOR ') @@ -598,10 +595,10 @@ def FETCH_ABSOLUTE(self, node, output): elif node.howMany == -1: output.write('LAST ') else: - output.write(f'ABSOLUTE {node.howMany.value} ') + output.write(f'ABSOLUTE {node.howMany} ') def FETCH_RELATIVE(self, node, output): - output.write(f'RELATIVE {node.howMany.value} ') + output.write(f'RELATIVE {node.howMany} ') fetch_direction_printer = FetchDirectionPrinter() @@ -621,7 +618,7 @@ def float(node, output): @node_printer('FuncCall') def func_call(node, output): - name = '.'.join(n.val.value for n in node.funcname) + name = '.'.join(n.val for n in node.funcname) special_printer = output.get_printer_for_function(name) if special_printer is not None: special_printer(node, output) @@ -631,7 +628,7 @@ def func_call(node, output): output.write('(') if node.agg_distinct: output.writes('DISTINCT') - if node.args is Missing: + if node.args is None: if node.agg_star: output.write('*') else: @@ -698,7 +695,7 @@ def grouping_func(node, output): @node_printer('IndexElem') def index_elem(node, output): - if node.name is not Missing: + if node.name is not None: output.print_name(node.name) else: output.write('(') @@ -849,7 +846,7 @@ def join_expr(node, output): output.swrites('JOIN') - if node.rarg.node_tag == 'JoinExpr': + if isinstance(node.rarg, ast.JoinExpr): output.indent(3, relative=False) # need this for: # tests/test_printers_roundtrip.py::test_pg_regress_corpus[join.sql] - @@ -876,7 +873,7 @@ def join_expr(node, output): output.writes(') AS') output.print_name(node.alias) - if node.rarg.node_tag == 'JoinExpr': + if isinstance(node.rarg, ast.JoinExpr): output.dedent() @@ -946,13 +943,13 @@ def null_test(node, output): @node_printer('ParamRef') def param_ref(node, output): - if node.number is Missing: # pragma: no cover + if node.number is None: # pragma: no cover # NB: standard PG does not allow "?"-style param placeholders, this is a minor # deviation introduced by libpg_query; in version 2 apparently the case is merged # back to the standard style below output.write('?') else: - output.write('$%d' % node.number.value) + output.write('$%d' % node.number) @node_printer('PrepareStmt') @@ -1212,7 +1209,7 @@ def select_stmt(node, output): if node.valuesLists: # Is this a SELECT ... FROM (VALUES (...))? - require_parens = node.parent_node.node_tag == 'RangeSubselect' + require_parens = isinstance(node.ancestors[0], ast.RangeSubselect) if require_parens: output.write('(') output.write('VALUES ') @@ -1298,12 +1295,12 @@ def select_stmt(node, output): output.write('FETCH FIRST ') # FIXME do we need add '()' for all ? if ((node.limitCount - and node.limitCount.node_tag == "A_Expr" + and isinstance(node.limitCount, ast.A_Expr) and node.limitCount.kind == enums.A_Expr_Kind.AEXPR_OP)): output.write('(') output.print_node(node.limitCount) if ((node.limitCount - and node.limitCount.node_tag == "A_Expr" + and isinstance(node.limitCount, ast.A_Expr) and node.limitCount.kind == enums.A_Expr_Kind.AEXPR_OP)): output.write(')') if node.limitOption == enums.LimitOption.LIMIT_OPTION_WITH_TIES: @@ -1366,12 +1363,12 @@ def SVFOP_CURRENT_TIMESTAMP(self, node, output): def SVFOP_CURRENT_TIMESTAMP_N(self, node, output): # pragma: no cover output.write('CURRENT_TIMESTAMP(') - output.write(str(node.typmod.value)) + output.write(str(node.typmod)) output.write(')') def SVFOP_CURRENT_TIME_N(self, node, output): # pragma: no cover output.write('CURRENT_TIME(') - output.write(str(node.typmod.value)) + output.write(str(node.typmod)) output.write(')') def SVFOP_CURRENT_USER(self, node, output): @@ -1385,12 +1382,12 @@ def SVFOP_LOCALTIMESTAMP(self, node, output): def SVFOP_LOCALTIMESTAMP_N(self, node, output): # pragma: no cover output.write('LOCALTIMESTAMP(') - output.write(str(node.typmod.value)) + output.write(str(node.typmod)) output.write(')') def SVFOP_LOCALTIME_N(self, node, output): # pragma: no cover output.write('LOCALTIME(') - output.write(str(node.typmod.value)) + output.write(str(node.typmod)) output.write(')') def SVFOP_SESSION_USER(self, node, output): @@ -1422,13 +1419,13 @@ def sub_link(node, output): elif node.subLinkType == slt.ALL_SUBLINK: output.print_node(node.testexpr) output.write(' ') - output.write(node.operName.string_value) + output.write(get_string_value(node.operName)) output.write(' ALL ') elif node.subLinkType == slt.ANY_SUBLINK: output.print_node(node.testexpr) if node.operName: output.write(' ') - output.write(node.operName.string_value) + output.write(get_string_value(node.operName)) output.write(' ANY ') else: output.write(' IN ') @@ -1451,57 +1448,57 @@ def sub_link(node, output): @node_printer('TransactionStmt') def transaction_stmt(node, output): tsk = enums.TransactionStmtKind - if node.kind.value == tsk.TRANS_STMT_BEGIN: + if node.kind == tsk.TRANS_STMT_BEGIN: output.write('BEGIN ') if node.options: output.print_list(node.options) - elif node.kind.value == tsk.TRANS_STMT_START: + elif node.kind == tsk.TRANS_STMT_START: output.write('START TRANSACTION ') if node.options: output.print_list(node.options) - elif node.kind.value == tsk.TRANS_STMT_COMMIT: + elif node.kind == tsk.TRANS_STMT_COMMIT: output.write('COMMIT ') if node.chain: output.write('AND CHAIN ') - elif node.kind.value == tsk.TRANS_STMT_ROLLBACK: + elif node.kind == tsk.TRANS_STMT_ROLLBACK: output.write('ROLLBACK ') if node.chain: output.write('AND CHAIN ') - elif node.kind.value == tsk.TRANS_STMT_SAVEPOINT: + elif node.kind == tsk.TRANS_STMT_SAVEPOINT: output.write('SAVEPOINT ') - output.write(node.savepoint_name.value) - elif node.kind.value == tsk.TRANS_STMT_RELEASE: + output.write(node.savepoint_name) + elif node.kind == tsk.TRANS_STMT_RELEASE: output.write('RELEASE ') - output.write(node.savepoint_name.value) - elif node.kind.value == tsk.TRANS_STMT_ROLLBACK_TO: + output.write(node.savepoint_name) + elif node.kind == tsk.TRANS_STMT_ROLLBACK_TO: output.write('ROLLBACK TO SAVEPOINT ') - output.write(node.savepoint_name.value) - elif node.kind.value == tsk.TRANS_STMT_PREPARE: + output.write(node.savepoint_name) + elif node.kind == tsk.TRANS_STMT_PREPARE: output.write('PREPARE TRANSACTION ') - output.write("'%s'" % node.gid.value) - elif node.kind.value == tsk.TRANS_STMT_COMMIT_PREPARED: + output.write("'%s'" % node.gid) + elif node.kind == tsk.TRANS_STMT_COMMIT_PREPARED: output.write('COMMIT PREPARED ') - output.write("'%s'" % node.gid.value) - elif node.kind.value == tsk.TRANS_STMT_ROLLBACK_PREPARED: + output.write("'%s'" % node.gid) + elif node.kind == tsk.TRANS_STMT_ROLLBACK_PREPARED: output.write('ROLLBACK PREPARED ') - output.write("'%s'" % node.gid.value) + output.write("'%s'" % node.gid) @node_printer('TransactionStmt', 'DefElem') def transaction_stmt_def_elem(node, output): - value = node.defname.value + value = node.defname argv = node.arg.val if value == 'transaction_isolation': output.write('ISOLATION LEVEL ') - output.write(argv.val.value.upper()) + output.write(argv.val.upper()) elif value == 'transaction_read_only': output.write('READ ') - if argv.val.value == 0: + if argv.val == 0: output.write('WRITE') else: output.write('ONLY') elif value == 'transaction_deferrable': - if argv.val.value == 0: + if argv.val == 0: output.write('NOT ') output.write('DEFERRABLE') else: # pragma: no cover @@ -1520,15 +1517,15 @@ def truncate_stmt(node, output): @node_printer('TypeCast') def type_cast(node, output): - if node.arg.node_tag == 'A_Const': + if isinstance(node.arg, ast.A_Const): # Special case for boolean constants - if ((node.arg.val.node_tag != 'Null' - and node.arg.val.val.value in ('t', 'f') - and '.'.join(n.val.value for n in node.typeName.names) == 'pg_catalog.bool')): + if ((not isinstance(node.arg.val, ast.Null) + and node.arg.val.val in ('t', 'f') + and '.'.join(n.val for n in node.typeName.names) == 'pg_catalog.bool')): output.write('TRUE' if node.arg.val.val == 't' else 'FALSE') return # Special case for bpchar - elif (('.'.join(n.val.value for n in node.typeName.names) == 'pg_catalog.bpchar' + elif (('.'.join(n.val for n in node.typeName.names) == 'pg_catalog.bpchar' and not node.typeName.typmods)): output.write('char ') output.print_node(node.arg) @@ -1595,7 +1592,7 @@ def type_name(node, output): if node.setof: # FIXME: is this used only by plpgsql? output.writes('SETOF') - name = '.'.join(n.val.value for n in node.names) + name = '.'.join(n.val for n in node.names) suffix = '' if name in system_types: prefix, suffix = system_types[name] @@ -1611,7 +1608,7 @@ def type_name(node, output): else: if node.typmods: if name == 'pg_catalog.interval': - typmod = node.typmods[0].val.val.value + typmod = node.typmods[0].val.val if typmod in interval_ranges: output.swrite(interval_ranges[typmod]) if len(node.typmods) == 2: @@ -1626,7 +1623,7 @@ def type_name(node, output): if node.arrayBounds: for ab in node.arrayBounds: output.write('[') - if ab.val.value >= 0: + if ab.val >= 0: output.print_node(ab) output.write(']') @@ -1795,21 +1792,21 @@ def window_def(node, output): def print_indirection(node, output): for idx, subnode in enumerate(node): - if subnode.node_tag == 'String': + if isinstance(subnode, ast.String): output.write('.') output.print_node(subnode, is_name=True) @node_printer(('OnConflictClause', 'UpdateStmt'), 'ResTarget') def update_stmt_res_target(node, output): - if node.val.node_tag == 'MultiAssignRef': + if isinstance(node.val, ast.MultiAssignRef): if node.val.colno == 1: output.write('( ') output.indent(-2) output.print_name(node.name) if node.indirection: print_indirection(node.indirection, output) - if node.val.colno.value == node.val.ncolumns.value: + if node.val.colno == node.val.ncolumns: output.dedent() output.write(') = ') output.print_node(node.val) @@ -1884,7 +1881,8 @@ def IS_XMLPARSE(self, node, output): # XMLPARSE(text, is_doc, preserve_ws) xml_option_type_printer(node.xmloption, node, output) arg, preserve_ws = node.args output.print_node(arg) - if preserve_ws.arg.val.val.value == 't': + if preserve_ws.arg.val.val == 't': + # FIXME: find a way to get here output.write(' PRESERVE WHITESPACE') output.write(')') @@ -1901,11 +1899,11 @@ def IS_XMLROOT(self, node, output): # XMLROOT(xml, version, standalone) xml, version, standalone = node.args output.print_node(xml) output.write(', version ') - if version.val.node_tag == 'Null': + if isinstance(version.val, ast.Null): output.write('NO VALUE') else: output.print_node(version) - xml_standalone_type_printer(standalone.val.val, node, output) + xml_standalone_type_printer(standalone.val, node, output) output.write(')') def IS_XMLSERIALIZE(self, node, output): # XMLSERIALIZE(is_document, xmlval) diff --git a/pglast/printers/sfuncs.py b/pglast/printers/sfuncs.py index 9e8964e..5f52a00 100644 --- a/pglast/printers/sfuncs.py +++ b/pglast/printers/sfuncs.py @@ -3,7 +3,7 @@ # :Created: mer 22 nov 2017 08:34:34 CET # :Author: Lele Gaifax # :License: GNU General Public License version 3 or later -# :Copyright: © 2017, 2018, 2021 Lele Gaifax +# :Copyright: © 2017, 2018, 2021, 2022 Lele Gaifax # from . import special_function @@ -49,7 +49,7 @@ def date_part(node, output): ``EXTRACT(field FROM timestamp).``. """ output.write('EXTRACT(') - output.write(node.args[0].val.val.value.upper()) + output.write(node.args[0].val.val.upper()) output.write(' FROM ') output.print_node(node.args[1]) output.write(')') @@ -75,7 +75,7 @@ def normalize(node, output): output.print_node(node.args[0]) if len(node.args) > 1: output.write(', ') - output.write(node.args[1].val.val.value.upper()) + output.write(node.args[1].val.val.upper()) output.write(')') diff --git a/pglast/stream.py b/pglast/stream.py index ff08b40..715aca6 100644 --- a/pglast/stream.py +++ b/pglast/stream.py @@ -11,10 +11,9 @@ from re import match from sys import stderr -from .node import List, Missing, Node, Scalar from . import ast, parse_plpgsql, parse_sql, visitors from .keywords import RESERVED_KEYWORDS, TYPE_FUNC_NAME_KEYWORDS -from .printers import get_printer_for_node_tag, get_special_function +from .printers import get_printer_for_node, get_special_function class OutputStream(StringIO): @@ -156,25 +155,18 @@ def __call__(self, sql, plpgsql=False): :returns: a string with the equivalent SQL obtained by serializing the syntax tree The `sql` statement may be either a ``str`` containing the ``SQL`` in textual form, or - a :class:`.node.Node` instance, or a :class:`.node.List` instance containing - :class:`.node.Node` instances, or a concrete :class:`.ast.Node` instance or a tuple of - them. + a concrete :class:`.ast.Node` instance or a tuple of them. """ from . import ast if isinstance(sql, str): - sql = Node(parse_plpgsql(sql) if plpgsql else parse_sql(sql)) - elif isinstance(sql, Node): - sql = (sql,) + sql = parse_plpgsql(sql) if plpgsql else parse_sql(sql) elif isinstance(sql, ast.Node): - sql = (Node(sql),) - elif isinstance(sql, tuple) and sql and isinstance(sql[0], ast.Node): - sql = (Node(n) for n in sql) - elif not isinstance(sql, List): + sql = (sql,) + elif not (isinstance(sql, tuple) and sql and isinstance(sql[0], ast.Node)): raise ValueError("Unexpected value for 'sql', must be either a string," - " a node.Node instance, a node.List, an ast.Node or tuple of" - " them, got %r" % type(sql)) + " an ast.Node or tuple of them, got %r" % type(sql)) class UpdateAncestors(visitors.Visitor): def visit(self, ancestors, node): @@ -277,19 +269,21 @@ def write_quoted_string(self, s): def _print_scalar(self, node, is_name, is_symbol): "Print the scalar `node`, special-casing string literals." - value = node.value - if is_symbol: + value = node.val if isinstance(node, ast.Value) else node + is_string = isinstance(value, str) + if is_symbol and is_string: self.write(value) - elif is_name: - # The `scalar` represent a name of a column/table/alias: when any of its + elif is_name and is_string: + # The `node` represent a name of a column/table/alias: when any of its # characters is not a lower case letter, a digit or underscore, it must be # double quoted + value = str(value) if ((not match(r'[a-z_][a-z0-9_]*$', value) or value in RESERVED_KEYWORDS or value in TYPE_FUNC_NAME_KEYWORDS)): value = '"%s"' % value.replace('"', '""') self.write(value) - elif isinstance(value, str): # node.parent_node.node_tag == 'String': + elif is_string: self.write_quoted_string(value) else: self.write(str(value)) @@ -317,23 +311,29 @@ def print_comment(self, comment): def print_name(self, nodes, sep='.'): "Helper method, execute :meth:`print_node` or :meth:`print_list` as needed." - if isinstance(nodes, (List, list)): + if isinstance(nodes, (list, tuple)): self.print_list(nodes, sep, standalone_items=False, are_names=True) - else: + elif isinstance(nodes, ast.Node): self.print_node(nodes, is_name=True) + else: + self._print_scalar(nodes, is_name=True, is_symbol=False) + self.separator() def print_symbol(self, nodes, sep='.'): "Helper method, execute :meth:`print_node` or :meth:`print_list` as needed." - if isinstance(nodes, (List, list)): + if isinstance(nodes, (list, tuple)): self.print_list(nodes, sep, standalone_items=False, are_names=True, is_symbol=True) - else: + elif isinstance(nodes, ast.Node): self.print_node(nodes, is_name=True, is_symbol=True) + else: + self._print_scalar(str(nodes), is_name=True, is_symbol=True) + self.separator() def print_node(self, node, is_name=False, is_symbol=False): """Lookup the specific printer for the given `node` and execute it. - :param node: an instance of :class:`~.node.Node` or :class:`~.node.Scalar` + :param node: an instance of :class:`~.ast.Node` :param bool is_name: whether this is a *name* of something, that may need to be double quoted :param bool is_symbol: @@ -346,38 +346,38 @@ def print_node(self, node, is_name=False, is_symbol=False): elif hasattr(node, 'stmt_location'): node_location = getattr(node, 'stmt_location') else: - node_location = Missing - if node_location is not Missing: + node_location = None + if node_location is not None: nextc = self.comments[0] - if nextc.location <= node_location.value: + if nextc.location <= node_location: self.print_comment(self.comments.pop(0)) while self.comments and self.comments[0].continue_previous: self.print_comment(self.comments.pop(0)) - if isinstance(node, Scalar): - self._print_scalar(node, is_name, is_symbol) - elif is_name and isinstance(node, (List, list)): + if is_name and isinstance(node, (list, tuple)): self.print_list(node, '.', standalone_items=False, are_names=True) - else: - parent_node_tag = node.parent_node and node.parent_node.node_tag - printer = get_printer_for_node_tag(parent_node_tag, node.node_tag) - if is_name and node.node_tag == 'String': + elif isinstance(node, ast.Node): + printer = get_printer_for_node(node) + if is_name and isinstance(node, ast.String): printer(node, self, is_name=is_name, is_symbol=is_symbol) else: printer(node, self) + else: + self._print_scalar(node, is_name, is_symbol) + self.separator() def _is_pg_catalog_func(self, items): return ( self.remove_pg_catalog_from_functions and len(items) > 1 - and isinstance(items, List) - and items.parent_attribute == 'funcname' - and items[0].val.value == 'pg_catalog' + and isinstance(items, (list, tuple)) + and items[0].ancestors.parent.member == 'funcname' + and items[0].val == 'pg_catalog' # The list contains all functions that cannot be found without an # explicit pg_catalog schema. ie: # position(a,b) is invalid but pg_catalog.position(a,b) is fine - and items[1].val.value not in ('position', 'xmlexists') + and items[1].val not in ('position', 'xmlexists') ) def _print_items(self, items, sep, newline, are_names=False, is_symbol=False): @@ -401,13 +401,16 @@ def _print_items(self, items, sep, newline, are_names=False, is_symbol=False): self.write(sep) if sep != '.': self.write(' ') - self.print_node(item, is_name=are_names, is_symbol=is_symbol and idx == last) + if item is None: + self.write('None') + else: + self.print_node(item, is_name=are_names, is_symbol=is_symbol and idx == last) def print_list(self, nodes, sep=',', relative_indent=None, standalone_items=None, are_names=False, is_symbol=False): """Execute :meth:`print_node` on all the `nodes`, separating them with `sep`. - :param nodes: a sequence of :class:`~.node.Node` instances or a single List node + :param nodes: a sequence of :class:`~.ast.Node` instances :param str sep: the separator between them :param bool relative_indent: if given, the relative amount of indentation to apply before the first item, by @@ -421,12 +424,6 @@ def print_list(self, nodes, sep=',', relative_indent=None, standalone_items=None case the last one must be printed verbatim (e.g. ``"MySchema".===``) """ - if isinstance(nodes, Node): # pragma: no cover - if nodes.node_tag != 'List': - raise ValueError("Unexpected value for 'nodes', must be either a List instance" - " or a sequence of Node instances, got %r" % type(nodes)) - nodes = nodes.items - if relative_indent is None: if are_names or is_symbol: relative_indent = 0 @@ -436,9 +433,8 @@ def print_list(self, nodes, sep=',', relative_indent=None, standalone_items=None else 0) if standalone_items is None: - standalone_items = not all(isinstance(n, Node) - and n.node_tag in ('A_Const', 'ColumnRef', - 'SetToDefault', 'RangeVar') + standalone_items = not all(isinstance(n, (ast.A_Const, ast.ColumnRef, + ast.SetToDefault, ast.RangeVar)) for n in nodes) with self.push_indent(relative_indent): @@ -450,7 +446,7 @@ def print_lists(self, lists, sep=',', relative_indent=None, standalone_items=Non sublist_relative_indent=None): """Execute :meth:`print_list` on all the `lists` items. - :param lists: a sequence of sequences of :class:`~.node.Node` instances + :param lists: a sequence of sequences of :class:`~.ast.Node` instances :param str sep: passed as is to :meth:`print_list` :param bool relative_indent: passed as is to :meth:`print_list` :param bool standalone_items: passed as is to :meth:`print_list` @@ -594,7 +590,7 @@ def print_list(self, nodes, sep=',', relative_indent=None, standalone_items=None are_names=False, is_symbol=False): """Execute :meth:`print_node` on all the `nodes`, separating them with `sep`. - :param nodes: a sequence of :class:`~.node.Node` instances + :param nodes: a sequence of :class:`~.ast.Node` instances :param str sep: the separator between them :param bool relative_indent: if given, the relative amount of indentation to apply before the first item, by @@ -608,12 +604,6 @@ def print_list(self, nodes, sep=',', relative_indent=None, standalone_items=None must be printed verbatim (such as ``"MySchema".===``) """ - if isinstance(nodes, Node): # pragma: no cover - if nodes.node_tag != 'List': - raise ValueError("Unexpected value for 'nodes', must be either a List instance" - " or a sequence of Node instances, got %r" % type(nodes)) - nodes = nodes.items - if standalone_items is None: clm = self.compact_lists_margin if clm is not None and clm > 0: @@ -622,10 +612,9 @@ def print_list(self, nodes, sep=',', relative_indent=None, standalone_items=None self.write(rawlist) return - standalone_items = not all( - (isinstance(n, Node) - and n.node_tag in ('A_Const', 'ColumnRef', 'SetToDefault', 'RangeVar')) - for n in nodes) + standalone_items = not all(isinstance(n, (ast.A_Const, ast.ColumnRef, + ast.SetToDefault, ast.RangeVar)) + for n in nodes) if (((sep != ',' or not self.comma_at_eoln) and len(nodes) > 1 diff --git a/tests/test_node.py b/tests/test_node.py deleted file mode 100644 index 0726944..0000000 --- a/tests/test_node.py +++ /dev/null @@ -1,92 +0,0 @@ -# -*- coding: utf-8 -*- -# :Project: pglast -- Test the node.py module -# :Created: ven 04 ago 2017 09:31:57 CEST -# :Author: Lele Gaifax -# :License: GNU General Public License version 3 or later -# :Copyright: © 2017, 2018, 2021 Lele Gaifax -# - -import pytest - -from pglast import ast, Missing, Node, parse_sql -from pglast.node import Base, List - - -def test_bad_base_construction(): - pytest.raises(ValueError, Base, {}, parent=1.0) - pytest.raises(ValueError, Base, [], name=1.0) - pytest.raises(ValueError, Base, set()) - - -def test_basic(): - root = Node(parse_sql('SELECT 1')) - assert root.parent_node is None - assert root.parent_attribute is None - assert isinstance(root, List) - assert len(root) == 1 - assert repr(root) == '[1*{RawStmt}]' - with pytest.raises(AttributeError): - root.not_there - - rawstmt = root[0] - assert rawstmt != root - assert rawstmt.node_tag == 'RawStmt' - assert isinstance(rawstmt.ast_node, ast.RawStmt) - assert rawstmt.parent_node is None - assert rawstmt.parent_attribute == (None, 0) - assert repr(rawstmt) == '{RawStmt}' - assert rawstmt.attribute_names == ('stmt', 'stmt_location', 'stmt_len') - with pytest.raises(ValueError): - rawstmt[1.0] - - stmt = rawstmt.stmt - assert stmt.node_tag == 'SelectStmt' - assert stmt.parent_node is rawstmt - assert stmt.parent_attribute == 'stmt' - assert rawstmt[stmt.parent_attribute] == stmt - assert stmt.whereClause is Missing - assert not stmt.whereClause - - -def test_scalar(): - constraint = ast.Constraint() - constraint.fk_matchtype = '\00' - node = Node(constraint) - assert not node.fk_matchtype - assert node.fk_matchtype != 1 - - -def test_traverse(): - root = Node(parse_sql('SELECT a, b, c FROM sometable')) - assert [repr(n) for n in root.traverse()] == [ - "{RawStmt}", - "{SelectStmt}", - "", - "{RangeVar}", - "", - "<20>", - "<'sometable'>", - "<'p'>", - "", - "", - "{ResTarget}", - "<7>", - "{ColumnRef}", - "{String}", - "<'a'>", - "<7>", - "{ResTarget}", - "<10>", - "{ColumnRef}", - "{String}", - "<'b'>", - "<10>", - "{ResTarget}", - "<13>", - "{ColumnRef}", - "{String}", - "<'c'>", - "<13>", - "<0>", - "<0>", - ] diff --git a/tests/test_printers.py b/tests/test_printers.py index a743b25..8e34e57 100644 --- a/tests/test_printers.py +++ b/tests/test_printers.py @@ -3,22 +3,21 @@ # :Created: sab 05 ago 2017 10:31:23 CEST # :Author: Lele Gaifax # :License: GNU General Public License version 3 or later -# :Copyright: © 2017, 2018, 2019, 2021 Lele Gaifax +# :Copyright: © 2017, 2018, 2019, 2021, 2022 Lele Gaifax # import warnings import pytest -from pglast import enums, prettify -from pglast.node import Missing, Scalar +from pglast import ast, enums, prettify from pglast.printers import IntEnumPrinter, NODE_PRINTERS, PrinterAlreadyPresentError -from pglast.printers import get_printer_for_node_tag, node_printer +from pglast.printers import get_printer_for_node, node_printer def test_registry(): - with pytest.raises(NotImplementedError): - get_printer_for_node_tag(None, 'non_existing') + with pytest.raises(ValueError): + get_printer_for_node(None) with pytest.raises(ValueError): @node_printer() @@ -35,54 +34,11 @@ def invalid_tag(node, output): def too_many_tags1(node, output): pass - with pytest.raises(ValueError): - @node_printer('one', 'two', 'three', check_tags=False) - def too_many_tags2(node, output): - pass - - try: - @node_printer('test_tag1', check_tags=False) - def tag1(node, output): - pass - - assert get_printer_for_node_tag(None, 'test_tag1') is tag1 - - with pytest.raises(PrinterAlreadyPresentError): - @node_printer('test_tag1', check_tags=False) - def tag3(node, output): - pass - - @node_printer('test_tag1', override=True, check_tags=False) - def tag1_bis(node, output): - pass - - assert get_printer_for_node_tag(None, 'test_tag1') is tag1_bis - - @node_printer('test_tag_3', check_tags=False) - def generic_tag3(node, output): - pass - - @node_printer('test_tag_1', 'test_tag_3', check_tags=False) - def specific_tag3(node, output): - pass - - @node_printer(('test_tag_a', 'test_tag_b'), 'test_tag_3', check_tags=False) - def specific_tag4(node, output): - pass - - assert get_printer_for_node_tag(None, 'test_tag_3') is generic_tag3 - assert get_printer_for_node_tag('Foo', 'test_tag_3') is generic_tag3 - assert get_printer_for_node_tag('test_tag_1', 'test_tag_3') is specific_tag3 - assert get_printer_for_node_tag('test_tag_a', 'test_tag_3') is specific_tag4 - assert get_printer_for_node_tag('test_tag_b', 'test_tag_3') is specific_tag4 - finally: - NODE_PRINTERS.pop('test_tag1', None) - def test_prettify_safety_belt(): - raw_stmt_printer = NODE_PRINTERS.pop('RawStmt', None) + raw_stmt_printer = NODE_PRINTERS.pop(ast.RawStmt, None) try: - @node_printer('RawStmt') + @node_printer(ast.RawStmt) def raw_stmt_1(node, output): output.write('Yeah') @@ -94,7 +50,7 @@ def raw_stmt_1(node, output): assert output == 'select 42' assert 'Detected a bug' in str(w[0].message) - @node_printer('RawStmt', override=True) + @node_printer(ast.RawStmt, override=True) def raw_stmt_2(node, output): output.write('select 1') @@ -107,9 +63,9 @@ def raw_stmt_2(node, output): assert 'Detected a non-cosmetic difference' in str(w[0].message) finally: if raw_stmt_printer is not None: - NODE_PRINTERS['RawStmt'] = raw_stmt_printer + NODE_PRINTERS[ast.RawStmt] = raw_stmt_printer else: - NODE_PRINTERS.pop('RawStmt', None) + NODE_PRINTERS.pop(ast.RawStmt, None) def test_int_enum_printer(): @@ -127,24 +83,12 @@ def LockWaitBlock(self, node, output): with pytest.raises(NotImplementedError): lwp('LockWaitError', object(), result) - with pytest.raises(ValueError): - lwp(None, object(), result) - with pytest.raises(ValueError): lwp('FooBar', object(), result) - lwp(Scalar('LockWaitBlock'), object(), result) + lwp(None, object(), result) assert result == ['block']*2 - with pytest.raises(ValueError): - lwp(Scalar('FooBar'), object(), result) - - lwp(Scalar(0), object(), result) - assert result == ['block']*3 - - lwp(Missing, object(), result) - assert result == ['block']*4 - def test_not_int_enum_printer(): class NotIntEnum(IntEnumPrinter): diff --git a/tests/test_printers_roundtrip.py b/tests/test_printers_roundtrip.py index 72c93d1..e8110e0 100644 --- a/tests/test_printers_roundtrip.py +++ b/tests/test_printers_roundtrip.py @@ -3,7 +3,7 @@ # :Created: dom 17 mar 2019 09:24:11 CET # :Author: Lele Gaifax # :License: GNU General Public License version 3 or later -# :Copyright: © 2019, 2021 Lele Gaifax +# :Copyright: © 2019, 2021, 2022 Lele Gaifax # from pathlib import Path @@ -11,7 +11,7 @@ import pytest -from pglast import Node, parse_sql, split +from pglast import parse_sql, split from pglast.parser import ParseError from pglast.stream import RawStream, IndentedStream import pglast.printers # noqa @@ -47,7 +47,7 @@ def test_printers_roundtrip(src, lineno, statement): except: # noqa raise RuntimeError("%s:%d:Could not parse %r" % (src, lineno, statement)) - serialized = RawStream()(Node(orig_ast)) + serialized = RawStream()(orig_ast) try: serialized_ast = parse_sql(serialized) except: # noqa @@ -55,7 +55,7 @@ def test_printers_roundtrip(src, lineno, statement): assert orig_ast == serialized_ast, "%s:%s:%r != %r" % (src, lineno, statement, serialized) - indented = IndentedStream()(Node(orig_ast)) + indented = IndentedStream()(orig_ast) try: indented_ast = parse_sql(indented) except: # noqa @@ -82,7 +82,7 @@ def test_stream_call_with_single_node(src, lineno, statement): for rawstmt in parsed: stmt = rawstmt.stmt try: - RawStream()(Node(stmt)) + RawStream()(stmt) except Exception: raise AssertionError('Could not serialize single statement %r' % stmt) @@ -129,7 +129,7 @@ def test_pg_regress_corpus(filename): % (trimmed_stmt, rel_src, lineno, e)) try: - serialized = RawStream()(Node(orig_ast)) + serialized = RawStream()(orig_ast) except NotImplementedError as e: raise NotImplementedError("Statement “%s” from %s at line %d, could not reprint: %s" % (trimmed_stmt, rel_src, lineno, e)) diff --git a/tests/test_stream.py b/tests/test_stream.py index a65f35e..8772e26 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -3,12 +3,12 @@ # :Created: sab 05 ago 2017 10:31:23 CEST # :Author: Lele Gaifax # :License: GNU General Public License version 3 or later -# :Copyright: © 2017, 2018, 2019, 2021 Lele Gaifax +# :Copyright: © 2017, 2018, 2019, 2021, 2022 Lele Gaifax # import pytest -from pglast import ast, node, parse_sql +from pglast import ast, parse_sql from pglast.printers import NODE_PRINTERS, PrinterAlreadyPresentError, SPECIAL_FUNCTIONS from pglast.printers import node_printer, special_function from pglast.stream import IndentedStream, OutputStream, RawStream @@ -26,9 +26,9 @@ def test_output_stream(): def test_raw_stream_with_sql(): - raw_stmt_printer = NODE_PRINTERS.pop('RawStmt', None) + raw_stmt_printer = NODE_PRINTERS.pop(ast.RawStmt, None) try: - @node_printer('RawStmt') + @node_printer(ast.RawStmt) def raw_stmt(node, output): output.write('Yeah') @@ -37,33 +37,15 @@ def raw_stmt(node, output): assert result == 'Yeah; Yeah' finally: if raw_stmt_printer is not None: - NODE_PRINTERS['RawStmt'] = raw_stmt_printer + NODE_PRINTERS[ast.RawStmt] = raw_stmt_printer else: - NODE_PRINTERS.pop('RawStmt', None) + NODE_PRINTERS.pop(ast.RawStmt, None) -def test_raw_stream_with_node(): - raw_stmt_printer = NODE_PRINTERS.pop('RawStmt', None) +def test_raw_stream(): + raw_stmt_printer = NODE_PRINTERS.pop(ast.RawStmt, None) try: - @node_printer('RawStmt') - def raw_stmt(node, output): - output.write('Yeah') - - root = parse_sql('SELECT 1') - output = RawStream() - result = output(node.Node(root)) - assert result == 'Yeah' - finally: - if raw_stmt_printer is not None: - NODE_PRINTERS['RawStmt'] = raw_stmt_printer - else: - NODE_PRINTERS.pop('RawStmt', None) - - -def test_raw_stream_with_ast_node(): - raw_stmt_printer = NODE_PRINTERS.pop('RawStmt', None) - try: - @node_printer('RawStmt') + @node_printer(ast.RawStmt) def raw_stmt(node, output): output.write('Yeah') @@ -71,15 +53,11 @@ def raw_stmt(node, output): output = RawStream() result = output(root) assert result == 'Yeah' - - output = RawStream() - result = output(root[0]) - assert result == 'Yeah' finally: if raw_stmt_printer is not None: - NODE_PRINTERS['RawStmt'] = raw_stmt_printer + NODE_PRINTERS[ast.RawStmt] = raw_stmt_printer else: - NODE_PRINTERS.pop('RawStmt', None) + NODE_PRINTERS.pop(ast.RawStmt, None) def test_raw_stream_invalid_call(): @@ -88,9 +66,9 @@ def test_raw_stream_invalid_call(): def test_indented_stream_with_sql(): - raw_stmt_printer = NODE_PRINTERS.pop('RawStmt', None) + raw_stmt_printer = NODE_PRINTERS.pop(ast.RawStmt, None) try: - @node_printer('RawStmt') + @node_printer(ast.RawStmt) def raw_stmt(node, output): output.write('Yeah') @@ -103,16 +81,16 @@ def raw_stmt(node, output): assert result == 'Yeah;\nYeah' finally: if raw_stmt_printer is not None: - NODE_PRINTERS['RawStmt'] = raw_stmt_printer + NODE_PRINTERS[ast.RawStmt] = raw_stmt_printer else: - NODE_PRINTERS.pop('RawStmt', None) + NODE_PRINTERS.pop(ast.RawStmt, None) def test_separate_statements(): """Separate statements by ``separate_statements`` (int) newlines.""" - raw_stmt_printer = NODE_PRINTERS.pop('RawStmt', None) + raw_stmt_printer = NODE_PRINTERS.pop(ast.RawStmt, None) try: - @node_printer('RawStmt') + @node_printer(ast.RawStmt) def raw_stmt(node, output): output.write('Yeah') @@ -121,9 +99,9 @@ def raw_stmt(node, output): assert result == 'Yeah;\n\n\nYeah' finally: if raw_stmt_printer is not None: - NODE_PRINTERS['RawStmt'] = raw_stmt_printer + NODE_PRINTERS[ast.RawStmt] = raw_stmt_printer else: - NODE_PRINTERS.pop('RawStmt', None) + NODE_PRINTERS.pop(ast.RawStmt, None) def test_special_function():