Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for printing dictionaries when using custom scalars #2048

Merged
merged 15 commits into from
Jul 29, 2022
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ per-file-ignores =
tests/types/test_string_annotations.py:E800
tests/federation/test_printer.py:E800
tests/federation/test_printer.py:E501
tests/test_printer/test_basic.py:E501
25 changes: 25 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
Release type: minor

This release adds support for printing default values for scalars like JSON.

For example the following:

```python
import strawberry
from strawberry.scalars import JSON


@strawberry.input
class MyInput:
j: JSON = strawberry.field(default_factory=dict)
j2: JSON = strawberry.field(default_factory=lambda: {"hello": "world"})
```
BryceBeagle marked this conversation as resolved.
Show resolved Hide resolved

will print the following schema:

```graphql
input MyInput {
j: JSON! = {}
j2: JSON! = {hello: "world"}
}
```
4 changes: 4 additions & 0 deletions strawberry/printer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .printer import print_schema


__all__ = ["print_schema"]
145 changes: 145 additions & 0 deletions strawberry/printer/ast_from_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import dataclasses
import re
from math import isfinite
from typing import Any, Mapping, Optional, cast

from graphql.language import (
BooleanValueNode,
EnumValueNode,
FloatValueNode,
IntValueNode,
ListValueNode,
NameNode,
NullValueNode,
ObjectFieldNode,
ObjectValueNode,
StringValueNode,
ValueNode,
)
from graphql.pyutils import Undefined, inspect, is_iterable
from graphql.type import (
GraphQLID,
GraphQLInputObjectType,
GraphQLInputType,
GraphQLList,
GraphQLNonNull,
is_enum_type,
is_input_object_type,
is_leaf_type,
is_list_type,
is_non_null_type,
)


__all__ = ["ast_from_value"]

_re_integer_string = re.compile("^-?(?:0|[1-9][0-9]*)$")


def ast_from_leaf_type(
serialized: object, type_: Optional[GraphQLInputType]
) -> ValueNode:
# Others serialize based on their corresponding Python scalar types.
if isinstance(serialized, bool):
return BooleanValueNode(value=serialized)

# Python ints and floats correspond nicely to Int and Float values.
if isinstance(serialized, int):
return IntValueNode(value=str(serialized))
if isinstance(serialized, float) and isfinite(serialized):
value = str(serialized)
if value.endswith(".0"):
value = value[:-2]
BryceBeagle marked this conversation as resolved.
Show resolved Hide resolved
return FloatValueNode(value=value)

if isinstance(serialized, str):
# Enum types use Enum literals.
if type_ and is_enum_type(type_):
return EnumValueNode(value=serialized)

# ID types can use Int literals.
if type_ is GraphQLID and _re_integer_string.match(serialized):
return IntValueNode(value=serialized)

return StringValueNode(value=serialized)

if isinstance(serialized, dict):
return ObjectValueNode(
fields=[
ObjectFieldNode(
name=NameNode(value=key),
value=ast_from_leaf_type(value, None),
)
for key, value in serialized.items()
]
)

raise TypeError(
f"Cannot convert value to AST: {inspect(serialized)}."
) # pragma: no cover


def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]:
# custom ast_from_value that allows to also serialize custom scalar that aren't
# basic types, namely JSON scalar types

if is_non_null_type(type_):
type_ = cast(GraphQLNonNull, type_)
ast_value = ast_from_value(value, type_.of_type)
if isinstance(ast_value, NullValueNode):
return None
return ast_value

# only explicit None, not Undefined or NaN
if value is None:
return NullValueNode()

# undefined
if value is Undefined:
return None

# Convert Python list to GraphQL list. If the GraphQLType is a list, but the value
# is not a list, convert the value using the list's item type.
if is_list_type(type_):
type_ = cast(GraphQLList, type_)
item_type = type_.of_type
if is_iterable(value):
maybe_value_nodes = (ast_from_value(item, item_type) for item in value)
value_nodes = tuple(node for node in maybe_value_nodes if node)
return ListValueNode(values=value_nodes)
return ast_from_value(value, item_type)

# Populate the fields of the input object by creating ASTs from each value in the
# Python dict according to the fields in the input type.
if is_input_object_type(type_):
# TODO: is this the right place?
if hasattr(value, "_type_definition"):
value = dataclasses.asdict(value)

if value is None or not isinstance(value, Mapping):
return None

type_ = cast(GraphQLInputObjectType, type_)
field_items = (
(field_name, ast_from_value(value[field_name], field.type))
for field_name, field in type_.fields.items()
if field_name in value
)
field_nodes = tuple(
ObjectFieldNode(name=NameNode(value=field_name), value=field_value)
for field_name, field_value in field_items
if field_value
)
return ObjectValueNode(fields=field_nodes)

if is_leaf_type(type_):
# Since value is an internally represented value, it must be serialized to an
# externally represented value before converting into an AST.
serialized = type_.serialize(value) # type: ignore
if serialized is None or serialized is Undefined:
return None # pragma: no cover

return ast_from_leaf_type(serialized, type_)

# Not reachable. All possible input types have been considered.
raise TypeError(f"Unexpected input type: {inspect(type_)}.") # pragma: no cover
53 changes: 50 additions & 3 deletions strawberry/printer.py → strawberry/printer/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,25 @@
overload,
)

from graphql import GraphQLArgument
from graphql.language.printer import print_ast
from graphql.type import (
is_enum_type,
is_input_type,
is_interface_type,
is_object_type,
is_scalar_type,
is_specified_directive,
)
from graphql.type.directives import GraphQLDirective
from graphql.utilities.ast_from_value import ast_from_value
from graphql.utilities.print_schema import (
is_defined_type,
print_args,
print_block,
print_deprecated,
print_description,
print_directive,
print_enum,
print_implemented_interfaces,
print_input_value,
print_scalar,
print_schema_definition,
print_type as original_print_type,
Expand All @@ -47,6 +46,8 @@
from strawberry.type import StrawberryContainer
from strawberry.unset import UNSET

from .ast_from_value import ast_from_value


if TYPE_CHECKING:
from strawberry.schema import BaseSchema
Expand Down Expand Up @@ -161,6 +162,30 @@ def print_field_directives(
)


def print_args(args: Dict[str, GraphQLArgument], indentation: str = "") -> str:
if not args:
return ""

# If every arg does not have a description, print them on one line.
if not any(arg.description for arg in args.values()):
return (
"("
+ ", ".join(print_input_value(name, arg) for name, arg in args.items())
+ ")"
)

return (
"(\n"
+ "\n".join(
print_description(arg, f" {indentation}", not i)
+ f" {indentation}"
+ print_input_value(name, arg)
for i, (name, arg) in enumerate(args.items())
)
+ f"\n{indentation})"
)
BryceBeagle marked this conversation as resolved.
Show resolved Hide resolved


def print_fields(type_, schema: BaseSchema, *, extras: PrintExtras) -> str:
fields = []

Expand Down Expand Up @@ -234,6 +259,25 @@ def _print_object(type_, schema: BaseSchema, *, extras: PrintExtras) -> str:
)


def _print_interface(type_, schema: BaseSchema, *, extras: PrintExtras) -> str:
return (
print_description(type_)
+ print_extends(type_, schema)
+ f"interface {type_.name}"
+ print_implemented_interfaces(type_)
+ print_type_directives(type_, schema, extras=extras)
+ print_fields(type_, schema, extras=extras)
)


def print_input_value(name: str, arg: GraphQLArgument) -> str:
default_ast = ast_from_value(arg.default_value, arg.type)
arg_decl = f"{name}: {arg.type}"
if default_ast:
arg_decl += f" = {print_ast(default_ast)}"
return arg_decl + print_deprecated(arg.deprecation_reason)


def _print_input_object(type_, schema: BaseSchema, *, extras: PrintExtras) -> str:
fields = [
print_description(field, " ", not i) + " " + print_input_value(name, field)
Expand Down Expand Up @@ -261,6 +305,9 @@ def _print_type(type_, schema: BaseSchema, *, extras: PrintExtras) -> str:
if is_input_type(type_):
return _print_input_object(type_, schema, extras=extras)

if is_interface_type(type_):
return _print_interface(type_, schema, extras=extras)

return original_print_type(type_)


Expand Down
8 changes: 3 additions & 5 deletions strawberry/schema/schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import lru_cache
from typing import Any, Dict, Iterable, List, Optional, Sequence, Type, Union
from typing import Any, Dict, Iterable, List, Optional, Type, Union

from graphql import (
ExecutionContext as GraphQLExecutionContext,
Expand Down Expand Up @@ -49,9 +49,9 @@ def __init__(
query: Type,
mutation: Optional[Type] = None,
subscription: Optional[Type] = None,
directives: Sequence[StrawberryDirective] = (),
directives: Iterable[StrawberryDirective] = (),
types=(),
extensions: Sequence[Union[Type[Extension], Extension]] = (),
extensions: Iterable[Union[Type[Extension], Extension]] = (),
execution_context_class: Optional[Type[GraphQLExecutionContext]] = None,
config: Optional[StrawberryConfig] = None,
scalar_overrides: Optional[
Expand Down Expand Up @@ -117,8 +117,6 @@ def __init__(
formatted_errors = "\n\n".join(f"❌ {error.message}" for error in errors)
raise ValueError(f"Invalid Schema. Errors:\n\n{formatted_errors}")

self.query = self.schema_converter.type_map[query_type.name]

def get_extensions(
self, sync: bool = False
) -> List[Union[Type[Extension], Extension]]:
Expand Down
Empty file.