Skip to content

Commit

Permalink
Add support for printing dictionaries when using custom scalars
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick91 committed Jul 28, 2022
1 parent 46f5f59 commit abf6fc0
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 5 deletions.
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ per-file-ignores =
tests/types/test_string_annotations.py:E800
tests/federation/test_printer.py:E800
tests/federation/test_printer.py:E501
tests/test_printer/test_basic.py:E501
25 changes: 25 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
Release type: minor

This release adds support for printing default values for scalars like JSON.

For example the following:

```python
import strawberry
from strawberry.scalars import JSON


@strawberry.input
class MyInput:
j: JSON = strawberry.field(default_factory=dict)
j2: JSON = strawberry.field(default_factory=lambda: {"hello": "world"})
```

will print the following schema:

```graphql
input MyInput {
j: JSON! = {}
j2: JSON! = {hello: "world"}
}
```
4 changes: 4 additions & 0 deletions strawberry/printer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .printer import print_schema


__all__ = ["print_schema"]
133 changes: 133 additions & 0 deletions strawberry/printer/ast_from_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import re
from math import isfinite
from typing import Any, Mapping, Optional, cast

from graphql.language import (
BooleanValueNode,
EnumValueNode,
FloatValueNode,
IntValueNode,
ListValueNode,
NameNode,
NullValueNode,
ObjectFieldNode,
ObjectValueNode,
StringValueNode,
ValueNode,
)
from graphql.pyutils import Undefined, inspect, is_iterable
from graphql.type import (
GraphQLID,
GraphQLInputObjectType,
GraphQLInputType,
GraphQLList,
GraphQLNonNull,
is_enum_type,
is_input_object_type,
is_leaf_type,
is_list_type,
is_non_null_type,
)


__all__ = ["ast_from_value"]

_re_integer_string = re.compile("^-?(?:0|[1-9][0-9]*)$")


def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]:
# custom ast_from_value that allows to also serialize custom scalar that aren't
# basic types, namely JSON scalar types

if is_non_null_type(type_):
type_ = cast(GraphQLNonNull, type_)
ast_value = ast_from_value(value, type_.of_type)
if isinstance(ast_value, NullValueNode):
return None
return ast_value

# only explicit None, not Undefined or NaN
if value is None:
return NullValueNode()

# undefined
if value is Undefined:
return None

# Convert Python list to GraphQL list. If the GraphQLType is a list, but the value
# is not a list, convert the value using the list's item type.
if is_list_type(type_):
type_ = cast(GraphQLList, type_)
item_type = type_.of_type
if is_iterable(value):
maybe_value_nodes = (ast_from_value(item, item_type) for item in value)
value_nodes = tuple(node for node in maybe_value_nodes if node)
return ListValueNode(values=value_nodes)
return ast_from_value(value, item_type)

# Populate the fields of the input object by creating ASTs from each value in the
# Python dict according to the fields in the input type.
if is_input_object_type(type_):
if value is None or not isinstance(value, Mapping):
return None
type_ = cast(GraphQLInputObjectType, type_)
field_items = (
(field_name, ast_from_value(value[field_name], field.type))
for field_name, field in type_.fields.items()
if field_name in value
)
field_nodes = tuple(
ObjectFieldNode(name=NameNode(value=field_name), value=field_value)
for field_name, field_value in field_items
if field_value
)
return ObjectValueNode(fields=field_nodes)

if is_leaf_type(type_):
# Since value is an internally represented value, it must be serialized to an
# externally represented value before converting into an AST.
serialized = type_.serialize(value) # type: ignore
if serialized is None or serialized is Undefined:
return None

# Others serialize based on their corresponding Python scalar types.
if isinstance(serialized, bool):
return BooleanValueNode(value=serialized)

# Python ints and floats correspond nicely to Int and Float values.
if isinstance(serialized, int):
return IntValueNode(value=str(serialized))
if isinstance(serialized, float) and isfinite(serialized):
value = str(serialized)
if value.endswith(".0"):
value = value[:-2]
return FloatValueNode(value=value)

if isinstance(serialized, str):
# Enum types use Enum literals.
if is_enum_type(type_):
return EnumValueNode(value=serialized)

# ID types can use Int literals.
if type_ is GraphQLID and _re_integer_string.match(serialized):
return IntValueNode(value=serialized)

return StringValueNode(value=serialized)

if isinstance(serialized, dict):
return ObjectValueNode(
fields=[
ObjectFieldNode(
name=NameNode(value=key),
value=StringValueNode(value=value),
)
for key, value in serialized.items()
]
)

breakpoint()

raise TypeError(f"Cannot convert value to AST: {inspect(serialized)}.")

# Not reachable. All possible input types have been considered.
raise TypeError(f"Unexpected input type: {inspect(type_)}.")
28 changes: 26 additions & 2 deletions strawberry/printer.py → strawberry/printer/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@
overload,
)

from graphql import GraphQLArgument
from graphql.language.printer import print_ast
from graphql.type import (
is_enum_type,
is_input_type,
is_interface_type,
is_object_type,
is_scalar_type,
is_specified_directive,
)
from graphql.type.directives import GraphQLDirective
from graphql.utilities.ast_from_value import ast_from_value
from graphql.utilities.print_schema import (
is_defined_type,
print_args,
Expand All @@ -35,7 +36,6 @@
print_directive,
print_enum,
print_implemented_interfaces,
print_input_value,
print_scalar,
print_schema_definition,
print_type as original_print_type,
Expand All @@ -47,6 +47,8 @@
from strawberry.type import StrawberryContainer
from strawberry.unset import UNSET

from .ast_from_value import ast_from_value


if TYPE_CHECKING:
from strawberry.schema import BaseSchema
Expand Down Expand Up @@ -234,6 +236,25 @@ def _print_object(type_, schema: BaseSchema, *, extras: PrintExtras) -> str:
)


def _print_interface(type_, schema: BaseSchema, *, extras: PrintExtras) -> str:
return (
print_description(type_)
+ print_extends(type_, schema)
+ f"interface {type_.name}"
+ print_implemented_interfaces(type_)
+ print_type_directives(type_, schema, extras=extras)
+ print_fields(type_, schema, extras=extras)
)


def print_input_value(name: str, arg: GraphQLArgument) -> str:
default_ast = ast_from_value(arg.default_value, arg.type)
arg_decl = f"{name}: {arg.type}"
if default_ast:
arg_decl += f" = {print_ast(default_ast)}"
return arg_decl + print_deprecated(arg.deprecation_reason)


def _print_input_object(type_, schema: BaseSchema, *, extras: PrintExtras) -> str:
fields = [
print_description(field, " ", not i) + " " + print_input_value(name, field)
Expand Down Expand Up @@ -261,6 +282,9 @@ def _print_type(type_, schema: BaseSchema, *, extras: PrintExtras) -> str:
if is_input_type(type_):
return _print_input_object(type_, schema, extras=extras)

if is_interface_type(type_):
return _print_interface(type_, schema, extras=extras)

return original_print_type(type_)


Expand Down
11 changes: 8 additions & 3 deletions strawberry/schema/schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import lru_cache
from typing import Any, Dict, Iterable, List, Optional, Sequence, Type, Union
from typing import Any, Dict, Iterable, List, Optional, Type, Union

from graphql import (
ExecutionContext as GraphQLExecutionContext,
Expand Down Expand Up @@ -49,15 +49,20 @@ def __init__(
query: Type,
mutation: Optional[Type] = None,
subscription: Optional[Type] = None,
directives: Sequence[StrawberryDirective] = (),
directives: Iterable[StrawberryDirective] = (),
types=(),
extensions: Sequence[Union[Type[Extension], Extension]] = (),
extensions: Iterable[Union[Type[Extension], Extension]] = (),
execution_context_class: Optional[Type[GraphQLExecutionContext]] = None,
config: Optional[StrawberryConfig] = None,
scalar_overrides: Optional[
Dict[object, Union[ScalarWrapper, ScalarDefinition]]
] = None,
schema_directives: Iterable[object] = (),
):
self.query = query
self.mutation = mutation
self.subscription = subscription

self.extensions = extensions
self.execution_context_class = execution_context_class
self.config = config or StrawberryConfig()
Expand Down
Empty file.
34 changes: 34 additions & 0 deletions tests/test_printer/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import strawberry
from strawberry.printer import print_schema
from strawberry.scalars import JSON
from strawberry.schema.config import StrawberryConfig
from strawberry.unset import UNSET

Expand Down Expand Up @@ -158,6 +159,39 @@ def search(self, input: MyInput) -> int:
assert print_schema(schema) == textwrap.dedent(expected_type).strip()


def test_input_defaults_scalars():
@strawberry.input
class MyInput:
j: JSON = strawberry.field(default_factory=dict)
j2: JSON = strawberry.field(default_factory=lambda: {"hello": "world"})

@strawberry.type
class Query:
@strawberry.field
def search(self, input: MyInput) -> JSON:
return input.j

expected_type = """
\"\"\"
The `JSON` scalar type represents JSON values as specified by [ECMA-404](http://www.ecma-international.org/publications/files/ECMA-ST/ECMA-404.pdf).
\"\"\"
scalar JSON @specifiedBy(url: "http://www.ecma-international.org/publications/files/ECMA-ST/ECMA-404.pdf")
input MyInput {
j: JSON! = {}
j2: JSON! = {hello: "world"}
}
type Query {
search(input: MyInput!): JSON!
}
"""

schema = strawberry.Schema(query=Query)

assert print_schema(schema) == textwrap.dedent(expected_type).strip()


def test_interface():
@strawberry.interface
class Node:
Expand Down

0 comments on commit abf6fc0

Please sign in to comment.