Skip to content

Commit

Permalink
Query: Fixes stack overflow in SkipEmptyPageQueryPipelineStage (#2851)
Browse files Browse the repository at this point in the history
* Refactor SkipEmptyPageQueryPipelineStage implementation from recursive to iterative

* Add more test cases

Co-authored-by: j82w <j82w@users.noreply.github.com>
  • Loading branch information
neildsh and j82w committed Nov 4, 2021
1 parent 0b78b3d commit 5dcaf20
Show file tree
Hide file tree
Showing 2 changed files with 275 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,88 +44,95 @@ public async ValueTask<bool> MoveNextAsync(ITrace trace)
throw new ArgumentNullException(nameof(trace));
}

if (!await this.inputStage.MoveNextAsync(trace))
for (int documentCount = 0; documentCount == 0;)
{
if (!this.returnedFinalStats)
if (!await this.inputStage.MoveNextAsync(trace))
{
QueryPage queryPage = new QueryPage(
documents: EmptyPage,
requestCharge: this.cumulativeRequestCharge,
activityId: Guid.Empty.ToString(),
responseLengthInBytes: this.cumulativeResponseLengthInBytes,
cosmosQueryExecutionInfo: default,
disallowContinuationTokenMessage: default,
additionalHeaders: this.cumulativeAdditionalHeaders,
state: default);
this.cumulativeRequestCharge = 0;
this.cumulativeResponseLengthInBytes = 0;
this.cumulativeAdditionalHeaders = null;
this.returnedFinalStats = true;
this.Current = TryCatch<QueryPage>.FromResult(queryPage);
return true;
if (!this.returnedFinalStats)
{
QueryPage queryPage = new QueryPage(
documents: EmptyPage,
requestCharge: this.cumulativeRequestCharge,
activityId: Guid.Empty.ToString(),
responseLengthInBytes: this.cumulativeResponseLengthInBytes,
cosmosQueryExecutionInfo: default,
disallowContinuationTokenMessage: default,
additionalHeaders: this.cumulativeAdditionalHeaders,
state: default);
this.cumulativeRequestCharge = 0;
this.cumulativeResponseLengthInBytes = 0;
this.cumulativeAdditionalHeaders = null;
this.returnedFinalStats = true;
this.Current = TryCatch<QueryPage>.FromResult(queryPage);
return true;
}

this.Current = default;
return false;
}

this.Current = default;
return false;
}

TryCatch<QueryPage> tryGetSourcePage = this.inputStage.Current;
if (tryGetSourcePage.Failed)
{
this.Current = tryGetSourcePage;
return true;
}

QueryPage sourcePage = tryGetSourcePage.Result;
if (sourcePage.Documents.Count == 0)
{
if (sourcePage.State == null)
// if we are here then it means the inputStage told us there's more pages
// so we tell the same thing to our consumer
TryCatch<QueryPage> tryGetSourcePage = this.inputStage.Current;
if (tryGetSourcePage.Failed)
{
QueryPage queryPage = new QueryPage(
documents: EmptyPage,
requestCharge: sourcePage.RequestCharge + this.cumulativeRequestCharge,
activityId: sourcePage.ActivityId,
responseLengthInBytes: sourcePage.ResponseLengthInBytes + this.cumulativeResponseLengthInBytes,
cosmosQueryExecutionInfo: sourcePage.CosmosQueryExecutionInfo,
disallowContinuationTokenMessage: sourcePage.DisallowContinuationTokenMessage,
additionalHeaders: sourcePage.AdditionalHeaders,
state: default);
this.cumulativeRequestCharge = 0;
this.cumulativeResponseLengthInBytes = 0;
this.cumulativeAdditionalHeaders = null;
this.Current = TryCatch<QueryPage>.FromResult(queryPage);
this.Current = tryGetSourcePage;
return true;
}

this.cumulativeRequestCharge += sourcePage.RequestCharge;
this.cumulativeResponseLengthInBytes += sourcePage.ResponseLengthInBytes;
this.cumulativeAdditionalHeaders = sourcePage.AdditionalHeaders;

return await this.MoveNextAsync(trace);
}

QueryPage cumulativeQueryPage;
if (this.cumulativeRequestCharge != 0)
{
cumulativeQueryPage = new QueryPage(
documents: sourcePage.Documents,
requestCharge: sourcePage.RequestCharge + this.cumulativeRequestCharge,
activityId: sourcePage.ActivityId,
responseLengthInBytes: sourcePage.ResponseLengthInBytes + this.cumulativeResponseLengthInBytes,
cosmosQueryExecutionInfo: sourcePage.CosmosQueryExecutionInfo,
disallowContinuationTokenMessage: sourcePage.DisallowContinuationTokenMessage,
additionalHeaders: sourcePage.AdditionalHeaders,
state: sourcePage.State);
this.cumulativeRequestCharge = 0;
this.cumulativeResponseLengthInBytes = 0;
this.cumulativeAdditionalHeaders = null;
}
else
{
cumulativeQueryPage = sourcePage;
QueryPage sourcePage = tryGetSourcePage.Result;
documentCount = sourcePage.Documents.Count;
if (documentCount == 0)
{
if (sourcePage.State == null)
{
QueryPage queryPage = new QueryPage(
documents: EmptyPage,
requestCharge: sourcePage.RequestCharge + this.cumulativeRequestCharge,
activityId: sourcePage.ActivityId,
responseLengthInBytes: sourcePage.ResponseLengthInBytes + this.cumulativeResponseLengthInBytes,
cosmosQueryExecutionInfo: sourcePage.CosmosQueryExecutionInfo,
disallowContinuationTokenMessage: sourcePage.DisallowContinuationTokenMessage,
additionalHeaders: sourcePage.AdditionalHeaders,
state: default);
this.cumulativeRequestCharge = 0;
this.cumulativeResponseLengthInBytes = 0;
this.cumulativeAdditionalHeaders = null;
this.Current = TryCatch<QueryPage>.FromResult(queryPage);
return true;
}

this.cumulativeRequestCharge += sourcePage.RequestCharge;
this.cumulativeResponseLengthInBytes += sourcePage.ResponseLengthInBytes;
this.cumulativeAdditionalHeaders = sourcePage.AdditionalHeaders;
}
else
{
QueryPage cumulativeQueryPage;
if (this.cumulativeRequestCharge != 0)
{
cumulativeQueryPage = new QueryPage(
documents: sourcePage.Documents,
requestCharge: sourcePage.RequestCharge + this.cumulativeRequestCharge,
activityId: sourcePage.ActivityId,
responseLengthInBytes: sourcePage.ResponseLengthInBytes + this.cumulativeResponseLengthInBytes,
cosmosQueryExecutionInfo: sourcePage.CosmosQueryExecutionInfo,
disallowContinuationTokenMessage: sourcePage.DisallowContinuationTokenMessage,
additionalHeaders: sourcePage.AdditionalHeaders,
state: sourcePage.State);
this.cumulativeRequestCharge = 0;
this.cumulativeResponseLengthInBytes = 0;
this.cumulativeAdditionalHeaders = null;
}
else
{
cumulativeQueryPage = sourcePage;
}

this.Current = TryCatch<QueryPage>.FromResult(cumulativeQueryPage);
}
}

this.Current = TryCatch<QueryPage>.FromResult(cumulativeQueryPage);
return true;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
//------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
//------------------------------------------------------------

namespace Microsoft.Azure.Cosmos.Tests.Query.Pipeline
{
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.Cosmos.CosmosElements;
using Microsoft.Azure.Cosmos.Query.Core.Monads;
using Microsoft.Azure.Cosmos.Query.Core.Pipeline;
using Microsoft.Azure.Cosmos.Query.Core.Pipeline.Pagination;
using Microsoft.Azure.Cosmos.Tracing;
using Microsoft.VisualStudio.TestTools.UnitTesting;

[TestClass]
public class SkipEmptyPageQueryPipelineStageTests
{
[TestMethod]
public async Task StackOverflowTest()
{
await using IQueryPipelineStage pipeline = CreatePipeline(Enumerable
.Repeat(EmptyPagePipelineStage.PageType.Empty, 2000)
.Concat(Enumerable.Repeat(EmptyPagePipelineStage.PageType.Error, 1))
.ToList());
bool hasNext = await pipeline.MoveNextAsync(NoOpTrace.Singleton);
Assert.IsTrue(hasNext);
TryCatch<QueryPage> result = pipeline.Current;
Assert.IsFalse(result.Succeeded);
}

[TestMethod]
public async Task BasicTests()
{
IReadOnlyList<TestCase> testCases = new List<TestCase>()
{
MakeTest(
input: Enumerable
.Repeat(EmptyPagePipelineStage.PageType.Empty, 2000)
.Concat(Enumerable.Repeat(EmptyPagePipelineStage.PageType.NonEmpty, 1)),
expected: Enumerable.Repeat(true, 1)),
MakeTest(
input: Enumerable
.Repeat(EmptyPagePipelineStage.PageType.Empty, 100)
.Concat(Enumerable.Repeat(EmptyPagePipelineStage.PageType.NonEmpty, 5))
.Concat(Enumerable.Repeat(EmptyPagePipelineStage.PageType.Empty, 27))
.Concat(Enumerable.Repeat(EmptyPagePipelineStage.PageType.NonEmpty, 3))
.Concat(Enumerable.Repeat(EmptyPagePipelineStage.PageType.Empty, 32))
.Concat(Enumerable.Repeat(EmptyPagePipelineStage.PageType.NonEmpty, 3)),
expected: Enumerable.Repeat(true, 11)),
MakeTest(
input: Enumerable
.Repeat(EmptyPagePipelineStage.PageType.Empty, 100)
.Concat(Enumerable.Repeat(EmptyPagePipelineStage.PageType.NonEmpty, 5))
.Concat(Enumerable.Repeat(EmptyPagePipelineStage.PageType.Empty, 27))
.Concat(Enumerable.Repeat(EmptyPagePipelineStage.PageType.Error, 3))
.Concat(Enumerable.Repeat(EmptyPagePipelineStage.PageType.Empty, 32))
.Concat(Enumerable.Repeat(EmptyPagePipelineStage.PageType.NonEmpty, 3)),
expected: Enumerable.Repeat(true, 5)
.Concat(Enumerable.Repeat(false, 3))
.Concat(Enumerable.Repeat(true, 3))),
MakeTest(
input: Enumerable.Repeat(EmptyPagePipelineStage.PageType.NonEmpty, 500),
expected: Enumerable.Repeat(true, 500)),
MakeTest(
input: Enumerable.Repeat(EmptyPagePipelineStage.PageType.Error, 500),
expected: Enumerable.Repeat(false, 500))
};

foreach (TestCase testCase in testCases)
{
await using IQueryPipelineStage pipeline = CreatePipeline(testCase.Input);
for (int index = 0; index < testCase.Expected.Count; ++index)
{
Assert.IsTrue(await pipeline.MoveNextAsync(NoOpTrace.Singleton));

if (testCase.Expected[index])
{
Assert.IsTrue(pipeline.Current.Succeeded);
Assert.AreEqual(1, pipeline.Current.Result.Documents.Count);
Assert.AreEqual("42", pipeline.Current.Result.Documents[0].ToString());
}
else
{
Assert.IsTrue(pipeline.Current.Failed);
}
}
}
}

internal static TestCase MakeTest(IEnumerable<EmptyPagePipelineStage.PageType> input, IEnumerable<bool> expected)
{
return new TestCase(input.ToList(), expected.ToList());
}

internal readonly struct TestCase
{
public TestCase(IReadOnlyList<EmptyPagePipelineStage.PageType> input, IReadOnlyList<bool> expected)
{
this.Input = input;
this.Expected = expected;
}

public readonly IReadOnlyList<EmptyPagePipelineStage.PageType> Input { get; }

public readonly IReadOnlyList<bool> Expected { get; }
}

private static IQueryPipelineStage CreatePipeline(IReadOnlyList<EmptyPagePipelineStage.PageType> pages)
{
EmptyPagePipelineStage emptyPagePipelineStage = new EmptyPagePipelineStage(pages);
SkipEmptyPageQueryPipelineStage skipEmptyPageStage = new SkipEmptyPageQueryPipelineStage(
inputStage: emptyPagePipelineStage,
cancellationToken: default);

return new CatchAllQueryPipelineStage(inputStage: skipEmptyPageStage, cancellationToken: default);
}

internal class EmptyPagePipelineStage : IQueryPipelineStage
{
public enum PageType { Empty, NonEmpty, Error };

private static readonly TryCatch<QueryPage> Empty = TryCatch<QueryPage>.FromResult(new QueryPage(
documents: new List<CosmosElement>(),
requestCharge: 42,
activityId: Guid.NewGuid().ToString(),
responseLengthInBytes: "[]".Length,
cosmosQueryExecutionInfo: default,
disallowContinuationTokenMessage: default,
additionalHeaders: default,
state: new QueryState(CosmosString.Create("Empty"))));

private static readonly TryCatch<QueryPage> NonEmpty = TryCatch<QueryPage>.FromResult(new QueryPage(
documents: new List<CosmosElement> { CosmosElement.Parse("42") },
requestCharge: 100,
activityId: Guid.NewGuid().ToString(),
responseLengthInBytes: "[42]".Length,
cosmosQueryExecutionInfo: default,
disallowContinuationTokenMessage: default,
additionalHeaders: default,
state: new QueryState(CosmosString.Create("NonEmpty"))));

private readonly IReadOnlyList<PageType> pages;

private int current;

public EmptyPagePipelineStage(IReadOnlyList<PageType> pages)
{
this.current = -1;
this.pages = pages;
}

public TryCatch<QueryPage> Current { get; private set; }

public ValueTask DisposeAsync()
{
return new ValueTask();
}

public ValueTask<bool> MoveNextAsync(ITrace trace)
{
++this.current;
if (this.current >= this.pages.Count)
{
return new ValueTask<bool>(false);
}

switch (this.pages[this.current])
{
case PageType.Empty:
this.Current = Empty;
break;
case PageType.NonEmpty:
this.Current = NonEmpty;
break;
case PageType.Error:
throw new CosmosException(
message: "Injected failure",
statusCode: System.Net.HttpStatusCode.InternalServerError,
subStatusCode: 0,
activityId: Guid.Empty.ToString(),
requestCharge: 0);
}

return new ValueTask<bool>(true);
}

public void SetCancellationToken(CancellationToken cancellationToken)
{
}
}
}
}

0 comments on commit 5dcaf20

Please sign in to comment.