Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ValueTask stream overloads on SNI streams and Waits #902

Merged
merged 5 commits into from Jul 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -344,6 +344,7 @@
<Compile Include="Microsoft\Data\SqlClient\SqlDiagnosticListener.NetStandard.cs" />
<Compile Include="Microsoft\Data\SqlClient\SqlDelegatedTransaction.NetStandard.cs" />
<Compile Include="Microsoft\Data\SqlClient\TdsParser.NetStandard.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIStreams.Task.cs" />
cheenamalhotra marked this conversation as resolved.
Show resolved Hide resolved
<Compile Include="Microsoft\Data\SqlClient\SNI\SslOverTdsStream.NetStandard.cs" />
</ItemGroup>
<ItemGroup Condition="'$(OSGroup)' != 'AnyOS' AND '$(TargetFramework)' != 'netstandard2.0'">
Expand Down Expand Up @@ -393,6 +394,8 @@
<Compile Include="Microsoft\Data\SqlClient\SNI\SslOverTdsStream.NetCoreApp.cs" />
<Compile Include="Microsoft\Data\SqlClient\SqlConnectionFactory.AssemblyLoadContext.cs" />
<Compile Include="Microsoft\Data\SqlClient\SqlDependencyUtils.AssemblyLoadContext.cs" />
<Compile Condition="$(TargetFramework.StartsWith('netcoreapp2.'))" Include="Microsoft\Data\SqlClient\SNI\SNIStreams.Task.cs" />
<Compile Condition="!$(TargetFramework.StartsWith('netcoreapp2.'))" Include="Microsoft\Data\SqlClient\SNI\SNIStreams.ValueTask.cs" />
cheenamalhotra marked this conversation as resolved.
Show resolved Hide resolved
</ItemGroup>
<ItemGroup Condition="'$(OSGroup)' != 'AnyOS' AND '$(TargetFramework)' != 'netstandard2.0' AND '$(BuildSimulator)' == 'true'">
<Compile Include="Microsoft\Data\SqlClient\SimulatorEnclaveProvider.NetCoreApp.cs" />
Expand Down Expand Up @@ -522,6 +525,7 @@
<Compile Include="$(CommonPath)\CoreLib\System\Threading\Tasks\TaskToApm.cs">
<Link>Common\CoreLib\System\Threading\Tasks\TaskToApm.cs</Link>
</Compile>
<Compile Include="Microsoft\Data\SqlClient\SNI\ConcurrentQueueSemaphore.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIError.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIHandle.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNILoadHandle.cs" />
Expand All @@ -533,8 +537,8 @@
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIPacketPool.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIPhysicalHandle.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIProxy.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNITcpHandle.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIStreams.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNITcpHandle.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SslOverTdsStream.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNICommon.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SspiClientContextStatus.cs" />
Expand Down Expand Up @@ -826,6 +830,7 @@
</ItemGroup>
<ItemGroup>
<Compile Include="Microsoft\Data\SqlClient\AAsyncCallContext.cs" />

<Compile Include="Resources\Strings.Designer.cs">
<DesignTime>True</DesignTime>
<AutoGen>True</AutoGen>
Expand Down
@@ -0,0 +1,60 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Concurrent;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Data.SqlClient.SNI
{
/// <summary>
/// This class implements a FIFO Queue with SemaphoreSlim for ordered execution of parallel tasks.
/// Currently used in Managed SNI (SNISslStream) to override SslStream's WriteAsync implementation.
/// </summary>
internal sealed partial class ConcurrentQueueSemaphore
{
private readonly SemaphoreSlim _semaphore;
private readonly ConcurrentQueue<TaskCompletionSource<bool>> _queue;

public ConcurrentQueueSemaphore(int initialCount)
{
_semaphore = new SemaphoreSlim(initialCount);
_queue = new ConcurrentQueue<TaskCompletionSource<bool>>();
}

public Task WaitAsync(CancellationToken cancellationToken)
{
// try sync wait with 0 which will not block to see if we need to do an async wait
if (_semaphore.Wait(0, cancellationToken))
{
return Task.CompletedTask;
}
else
{
var tcs = new TaskCompletionSource<bool>();
_queue.Enqueue(tcs);
_semaphore.WaitAsync().ContinueWith(
continuationAction: static (Task task, object state) =>
{
ConcurrentQueue<TaskCompletionSource<bool>> queue = (ConcurrentQueue<TaskCompletionSource<bool>>)state;
if (queue.TryDequeue(out TaskCompletionSource<bool> popped))
{
popped.SetResult(true);
}
},
state: _queue,
cancellationToken: cancellationToken
);
return tcs.Task;
}
}

public void Release()
{
_semaphore.Release();
}
}

}
@@ -0,0 +1,73 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Data.SqlClient.SNI
{
// NetCore2.1:
// DO NOT OVERRIDE ValueTask versions of ReadAsync and WriteAsync because the underlying SslStream implements them
// by calling the Task versions which are already overridden meaning that if a caller uses Task WriteAsync this would
// call ValueTask WriteAsync which then called TaskWriteAsync introducing a lock cycle and never return

internal sealed partial class SNISslStream
{
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
return await base.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_readAsyncSemaphore.Release();
}
}

public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
await base.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_writeAsyncSemaphore.Release();
}
}
}

internal sealed partial class SNINetworkStream
{
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
return await base.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_readAsyncSemaphore.Release();
}
}

// Prevent the WriteAsync collisions by running the task in a Semaphore Slim
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
await base.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_writeAsyncSemaphore.Release();
}
}
}
}
@@ -0,0 +1,89 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Threading;
using System.Threading.Tasks;
using System;

namespace Microsoft.Data.SqlClient.SNI
{
internal sealed partial class SNISslStream
{
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return ReadAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();
}

public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
return await base.ReadAsync(buffer, cancellationToken).ConfigureAwait(false);
}
finally
{
_readAsyncSemaphore.Release();
}
}

public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return WriteAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();
}

public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
await base.WriteAsync(buffer, cancellationToken).ConfigureAwait(false);
}
finally
{
_writeAsyncSemaphore.Release();
}
}
}

internal sealed partial class SNINetworkStream
{
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return ReadAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();
}

public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
return await base.ReadAsync(buffer, cancellationToken).ConfigureAwait(false);
}
finally
{
_readAsyncSemaphore.Release();
}
}

// Prevent the WriteAsync collisions by running the task in a Semaphore Slim
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return WriteAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();
}

public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
await base.WriteAsync(buffer, cancellationToken).ConfigureAwait(false);
}
finally
{
_writeAsyncSemaphore.Release();
}
}
}
}
Expand Up @@ -4,16 +4,14 @@

using System.Net.Security;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using System.Net.Sockets;

namespace Microsoft.Data.SqlClient.SNI
{
/// <summary>
Wraith2 marked this conversation as resolved.
Show resolved Hide resolved
/// This class extends SslStream to customize stream behavior for Managed SNI implementation.
/// </summary>
internal class SNISslStream : SslStream
internal sealed partial class SNISslStream : SslStream
{
private readonly ConcurrentQueueSemaphore _writeAsyncSemaphore;
private readonly ConcurrentQueueSemaphore _readAsyncSemaphore;
Expand All @@ -24,40 +22,12 @@ public SNISslStream(Stream innerStream, bool leaveInnerStreamOpen, RemoteCertifi
_writeAsyncSemaphore = new ConcurrentQueueSemaphore(1);
_readAsyncSemaphore = new ConcurrentQueueSemaphore(1);
}

// Prevent ReadAsync collisions by running the task in a Semaphore Slim
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
return await base.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_readAsyncSemaphore.Release();
}
}

// Prevent the WriteAsync collisions by running the task in a Semaphore Slim
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
await base.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_writeAsyncSemaphore.Release();
}
}
}

/// <summary>
/// This class extends NetworkStream to customize stream behavior for Managed SNI implementation.
/// </summary>
internal class SNINetworkStream : NetworkStream
internal sealed partial class SNINetworkStream : NetworkStream
{
private readonly ConcurrentQueueSemaphore _writeAsyncSemaphore;
private readonly ConcurrentQueueSemaphore _readAsyncSemaphore;
Expand All @@ -67,33 +37,5 @@ public SNINetworkStream(Socket socket, bool ownsSocket) : base(socket, ownsSocke
_writeAsyncSemaphore = new ConcurrentQueueSemaphore(1);
_readAsyncSemaphore = new ConcurrentQueueSemaphore(1);
}

// Prevent ReadAsync collisions by running the task in a Semaphore Slim
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
return await base.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_readAsyncSemaphore.Release();
}
}

// Prevent the WriteAsync collisions by running the task in a Semaphore Slim
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
await base.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_writeAsyncSemaphore.Release();
}
}
}
}