Skip to content

Commit

Permalink
add basic interface inheritance support
Browse files Browse the repository at this point in the history
  • Loading branch information
nettrino committed Dec 1, 2022
1 parent 0c4621e commit 38ab9fe
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 42 deletions.
5 changes: 3 additions & 2 deletions gql_schema_codegen/block/__init__.py
@@ -1,3 +1,4 @@
from .block import Block, BlockInfo, BlockField, BlockFieldInfo
from .block import (Block, BlockField, BlockFieldInfo, BlockInfo,
get_inheritance_tree)

__all__ = ["Block", "BlockInfo", "BlockField", "BlockFieldInfo"]
__all__ = ["Block", "BlockInfo", "BlockField", "BlockFieldInfo", "get_inheritance_tree"]
40 changes: 32 additions & 8 deletions gql_schema_codegen/block/block.py
@@ -1,11 +1,13 @@
import os
import re
from typing import List, Literal, NamedTuple, Optional, Union
from typing import Dict, List, Literal, NamedTuple, Optional, Set, Union

from ..base import BaseInfo
from ..constants import RESOLVER_TYPES, VALUE_TYPES
from ..constants.block_fields import all_block_fields
from ..dependency import Dependency, DependencyGroup
from ..dependency import (Dependency, DependencyGroup,
get_interface_dependencies,
remove_interface_dependencies)
from ..utils import pascal_case


Expand All @@ -18,6 +20,14 @@ class BlockFieldInfo(NamedTuple):
value_type: str


# a dictionary where for each node, we hold its children
inheritanceTree: Dict[str, Set[str]] = {"root": set()}


def get_inheritance_tree():
return inheritanceTree


class Block(BaseInfo):
def __init__(self, info, dependency_group: DependencyGroup) -> None:
super().__init__(info)
Expand All @@ -35,6 +45,7 @@ def display_name(self):

@property
def heading_file_line(self):
global inheritanceTree
display_name = self.display_name

if self.type == "enum":
Expand All @@ -59,15 +70,28 @@ def heading_file_line(self):
)

if not self.implements:
# check if we have an interface implementing another interface
deps = get_interface_dependencies()
if display_name in deps:
siblings = inheritanceTree.get(deps[display_name], set())
siblings.add(display_name)
inheritanceTree[deps[display_name]] = siblings
return f"@dataclass(kw_only=True)\nclass {display_name}({deps[display_name]}):"

inheritanceTree["root"].add(display_name)
return (
f"@dataclass(kw_only=True)\nclass {display_name}(DataClassJSONMixin):"
)

for el in self.info.implements.split("&"): # type: ignore
self.parent_classes.add(el.strip())

parent = ", ".join(list(self.parent_classes))
return f"@dataclass(kw_only=True)\nclass {display_name}({parent}):"
parents = remove_interface_dependencies(
[x.strip() for x in self.info.implements.split("&")] # type: ignore
)
for p in parents:
siblings = inheritanceTree.get(p, set())
siblings.add(display_name)
inheritanceTree[p] = siblings
parent_str = ", ".join(parents)
return f"@dataclass(kw_only=True)\nclass {display_name}({parent_str}):"

@property
def category(self):
Expand Down Expand Up @@ -99,8 +123,8 @@ def file_representation(self):
parent_fields = set()
for p in self.parent_classes:
parent_fields.update(all_block_fields.get(p, set()))

for f in self.fields:
# don't re-include parent fields
if str(f).split(":")[1].strip() in parent_fields:
continue

Expand Down
13 changes: 11 additions & 2 deletions gql_schema_codegen/dependency/__init__.py
@@ -1,3 +1,12 @@
from .dependency import Dependency, DependencyGroup
from .dependency import (Dependency, DependencyGroup,
get_interface_dependencies,
remove_interface_dependencies,
update_interface_dependencies)

__all__ = ["Dependency", "DependencyGroup"]
__all__ = [
"Dependency",
"DependencyGroup",
"get_interface_dependencies",
"update_interface_dependencies",
"remove_interface_dependencies",
]
73 changes: 72 additions & 1 deletion gql_schema_codegen/dependency/dependency.py
@@ -1,12 +1,83 @@
from itertools import groupby
from typing import List, NamedTuple, Set
from typing import Dict, List, NamedTuple, Set


class Dependency(NamedTuple):
imported_from: str
dependency: str


INTERMEDIATE_INTERFACES: Dict[str, str] = {}


def get_interface_dependencies():
return INTERMEDIATE_INTERFACES


def update_interface_dependencies(config_file_content):
global INTERMEDIATE_INTERFACES
if type(config_file_content) is dict:
data = config_file_content.get("interfaceInheritance")
if type(data) is dict:
INTERMEDIATE_INTERFACES = data


def remove_interface_dependencies(
interfaces: List[str],
intermediate_interfaces: Dict[str, str] = {},
) -> List[str]:
"""Filter all dependencies from intermediate interfaces
Assumes that all keys in intermediate_interfaces are leaf nodes and
returns all other parent interfaces from a given list
Intemediate interfaces in GraphQL implement other interfaces. This is part
of the spec (see https://github.com/graphql/graphql-spec/pull/373)
but not implemented in all clients yet (e.g., neo4j)
Thus we are parsing intermediate interfaces in scalars.yml so as to emit
proper python code without relying on the graphql schema being correct
>>> remove_interface_dependencies(['t1', 't2', 'i1', 'i2', 't3'], {'i1':'i2'})
['t1', 't2', 'i1', 't3']
>>> remove_interface_dependencies(['t1', 't2', 't3'], {})
['t1', 't2', 't3']
>>> remove_interface_dependencies(['t1', 't2', 'i1', 'i2', 'i3', 't3'], \
{'i1':'i2', 'i2': 'i3'})
['t1', 't2', 'i1', 't3']
>>> remove_interface_dependencies(['t1', 't2', 'i1', 'i2', 'i3', \
'ni1', 'ni2', 't3'], \
{'i1':'i2', 'i2': 'i3', 'ni1':'ni2'})
['t1', 't2', 'i1', 'ni1', 't3']
"""
deps: Set[str] = set()
if not intermediate_interfaces:
intermediate_interfaces = INTERMEDIATE_INTERFACES
for i in interfaces:
if i not in intermediate_interfaces:
# this is not an intermediate interface, thus we are keeping it
# as a dependency
continue

# add the parent dependency to the list of tracked dependencies:
deps.add(intermediate_interfaces[i])

# transitively fetch all dependencies
seen: Set[str] = set()
while deps:
d = deps.pop()
if d in seen:
continue
if d in intermediate_interfaces:
deps.add(intermediate_interfaces[d])
seen.add(d)

return list(filter(lambda x: x not in seen, interfaces))


class DependencyGroup:
def __init__(
self, deps: Set[Dependency] = set(), direct_deps: Set[str] = set()
Expand Down
78 changes: 49 additions & 29 deletions gql_schema_codegen/schema/schema.py
Expand Up @@ -2,29 +2,21 @@
import os
import re
import subprocess
from typing import List, Optional, Set
from typing import Dict, List, Optional, Set

import yaml
from graphql import (
build_client_schema,
build_schema,
get_introspection_query,
print_schema,
)
from graphql import (build_client_schema, build_schema,
get_introspection_query, print_schema)
from graphqlclient import GraphQLClient

from ..block import Block, BlockField, BlockFieldInfo, BlockInfo
from ..constants import (
BLOCK_PATTERN,
DIRECTIVE_PATTERN,
DIRECTIVE_USAGE_PATTERN,
FIELD_PATTERN,
RESOLVER_TYPES,
SCALAR_PATTERN,
UNION_PATTERN,
)
from ..block import (Block, BlockField, BlockFieldInfo, BlockInfo,
get_inheritance_tree)
from ..constants import (BLOCK_PATTERN, DIRECTIVE_PATTERN,
DIRECTIVE_USAGE_PATTERN, FIELD_PATTERN,
RESOLVER_TYPES, SCALAR_PATTERN, UNION_PATTERN)
from ..constants.block_fields import all_block_fields
from ..dependency import Dependency, DependencyGroup
from ..dependency import (Dependency, DependencyGroup,
update_interface_dependencies)
from ..scalar import ScalarInfo, ScalarType
from ..union import UnionInfo, UnionType

Expand Down Expand Up @@ -60,6 +52,8 @@ def __init__(self, **kwargs) -> None:
self._import_blocks = kwargs.get("import_blocks", self._import_blocks)
self._only_blocks = kwargs.get("only_blocks", self._only_blocks)

update_interface_dependencies(self.config_file_content)

self.dependency_group = DependencyGroup()

@property
Expand Down Expand Up @@ -209,7 +203,6 @@ def blocks(self):

block_type = block["type"]
block_name = block["name"]

all_block_fields[block_name] = set()
for field in self.get_fields_from_block(block["fields"]):
all_block_fields[block_name].add(field["name"])
Expand Down Expand Up @@ -250,11 +243,40 @@ def blocks(self):

@property
def sorted_blocks(self):
types_order = ["enum", "type", "param_type", "input"]
return sorted(
self.blocks,
key=lambda b: (types_order.index(b.type) if b.type in types_order else -1),
)
# first populate inheritance tree. this is VERY dirty for now but we
# should refactor this soon. We are calling b.heading_file_line for all
# blocks here as this is what populates the tree, and we only do this
# once
all_blocks: Dict[str, Block] = {}
for b in self.blocks:
all_blocks[b.name] = b
_ = b.heading_file_line
inheritanceTree = get_inheritance_tree()
sorted_bl: List[Block] = []
# first add enums - these have no dependencies
sorted_bl.extend(list(filter(lambda x: x.type == "enum", self.blocks)))

types_order = ["interface", "type", "param_type", "input"]
for t in types_order:
interfaces = list(filter(lambda x: x.type == t, self.blocks))

to_add = {b.name for b in interfaces}
blocks: List[Block] = []

# add nodes in a BFS manner to ensure we don't break dependencies
queue: List[str] = ["root"]
visited: Set[str] = set(["root"])
while queue:
node = queue.pop(0)
if node in to_add:
blocks.append(all_blocks[node])
for child_node in inheritanceTree.get(node, []):
if child_node not in visited:
visited.add(child_node)
queue.append(child_node)

sorted_bl.extend(blocks)
return sorted_bl

@property
def unions(self):
Expand Down Expand Up @@ -302,17 +324,15 @@ def file_representation(self):
lines: List[str] = ["\n" * 2]

if len(self.scalars) > 0:
# lines.extend(['## Scalars'] + ['\n' * 2])

for s in self.scalars:
lines.extend([s.file_representation] + ["\n" * 2])
lines.extend([s.file_representation] + ["\n"])

lines.append("\n" * 2)

if len(self.unions) > 0:
self.dependency_group.add_dependency(
Dependency(imported_from="typing", dependency="Union")
)
lines.extend(["## Union Types"] + ["\n" * 2])

for u in self.unions:
lines.extend([u.file_representation] + ["\n" * 2])

Expand Down

0 comments on commit 38ab9fe

Please sign in to comment.