diff --git a/starlette/graphql.py b/starlette/graphql.py index d3a64279e..c44eb27d8 100644 --- a/starlette/graphql.py +++ b/starlette/graphql.py @@ -10,13 +10,22 @@ try: import graphene - from graphql.execution.executors.asyncio import AsyncioExecutor - from graphql.error import format_error as format_graphql_error - from graphql.error import GraphQLError except ImportError: # pragma: nocover graphene = None + +try: + from graphql.execution.executors.asyncio import AsyncioExecutor +except ImportError: # pragma: nocover AsyncioExecutor = None # type: ignore + +try: + from graphql.error import format_error as format_graphql_error +except ImportError: # pragma: nocover format_graphql_error = None # type: ignore + +try: + from graphql.error import GraphQLError +except ImportError: # pragma: nocover GraphQLError = None # type: ignore diff --git a/tests/test_graphql.py b/tests/test_graphql.py index 25f4a5bae..02c8c4072 100644 --- a/tests/test_graphql.py +++ b/tests/test_graphql.py @@ -1,5 +1,6 @@ import graphene from graphql.execution.executors.asyncio import AsyncioExecutor +from graphql.error import GraphQLError from starlette.applications import Starlette from starlette.datastructures import Headers @@ -20,6 +21,7 @@ async def __call__(self, scope, receive, send): class Query(graphene.ObjectType): hello = graphene.String(name=graphene.String(default_value="stranger")) whoami = graphene.String() + graphql_error = graphene.String() def resolve_hello(self, info, name): return "Hello " + name @@ -31,6 +33,9 @@ def resolve_whoami(self, info): else info.context["request"]["user"] ) + def resolve_graphql_error(self, info): + raise GraphQLError("GraphQL Error") + schema = graphene.Schema(query=Query) app = GraphQLApp(schema=schema, graphiql=True) @@ -101,6 +106,13 @@ def test_graphiql_get(): assert "" in response.text +def test_graphql_error(): + response = client.get("/?query={ graphqlError }") + assert response.status_code == 400 + assert response.json()["data"] == {"graphqlError": None} + assert response.json()["errors"][0]["message"] == "GraphQL Error" + + def test_graphiql_not_found(): app = GraphQLApp(schema=schema, graphiql=False) client = TestClient(app)