Skip to content

Commit

Permalink
Merge pull request #52260 from CyrusNajmabadi/cloudcacheCrash
Browse files Browse the repository at this point in the history
Address crash in the cloud cache system when doing two single byte reads.
  • Loading branch information
CyrusNajmabadi committed Apr 1, 2021
2 parents 35997a1 + f8e3f7c commit ee26791
Show file tree
Hide file tree
Showing 6 changed files with 304 additions and 5 deletions.
2 changes: 2 additions & 0 deletions src/Tools/IdeCoreBenchmarks/IdeCoreBenchmarks.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
<ItemGroup>
<Compile Include="..\..\VisualStudio\Core\Def\Storage\AbstractCloudCachePersistentStorageService.cs" Link="CloudCache\AbstractCloudCachePersistentStorageService.cs" />
<Compile Include="..\..\VisualStudio\Core\Def\Storage\CloudCachePersistentStorage.cs" Link="CloudCache\CloudCachePersistentStorage.cs" />
<Compile Include="..\..\VisualStudio\Core\Def\Storage\Nerdbank\ReadOnlySequenceStream.cs" Link="CloudCache\ReadOnlySequenceStream.cs" />
<Compile Include="..\..\VisualStudio\Core\Def\Storage\ProjectContainerKeyCache.cs" Link="CloudCache\ProjectContainerKeyCache.cs" />
<Compile Include="..\..\VisualStudio\CSharp\Test\PersistentStorage\Mocks\AuthorizationServiceMock.cs" Link="CloudCache\AuthorizationServiceMock.cs" />
<Compile Include="..\..\VisualStudio\CSharp\Test\PersistentStorage\Mocks\FileSystemServiceMock.cs" Link="CloudCache\FileSystemServiceMock.cs" />
Expand All @@ -38,6 +39,7 @@
<PackageReference Include="Microsoft.Win32.Registry" Version="$(MicrosoftWin32RegistryVersion)" />
<PackageReference Include="Newtonsoft.Json" Version="$(NewtonsoftJsonVersion)" />
<PackageReference Include="Microsoft.VisualStudio.Cache" Version="$(MicrosoftVisualStudioCacheVersion)" />
<PackageReference Include="Nerdbank.Streams" Version="$(NerdbankStreamsVersion)" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,26 @@ public void CacheDirectoryShouldNotBeAtRoot()
Assert.False(location?.StartsWith("/") ?? false);
}

[Theory]
[CombinatorialData]
public async Task PersistentService_ReadByteTwice(Size size, bool withChecksum)
{
var solution = CreateOrOpenSolution();
var streamName1 = "PersistentService_ReadByteTwice";

await using (var storage = await GetStorageAsync(solution))
{
Assert.True(await storage.WriteStreamAsync(streamName1, EncodeString(GetData1(size)), GetChecksum1(withChecksum)));
}

await using (var storage = await GetStorageAsync(solution))
{
using var stream = await storage.ReadStreamAsync(streamName1, GetChecksum1(withChecksum));
stream.ReadByte();
stream.ReadByte();
}
}

[PartNotDiscoverable]
[ExportWorkspaceService(typeof(IPersistentStorageLocationService), layer: ServiceLayer.Test), Shared]
private class TestPersistentStorageLocationService : DefaultPersistentStorageLocationService
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@
<PackageReference Include="NuGet.VisualStudio" Version="$(NuGetVisualStudioVersion)" />
<PackageReference Include="Newtonsoft.Json" Version="$(NewtonsoftJsonVersion)" />
<PackageReference Include="System.Threading.Tasks.Dataflow" Version="$(SystemThreadingTasksDataflowVersion)" />
<PackageReference Include="Nerdbank.Streams" Version="$(NerdbankStreamsVersion)" />

<!-- By default build assets that define embedded interop types do not flow. Set PrivateAssets to none to make them flow. -->
<PackageReference Include="Microsoft.VisualStudio.SDK.EmbedInteropTypes" Version="$(MicrosoftVisualStudioSDKEmbedInteropTypesVersion)" PrivateAssets="none" />
Expand Down
22 changes: 17 additions & 5 deletions src/VisualStudio/Core/Def/Storage/CloudCachePersistentStorage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using Microsoft.CodeAnalysis.PersistentStorage;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.VisualStudio.RpcContracts.Caching;
using Nerdbank.Streams;
using Roslyn.Utilities;

namespace Microsoft.VisualStudio.LanguageServices.Storage
Expand Down Expand Up @@ -163,16 +164,27 @@ private async Task<bool> ChecksumMatchesAsync(string name, Checksum checksum, Ca
// and then pass that out. This should not be a problem in practice as PipeReader internally intelligently
// uses and pools reasonable sized buffers, preventing us from exacerbating the GC or causing LOH
// allocations.
return await AsPrebufferedStreamAsync(pipe.Reader, cancellationToken).ConfigureAwait(false);
}

private static async Task<Stream> AsPrebufferedStreamAsync(PipeReader pipeReader, CancellationToken cancellationToken = default)
{
while (true)
{
var readResult = await pipe.Reader.ReadAsync(cancellationToken).ConfigureAwait(false);
pipe.Reader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.End);
// Read and immediately report all bytes as "examined" so that the next ReadAsync call will block till more bytes come in.
// The goal here is to force the PipeReader to buffer everything internally (even if it were to exceed its natural writer threshold limit).
ReadResult readResult = await pipeReader.ReadAsync(cancellationToken).ConfigureAwait(false);
pipeReader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.End);

if (readResult.IsCompleted)
break;
{
// After having buffered and "examined" all the bytes, the stream returned from PipeReader.AsStream() would fail
// because it may not "examine" all bytes at once.
// Instead, we'll create our own Stream over just the buffer itself, and recycle the buffers when the stream is disposed
// the way the stream returned from PipeReader.AsStream() would have.
return new ReadOnlySequenceStream(readResult.Buffer, reader => ((PipeReader)reader!).Complete(), pipeReader);
}
}

return pipe.Reader.AsStream();
}

public sealed override Task<bool> WriteStreamAsync(string name, Stream stream, Checksum? checksum, CancellationToken cancellationToken)
Expand Down
262 changes: 262 additions & 0 deletions src/VisualStudio/Core/Def/Storage/Nerdbank/ReadOnlySequenceStream.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

// Copied from https://raw.githubusercontent.com/AArnott/Nerdbank.Streams/2b142fa6a38b15e4b06ecc53bf073aa49fd1de34/src/Nerdbank.Streams/ReadOnlySequenceStream.cs
// Remove once we move to Nerdbank.Streams 2.7.62-alpha

namespace Nerdbank.Streams
{
using System;
using System.Buffers;
using System.IO;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft;

internal class ReadOnlySequenceStream : Stream, IDisposableObservable
{
private static readonly Task<int> TaskOfZero = Task.FromResult(0);

private readonly Action<object?>? disposeAction;
private readonly object? disposeActionArg;

/// <summary>
/// A reusable task if two consecutive reads return the same number of bytes.
/// </summary>
private Task<int>? lastReadTask;

private readonly ReadOnlySequence<byte> readOnlySequence;

private SequencePosition position;

internal ReadOnlySequenceStream(ReadOnlySequence<byte> readOnlySequence, Action<object?>? disposeAction, object? disposeActionArg)
{
this.readOnlySequence = readOnlySequence;
this.disposeAction = disposeAction;
this.disposeActionArg = disposeActionArg;
this.position = readOnlySequence.Start;
}

/// <inheritdoc/>
public override bool CanRead => !this.IsDisposed;

/// <inheritdoc/>
public override bool CanSeek => !this.IsDisposed;

/// <inheritdoc/>
public override bool CanWrite => false;

/// <inheritdoc/>
public override long Length => this.ReturnOrThrowDisposed(this.readOnlySequence.Length);

/// <inheritdoc/>
public override long Position
{
get => this.readOnlySequence.Slice(0, this.position).Length;
set
{
Requires.Range(value >= 0, nameof(value));
this.position = this.readOnlySequence.GetPosition(value, this.readOnlySequence.Start);
}
}

/// <inheritdoc/>
public bool IsDisposed { get; private set; }

/// <inheritdoc/>
public override void Flush() => this.ThrowDisposedOr(new NotSupportedException());

/// <inheritdoc/>
public override Task FlushAsync(CancellationToken cancellationToken) => throw this.ThrowDisposedOr(new NotSupportedException());

/// <inheritdoc/>
public override int Read(byte[] buffer, int offset, int count)
{
ReadOnlySequence<byte> remaining = this.readOnlySequence.Slice(this.position);
ReadOnlySequence<byte> toCopy = remaining.Slice(0, Math.Min(count, remaining.Length));
this.position = toCopy.End;
toCopy.CopyTo(buffer.AsSpan(offset, count));
return (int)toCopy.Length;
}

/// <inheritdoc/>
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
int bytesRead = this.Read(buffer, offset, count);
if (bytesRead == 0)
{
return TaskOfZero;
}

if (this.lastReadTask?.Result == bytesRead)
{
return this.lastReadTask;
}
else
{
return this.lastReadTask = Task.FromResult(bytesRead);
}
}

/// <inheritdoc/>
public override int ReadByte()
{
ReadOnlySequence<byte> remaining = this.readOnlySequence.Slice(this.position);
if (remaining.Length > 0)
{
byte result = remaining.First.Span[0];
this.position = this.readOnlySequence.GetPosition(1, this.position);
return result;
}
else
{
return -1;
}
}

/// <inheritdoc/>
public override long Seek(long offset, SeekOrigin origin)
{
Verify.NotDisposed(this);

SequencePosition relativeTo;
switch (origin)
{
case SeekOrigin.Begin:
relativeTo = this.readOnlySequence.Start;
break;
case SeekOrigin.Current:
if (offset >= 0)
{
relativeTo = this.position;
}
else
{
relativeTo = this.readOnlySequence.Start;
offset += this.Position;
}

break;
case SeekOrigin.End:
if (offset >= 0)
{
relativeTo = this.readOnlySequence.End;
}
else
{
relativeTo = this.readOnlySequence.Start;
offset += this.Position;
}

break;
default:
throw new ArgumentOutOfRangeException(nameof(origin));
}

this.position = this.readOnlySequence.GetPosition(offset, relativeTo);
return this.Position;
}

/// <inheritdoc/>
public override void SetLength(long value) => this.ThrowDisposedOr(new NotSupportedException());

/// <inheritdoc/>
public override void Write(byte[] buffer, int offset, int count) => this.ThrowDisposedOr(new NotSupportedException());

/// <inheritdoc/>
public override void WriteByte(byte value) => this.ThrowDisposedOr(new NotSupportedException());

/// <inheritdoc/>
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) => throw this.ThrowDisposedOr(new NotSupportedException());

/// <inheritdoc/>
public override async Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
{
foreach (var segment in this.readOnlySequence)
{
await WriteAsync(destination, segment, cancellationToken).ConfigureAwait(false);
}
}

private static ValueTask WriteAsync(Stream stream, ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
Requires.NotNull(stream, nameof(stream));

if (MemoryMarshal.TryGetArray(buffer, out ArraySegment<byte> array))
{
return new ValueTask(stream.WriteAsync(array.Array!, array.Offset, array.Count, cancellationToken));
}
else
{
byte[] sharedBuffer = ArrayPool<byte>.Shared.Rent(buffer.Length);
buffer.Span.CopyTo(sharedBuffer);
return new ValueTask(FinishWriteAsync(stream.WriteAsync(sharedBuffer, 0, buffer.Length, cancellationToken), sharedBuffer));
}

async Task FinishWriteAsync(Task writeTask, byte[] localBuffer)
{
try
{
await writeTask.ConfigureAwait(false);
}
finally
{
ArrayPool<byte>.Shared.Return(localBuffer);
}
}
}

#if SPAN_BUILTIN

/// <inheritdoc/>
public override int Read(Span<byte> buffer)
{
ReadOnlySequence<byte> remaining = this.readOnlySequence.Slice(this.position);
ReadOnlySequence<byte> toCopy = remaining.Slice(0, Math.Min(buffer.Length, remaining.Length));
this.position = toCopy.End;
toCopy.CopyTo(buffer);
return (int)toCopy.Length;
}

/// <inheritdoc/>
public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
return new ValueTask<int>(this.Read(buffer.Span));
}

/// <inheritdoc/>
public override void Write(ReadOnlySpan<byte> buffer) => throw this.ThrowDisposedOr(new NotSupportedException());

/// <inheritdoc/>
public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default) => throw this.ThrowDisposedOr(new NotSupportedException());

#endif

/// <inheritdoc/>
protected override void Dispose(bool disposing)
{
if (!this.IsDisposed)
{
this.IsDisposed = true;
this.disposeAction?.Invoke(this.disposeActionArg);
base.Dispose(disposing);
}
}

private T ReturnOrThrowDisposed<T>(T value)
{
Verify.NotDisposed(this);
return value;
}

private Exception ThrowDisposedOr(Exception ex)
{
Verify.NotDisposed(this);
throw ex;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
<PackageReference Include="Microsoft.VisualStudio.Threading" Version="$(MicrosoftVisualStudioThreadingVersion)" />
<PackageReference Include="Microsoft.ServiceHub.Framework" Version="$(MicrosoftServiceHubFrameworkVersion)" />
<PackageReference Include="Microsoft.VisualStudio.RpcContracts" Version="$(MicrosoftVisualStudioRpcContractsVersion)" />
<PackageReference Include="Nerdbank.Streams" Version="$(NerdbankStreamsVersion)" />
</ItemGroup>
<ItemGroup>
<PublicAPI Include="PublicAPI.Shipped.txt" />
Expand All @@ -39,6 +40,7 @@
<Compile Include="..\..\..\VisualStudio\Core\Def\Implementation\Watson\WatsonReporter.cs" Link="Services\WorkspaceTelemetry\WatsonReporter.cs" />
<Compile Include="..\..\..\VisualStudio\Core\Def\Storage\AbstractCloudCachePersistentStorageService.cs" Link="Host\Storage\AbstractCloudCachePersistentStorageService.cs" />
<Compile Include="..\..\..\VisualStudio\Core\Def\Storage\CloudCachePersistentStorage.cs" Link="Host\Storage\CloudCachePersistentStorage.cs" />
<Compile Include="..\..\..\VisualStudio\Core\Def\Storage\Nerdbank\ReadOnlySequenceStream.cs" Link="Host\Storage\ReadOnlySequenceStream.cs" />
<Compile Include="..\..\..\VisualStudio\Core\Def\Storage\ProjectContainerKeyCache.cs" Link="Host\Storage\ProjectContainerKeyCache.cs" />
<Compile Include="..\..\..\VisualStudio\Core\Def\Telemetry\VSTelemetryCache.cs" Link="Services\WorkspaceTelemetry\VSTelemetryCache.cs" />
<Compile Include="..\..\..\VisualStudio\Core\Def\Telemetry\VSTelemetryLogger.cs" Link="Services\WorkspaceTelemetry\VSTelemetryLogger.cs" />
Expand Down

0 comments on commit ee26791

Please sign in to comment.