diff --git a/src/main/java/graphql/GraphQL.java b/src/main/java/graphql/GraphQL.java index 09a875c654..575dd665d5 100644 --- a/src/main/java/graphql/GraphQL.java +++ b/src/main/java/graphql/GraphQL.java @@ -683,4 +683,4 @@ private static Instrumentation checkInstrumentationDefaultState(Instrumentation } return new ChainedInstrumentation(instrumentationList); } -} \ No newline at end of file +} diff --git a/src/main/java/graphql/execution/instrumentation/dataloader/LevelMap.java b/src/main/java/graphql/execution/instrumentation/dataloader/LevelMap.java index 209eaef581..9e151ab40b 100644 --- a/src/main/java/graphql/execution/instrumentation/dataloader/LevelMap.java +++ b/src/main/java/graphql/execution/instrumentation/dataloader/LevelMap.java @@ -1,7 +1,6 @@ package graphql.execution.instrumentation.dataloader; import graphql.Internal; - import java.util.Arrays; /** @@ -28,26 +27,21 @@ public LevelMap() { } public int get(int level) { - if (level < 0) { - throw new IllegalArgumentException("negative level " + level); - } - if (level + 1 > countsByLevel.length) { - throw new IllegalArgumentException("unknown level " + level); - } + maybeResize(level); return countsByLevel[level]; } public void increment(int level, int by) { - mutatePreconditions(level); + maybeResize(level); countsByLevel[level] += by; } public void set(int level, int newValue) { - mutatePreconditions(level); + maybeResize(level); countsByLevel[level] = newValue; } - private void mutatePreconditions(int level) { + private void maybeResize(int level) { if (level < 0) { throw new IllegalArgumentException("negative level " + level); } @@ -71,4 +65,4 @@ public String toString() { public void clear() { Arrays.fill(countsByLevel, 0); } -} \ No newline at end of file +} diff --git a/src/test/groovy/graphql/execution/instrumentation/dataloader/DataLoaderDispatcherInstrumentationTest.groovy b/src/test/groovy/graphql/execution/instrumentation/dataloader/DataLoaderDispatcherInstrumentationTest.groovy index 4b8786d838..23d00ce90f 100644 --- a/src/test/groovy/graphql/execution/instrumentation/dataloader/DataLoaderDispatcherInstrumentationTest.groovy +++ b/src/test/groovy/graphql/execution/instrumentation/dataloader/DataLoaderDispatcherInstrumentationTest.groovy @@ -79,7 +79,7 @@ class DataLoaderDispatcherInstrumentationTest extends Specification { chainedInstrumentation.instrumentations.any { instr -> instr instanceof DataLoaderDispatcherInstrumentation } } - def "dispatch is never called if not data loader registry is set in"() { + def "dispatch is never called if data loader registry is not set"() { def dataLoaderRegistry = new DataLoaderRegistry() { @Override void dispatchAll() { @@ -294,4 +294,30 @@ class DataLoaderDispatcherInstrumentationTest extends Specification { er.errors.isEmpty() er.data["field"] == "working as expected" } + + def "handles deep async queries when a data loader registry is present"() { + given: + def support = new DeepDataFetchers() + def dummyDataloaderRegistry = new DataLoaderRegistry() + def batchingInstrumentation = new DataLoaderDispatcherInstrumentation() + def graphql = GraphQL.newGraphQL(support.schema()) + .instrumentation(batchingInstrumentation) + .build() + // FieldLevelTrackingApproach uses LevelMaps with a default size of 16. + // Use a value greater than 16 to ensure that the underlying LevelMaps are resized + // as expected + def depth = 50 + + when: + def asyncResult = graphql.executeAsync( + newExecutionInput() + .query(support.buildQuery(depth)) + .dataLoaderRegistry(dummyDataloaderRegistry) + ) + def er = asyncResult.join() + + then: + er.errors.isEmpty() + er.data == support.buildResponse(depth) + } } diff --git a/src/test/groovy/graphql/execution/instrumentation/dataloader/DeepDataFetchers.java b/src/test/groovy/graphql/execution/instrumentation/dataloader/DeepDataFetchers.java new file mode 100644 index 0000000000..d090425b51 --- /dev/null +++ b/src/test/groovy/graphql/execution/instrumentation/dataloader/DeepDataFetchers.java @@ -0,0 +1,72 @@ +package graphql.execution.instrumentation.dataloader; + +import java.util.HashMap; +import java.util.concurrent.CompletableFuture; +import java.util.function.Supplier; + +import graphql.schema.DataFetcher; +import graphql.schema.GraphQLFieldDefinition; +import graphql.schema.GraphQLObjectType; +import graphql.schema.GraphQLSchema; +import graphql.schema.GraphQLTypeReference; + +public class DeepDataFetchers { + private static CompletableFuture supplyAsyncWithSleep(Supplier supplier) { + Supplier sleepSome = sleepSome(supplier); + return CompletableFuture.supplyAsync(sleepSome); + } + + private static Supplier sleepSome(Supplier supplier) { + return () -> { + try { + Thread.sleep(10L); + return supplier.get(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }; + } + + public GraphQLSchema schema() { + DataFetcher>> slowFetcher = environment -> + supplyAsyncWithSleep(HashMap::new); + + GraphQLFieldDefinition selfField = GraphQLFieldDefinition.newFieldDefinition() + .name("self") + .type(GraphQLTypeReference.typeRef("Query")) + .dataFetcher(slowFetcher) + .build(); + + GraphQLObjectType query = GraphQLObjectType.newObject() + .name("Query") + .field(selfField) + .build(); + + return GraphQLSchema.newSchema().query(query).build(); + } + + public String buildQuery(Integer depth) { + StringBuilder sb = new StringBuilder(); + sb.append("query {"); + for (Integer i = 0; i < depth; i++) { + sb.append("self {"); + } + sb.append("__typename"); + for (Integer i = 0; i < depth; i++) { + sb.append("}"); + } + sb.append("}"); + + return sb.toString(); + } + + public HashMap buildResponse(Integer depth) { + HashMap level = new HashMap<>(); + if (depth == 0) { + level.put("__typename", "Query"); + } else { + level.put("self", buildResponse(depth - 1)); + } + return level; + } +} diff --git a/src/test/groovy/graphql/execution/instrumentation/dataloader/LevelMapTest.groovy b/src/test/groovy/graphql/execution/instrumentation/dataloader/LevelMapTest.groovy index 7733877124..8ff6bece5b 100644 --- a/src/test/groovy/graphql/execution/instrumentation/dataloader/LevelMapTest.groovy +++ b/src/test/groovy/graphql/execution/instrumentation/dataloader/LevelMapTest.groovy @@ -1,6 +1,7 @@ package graphql.execution.instrumentation.dataloader import spock.lang.Specification +import graphql.AssertException class LevelMapTest extends Specification { @@ -99,4 +100,15 @@ class LevelMapTest extends Specification { then: sut.toString() == "IntMap[level=0,count=42 level=1,count=1 ]" } + + def "can get outside of its size"() { + given: + LevelMap sut = new LevelMap(0) + + when: + sut.get(1) + + then: + sut.get(1) == 0 + } }