Skip to content

Commit

Permalink
Add ValueTask stream overloads on SNI streams and Waits (dotnet#902)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wraith2 authored and Davoud Eshtehari committed Aug 31, 2021
1 parent e6b1f9e commit 1ee0cba
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 114 deletions.
Expand Up @@ -344,6 +344,7 @@
<Compile Include="Microsoft\Data\SqlClient\SqlDiagnosticListener.NetStandard.cs" />
<Compile Include="Microsoft\Data\SqlClient\SqlDelegatedTransaction.NetStandard.cs" />
<Compile Include="Microsoft\Data\SqlClient\TdsParser.NetStandard.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIStreams.Task.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SslOverTdsStream.NetStandard.cs" />
</ItemGroup>
<ItemGroup Condition="'$(OSGroup)' != 'AnyOS' AND '$(TargetFramework)' != 'netstandard2.0'">
Expand Down Expand Up @@ -393,6 +394,8 @@
<Compile Include="Microsoft\Data\SqlClient\SNI\SslOverTdsStream.NetCoreApp.cs" />
<Compile Include="Microsoft\Data\SqlClient\SqlConnectionFactory.AssemblyLoadContext.cs" />
<Compile Include="Microsoft\Data\SqlClient\SqlDependencyUtils.AssemblyLoadContext.cs" />
<Compile Condition="$(TargetFramework.StartsWith('netcoreapp2.'))" Include="Microsoft\Data\SqlClient\SNI\SNIStreams.Task.cs" />
<Compile Condition="!$(TargetFramework.StartsWith('netcoreapp2.'))" Include="Microsoft\Data\SqlClient\SNI\SNIStreams.ValueTask.cs" />
</ItemGroup>
<ItemGroup Condition="'$(OSGroup)' != 'AnyOS' AND '$(TargetFramework)' != 'netstandard2.0' AND '$(BuildSimulator)' == 'true'">
<Compile Include="Microsoft\Data\SqlClient\SimulatorEnclaveProvider.NetCoreApp.cs" />
Expand Down Expand Up @@ -522,6 +525,7 @@
<Compile Include="$(CommonPath)\CoreLib\System\Threading\Tasks\TaskToApm.cs">
<Link>Common\CoreLib\System\Threading\Tasks\TaskToApm.cs</Link>
</Compile>
<Compile Include="Microsoft\Data\SqlClient\SNI\ConcurrentQueueSemaphore.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIError.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIHandle.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNILoadHandle.cs" />
Expand All @@ -533,8 +537,8 @@
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIPacketPool.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIPhysicalHandle.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIProxy.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNITcpHandle.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIStreams.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNITcpHandle.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SslOverTdsStream.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNICommon.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SspiClientContextStatus.cs" />
Expand Down Expand Up @@ -826,6 +830,7 @@
</ItemGroup>
<ItemGroup>
<Compile Include="Microsoft\Data\SqlClient\AAsyncCallContext.cs" />

<Compile Include="Resources\Strings.Designer.cs">
<DesignTime>True</DesignTime>
<AutoGen>True</AutoGen>
Expand Down
@@ -0,0 +1,60 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Concurrent;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Data.SqlClient.SNI
{
/// <summary>
/// This class implements a FIFO Queue with SemaphoreSlim for ordered execution of parallel tasks.
/// Currently used in Managed SNI (SNISslStream) to override SslStream's WriteAsync implementation.
/// </summary>
internal sealed partial class ConcurrentQueueSemaphore
{
private readonly SemaphoreSlim _semaphore;
private readonly ConcurrentQueue<TaskCompletionSource<bool>> _queue;

public ConcurrentQueueSemaphore(int initialCount)
{
_semaphore = new SemaphoreSlim(initialCount);
_queue = new ConcurrentQueue<TaskCompletionSource<bool>>();
}

public Task WaitAsync(CancellationToken cancellationToken)
{
// try sync wait with 0 which will not block to see if we need to do an async wait
if (_semaphore.Wait(0, cancellationToken))
{
return Task.CompletedTask;
}
else
{
var tcs = new TaskCompletionSource<bool>();
_queue.Enqueue(tcs);
_semaphore.WaitAsync().ContinueWith(
continuationAction: static (Task task, object state) =>
{
ConcurrentQueue<TaskCompletionSource<bool>> queue = (ConcurrentQueue<TaskCompletionSource<bool>>)state;
if (queue.TryDequeue(out TaskCompletionSource<bool> popped))
{
popped.SetResult(true);
}
},
state: _queue,
cancellationToken: cancellationToken
);
return tcs.Task;
}
}

public void Release()
{
_semaphore.Release();
}
}

}
@@ -0,0 +1,73 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Data.SqlClient.SNI
{
// NetCore2.1:
// DO NOT OVERRIDE ValueTask versions of ReadAsync and WriteAsync because the underlying SslStream implements them
// by calling the Task versions which are already overridden meaning that if a caller uses Task WriteAsync this would
// call ValueTask WriteAsync which then called TaskWriteAsync introducing a lock cycle and never return

internal sealed partial class SNISslStream
{
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
return await base.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_readAsyncSemaphore.Release();
}
}

public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
await base.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_writeAsyncSemaphore.Release();
}
}
}

internal sealed partial class SNINetworkStream
{
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
return await base.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_readAsyncSemaphore.Release();
}
}

// Prevent the WriteAsync collisions by running the task in a Semaphore Slim
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
await base.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_writeAsyncSemaphore.Release();
}
}
}
}
@@ -0,0 +1,89 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Threading;
using System.Threading.Tasks;
using System;

namespace Microsoft.Data.SqlClient.SNI
{
internal sealed partial class SNISslStream
{
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return ReadAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();
}

public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
return await base.ReadAsync(buffer, cancellationToken).ConfigureAwait(false);
}
finally
{
_readAsyncSemaphore.Release();
}
}

public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return WriteAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();
}

public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
await base.WriteAsync(buffer, cancellationToken).ConfigureAwait(false);
}
finally
{
_writeAsyncSemaphore.Release();
}
}
}

internal sealed partial class SNINetworkStream
{
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return ReadAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();
}

public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
return await base.ReadAsync(buffer, cancellationToken).ConfigureAwait(false);
}
finally
{
_readAsyncSemaphore.Release();
}
}

// Prevent the WriteAsync collisions by running the task in a Semaphore Slim
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return WriteAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();
}

public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
await base.WriteAsync(buffer, cancellationToken).ConfigureAwait(false);
}
finally
{
_writeAsyncSemaphore.Release();
}
}
}
}
Expand Up @@ -4,16 +4,14 @@

using System.Net.Security;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using System.Net.Sockets;

namespace Microsoft.Data.SqlClient.SNI
{
/// <summary>
/// This class extends SslStream to customize stream behavior for Managed SNI implementation.
/// </summary>
internal class SNISslStream : SslStream
internal sealed partial class SNISslStream : SslStream
{
private readonly ConcurrentQueueSemaphore _writeAsyncSemaphore;
private readonly ConcurrentQueueSemaphore _readAsyncSemaphore;
Expand All @@ -24,40 +22,12 @@ public SNISslStream(Stream innerStream, bool leaveInnerStreamOpen, RemoteCertifi
_writeAsyncSemaphore = new ConcurrentQueueSemaphore(1);
_readAsyncSemaphore = new ConcurrentQueueSemaphore(1);
}

// Prevent ReadAsync collisions by running the task in a Semaphore Slim
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
return await base.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_readAsyncSemaphore.Release();
}
}

// Prevent the WriteAsync collisions by running the task in a Semaphore Slim
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
await base.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_writeAsyncSemaphore.Release();
}
}
}

/// <summary>
/// This class extends NetworkStream to customize stream behavior for Managed SNI implementation.
/// </summary>
internal class SNINetworkStream : NetworkStream
internal sealed partial class SNINetworkStream : NetworkStream
{
private readonly ConcurrentQueueSemaphore _writeAsyncSemaphore;
private readonly ConcurrentQueueSemaphore _readAsyncSemaphore;
Expand All @@ -67,33 +37,5 @@ public SNINetworkStream(Socket socket, bool ownsSocket) : base(socket, ownsSocke
_writeAsyncSemaphore = new ConcurrentQueueSemaphore(1);
_readAsyncSemaphore = new ConcurrentQueueSemaphore(1);
}

// Prevent ReadAsync collisions by running the task in a Semaphore Slim
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
return await base.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_readAsyncSemaphore.Release();
}
}

// Prevent the WriteAsync collisions by running the task in a Semaphore Slim
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
await base.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_writeAsyncSemaphore.Release();
}
}
}
}

0 comments on commit 1ee0cba

Please sign in to comment.