From 38ab9fe386e0192f65ffd6a4758ee52bd39dd579 Mon Sep 17 00:00:00 2001 From: Theofilos Petsios Date: Thu, 1 Dec 2022 13:08:23 +0200 Subject: [PATCH] add basic interface inheritance support --- gql_schema_codegen/block/__init__.py | 5 +- gql_schema_codegen/block/block.py | 40 ++++++++--- gql_schema_codegen/dependency/__init__.py | 13 +++- gql_schema_codegen/dependency/dependency.py | 73 ++++++++++++++++++- gql_schema_codegen/schema/schema.py | 78 +++++++++++++-------- 5 files changed, 167 insertions(+), 42 deletions(-) diff --git a/gql_schema_codegen/block/__init__.py b/gql_schema_codegen/block/__init__.py index 4f15795..0cd546d 100644 --- a/gql_schema_codegen/block/__init__.py +++ b/gql_schema_codegen/block/__init__.py @@ -1,3 +1,4 @@ -from .block import Block, BlockInfo, BlockField, BlockFieldInfo +from .block import (Block, BlockField, BlockFieldInfo, BlockInfo, + get_inheritance_tree) -__all__ = ["Block", "BlockInfo", "BlockField", "BlockFieldInfo"] +__all__ = ["Block", "BlockInfo", "BlockField", "BlockFieldInfo", "get_inheritance_tree"] diff --git a/gql_schema_codegen/block/block.py b/gql_schema_codegen/block/block.py index 1b96195..bc87494 100644 --- a/gql_schema_codegen/block/block.py +++ b/gql_schema_codegen/block/block.py @@ -1,11 +1,13 @@ import os import re -from typing import List, Literal, NamedTuple, Optional, Union +from typing import Dict, List, Literal, NamedTuple, Optional, Set, Union from ..base import BaseInfo from ..constants import RESOLVER_TYPES, VALUE_TYPES from ..constants.block_fields import all_block_fields -from ..dependency import Dependency, DependencyGroup +from ..dependency import (Dependency, DependencyGroup, + get_interface_dependencies, + remove_interface_dependencies) from ..utils import pascal_case @@ -18,6 +20,14 @@ class BlockFieldInfo(NamedTuple): value_type: str +# a dictionary where for each node, we hold its children +inheritanceTree: Dict[str, Set[str]] = {"root": set()} + + +def get_inheritance_tree(): + return inheritanceTree + + class Block(BaseInfo): def __init__(self, info, dependency_group: DependencyGroup) -> None: super().__init__(info) @@ -35,6 +45,7 @@ def display_name(self): @property def heading_file_line(self): + global inheritanceTree display_name = self.display_name if self.type == "enum": @@ -59,15 +70,28 @@ def heading_file_line(self): ) if not self.implements: + # check if we have an interface implementing another interface + deps = get_interface_dependencies() + if display_name in deps: + siblings = inheritanceTree.get(deps[display_name], set()) + siblings.add(display_name) + inheritanceTree[deps[display_name]] = siblings + return f"@dataclass(kw_only=True)\nclass {display_name}({deps[display_name]}):" + + inheritanceTree["root"].add(display_name) return ( f"@dataclass(kw_only=True)\nclass {display_name}(DataClassJSONMixin):" ) - for el in self.info.implements.split("&"): # type: ignore - self.parent_classes.add(el.strip()) - - parent = ", ".join(list(self.parent_classes)) - return f"@dataclass(kw_only=True)\nclass {display_name}({parent}):" + parents = remove_interface_dependencies( + [x.strip() for x in self.info.implements.split("&")] # type: ignore + ) + for p in parents: + siblings = inheritanceTree.get(p, set()) + siblings.add(display_name) + inheritanceTree[p] = siblings + parent_str = ", ".join(parents) + return f"@dataclass(kw_only=True)\nclass {display_name}({parent_str}):" @property def category(self): @@ -99,8 +123,8 @@ def file_representation(self): parent_fields = set() for p in self.parent_classes: parent_fields.update(all_block_fields.get(p, set())) - for f in self.fields: + # don't re-include parent fields if str(f).split(":")[1].strip() in parent_fields: continue diff --git a/gql_schema_codegen/dependency/__init__.py b/gql_schema_codegen/dependency/__init__.py index e33df82..ea13959 100644 --- a/gql_schema_codegen/dependency/__init__.py +++ b/gql_schema_codegen/dependency/__init__.py @@ -1,3 +1,12 @@ -from .dependency import Dependency, DependencyGroup +from .dependency import (Dependency, DependencyGroup, + get_interface_dependencies, + remove_interface_dependencies, + update_interface_dependencies) -__all__ = ["Dependency", "DependencyGroup"] +__all__ = [ + "Dependency", + "DependencyGroup", + "get_interface_dependencies", + "update_interface_dependencies", + "remove_interface_dependencies", +] diff --git a/gql_schema_codegen/dependency/dependency.py b/gql_schema_codegen/dependency/dependency.py index 4bb4baa..0a74546 100644 --- a/gql_schema_codegen/dependency/dependency.py +++ b/gql_schema_codegen/dependency/dependency.py @@ -1,5 +1,5 @@ from itertools import groupby -from typing import List, NamedTuple, Set +from typing import Dict, List, NamedTuple, Set class Dependency(NamedTuple): @@ -7,6 +7,77 @@ class Dependency(NamedTuple): dependency: str +INTERMEDIATE_INTERFACES: Dict[str, str] = {} + + +def get_interface_dependencies(): + return INTERMEDIATE_INTERFACES + + +def update_interface_dependencies(config_file_content): + global INTERMEDIATE_INTERFACES + if type(config_file_content) is dict: + data = config_file_content.get("interfaceInheritance") + if type(data) is dict: + INTERMEDIATE_INTERFACES = data + + +def remove_interface_dependencies( + interfaces: List[str], + intermediate_interfaces: Dict[str, str] = {}, +) -> List[str]: + """Filter all dependencies from intermediate interfaces + + Assumes that all keys in intermediate_interfaces are leaf nodes and + returns all other parent interfaces from a given list + + Intemediate interfaces in GraphQL implement other interfaces. This is part + of the spec (see https://github.com/graphql/graphql-spec/pull/373) + but not implemented in all clients yet (e.g., neo4j) + Thus we are parsing intermediate interfaces in scalars.yml so as to emit + proper python code without relying on the graphql schema being correct + + + >>> remove_interface_dependencies(['t1', 't2', 'i1', 'i2', 't3'], {'i1':'i2'}) + ['t1', 't2', 'i1', 't3'] + + >>> remove_interface_dependencies(['t1', 't2', 't3'], {}) + ['t1', 't2', 't3'] + + >>> remove_interface_dependencies(['t1', 't2', 'i1', 'i2', 'i3', 't3'], \ + {'i1':'i2', 'i2': 'i3'}) + ['t1', 't2', 'i1', 't3'] + + >>> remove_interface_dependencies(['t1', 't2', 'i1', 'i2', 'i3', \ + 'ni1', 'ni2', 't3'], \ + {'i1':'i2', 'i2': 'i3', 'ni1':'ni2'}) + ['t1', 't2', 'i1', 'ni1', 't3'] + """ + deps: Set[str] = set() + if not intermediate_interfaces: + intermediate_interfaces = INTERMEDIATE_INTERFACES + for i in interfaces: + if i not in intermediate_interfaces: + # this is not an intermediate interface, thus we are keeping it + # as a dependency + continue + + # add the parent dependency to the list of tracked dependencies: + deps.add(intermediate_interfaces[i]) + + # transitively fetch all dependencies + seen: Set[str] = set() + while deps: + d = deps.pop() + if d in seen: + continue + if d in intermediate_interfaces: + deps.add(intermediate_interfaces[d]) + seen.add(d) + + return list(filter(lambda x: x not in seen, interfaces)) + + class DependencyGroup: def __init__( self, deps: Set[Dependency] = set(), direct_deps: Set[str] = set() diff --git a/gql_schema_codegen/schema/schema.py b/gql_schema_codegen/schema/schema.py index 334e77a..6bb1653 100644 --- a/gql_schema_codegen/schema/schema.py +++ b/gql_schema_codegen/schema/schema.py @@ -2,29 +2,21 @@ import os import re import subprocess -from typing import List, Optional, Set +from typing import Dict, List, Optional, Set import yaml -from graphql import ( - build_client_schema, - build_schema, - get_introspection_query, - print_schema, -) +from graphql import (build_client_schema, build_schema, + get_introspection_query, print_schema) from graphqlclient import GraphQLClient -from ..block import Block, BlockField, BlockFieldInfo, BlockInfo -from ..constants import ( - BLOCK_PATTERN, - DIRECTIVE_PATTERN, - DIRECTIVE_USAGE_PATTERN, - FIELD_PATTERN, - RESOLVER_TYPES, - SCALAR_PATTERN, - UNION_PATTERN, -) +from ..block import (Block, BlockField, BlockFieldInfo, BlockInfo, + get_inheritance_tree) +from ..constants import (BLOCK_PATTERN, DIRECTIVE_PATTERN, + DIRECTIVE_USAGE_PATTERN, FIELD_PATTERN, + RESOLVER_TYPES, SCALAR_PATTERN, UNION_PATTERN) from ..constants.block_fields import all_block_fields -from ..dependency import Dependency, DependencyGroup +from ..dependency import (Dependency, DependencyGroup, + update_interface_dependencies) from ..scalar import ScalarInfo, ScalarType from ..union import UnionInfo, UnionType @@ -60,6 +52,8 @@ def __init__(self, **kwargs) -> None: self._import_blocks = kwargs.get("import_blocks", self._import_blocks) self._only_blocks = kwargs.get("only_blocks", self._only_blocks) + update_interface_dependencies(self.config_file_content) + self.dependency_group = DependencyGroup() @property @@ -209,7 +203,6 @@ def blocks(self): block_type = block["type"] block_name = block["name"] - all_block_fields[block_name] = set() for field in self.get_fields_from_block(block["fields"]): all_block_fields[block_name].add(field["name"]) @@ -250,11 +243,40 @@ def blocks(self): @property def sorted_blocks(self): - types_order = ["enum", "type", "param_type", "input"] - return sorted( - self.blocks, - key=lambda b: (types_order.index(b.type) if b.type in types_order else -1), - ) + # first populate inheritance tree. this is VERY dirty for now but we + # should refactor this soon. We are calling b.heading_file_line for all + # blocks here as this is what populates the tree, and we only do this + # once + all_blocks: Dict[str, Block] = {} + for b in self.blocks: + all_blocks[b.name] = b + _ = b.heading_file_line + inheritanceTree = get_inheritance_tree() + sorted_bl: List[Block] = [] + # first add enums - these have no dependencies + sorted_bl.extend(list(filter(lambda x: x.type == "enum", self.blocks))) + + types_order = ["interface", "type", "param_type", "input"] + for t in types_order: + interfaces = list(filter(lambda x: x.type == t, self.blocks)) + + to_add = {b.name for b in interfaces} + blocks: List[Block] = [] + + # add nodes in a BFS manner to ensure we don't break dependencies + queue: List[str] = ["root"] + visited: Set[str] = set(["root"]) + while queue: + node = queue.pop(0) + if node in to_add: + blocks.append(all_blocks[node]) + for child_node in inheritanceTree.get(node, []): + if child_node not in visited: + visited.add(child_node) + queue.append(child_node) + + sorted_bl.extend(blocks) + return sorted_bl @property def unions(self): @@ -302,17 +324,15 @@ def file_representation(self): lines: List[str] = ["\n" * 2] if len(self.scalars) > 0: - # lines.extend(['## Scalars'] + ['\n' * 2]) - for s in self.scalars: - lines.extend([s.file_representation] + ["\n" * 2]) + lines.extend([s.file_representation] + ["\n"]) + + lines.append("\n" * 2) if len(self.unions) > 0: self.dependency_group.add_dependency( Dependency(imported_from="typing", dependency="Union") ) - lines.extend(["## Union Types"] + ["\n" * 2]) - for u in self.unions: lines.extend([u.file_representation] + ["\n" * 2])