Skip to content

Commit

Permalink
Add ValueTask stream overloads on SNI streams
Browse files Browse the repository at this point in the history
  • Loading branch information
Wraith2 committed Mar 2, 2021
1 parent ca2fe25 commit 466062b
Show file tree
Hide file tree
Showing 8 changed files with 307 additions and 104 deletions.
Expand Up @@ -282,7 +282,9 @@
<Compile Include="Microsoft\Data\SqlClient\SqlDiagnosticListener.NetStandard.cs" />
<Compile Include="Microsoft\Data\SqlClient\SqlDelegatedTransaction.NetStandard.cs" />
<Compile Include="Microsoft\Data\SqlClient\TdsParser.NetStandard.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\ConcurrentQueueSemaphore.NetStandard.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SslOverTdsStream.NetStandard.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIStreams.NetStandard.cs" />
</ItemGroup>
<ItemGroup Condition="'$(OSGroup)' != 'AnyOS' AND '$(TargetFramework)' != 'netstandard2.0'">
<Compile Include="..\..\src\Microsoft\Data\SqlClient\AlwaysEncryptedAttestationException.cs">
Expand Down Expand Up @@ -314,6 +316,8 @@
<Compile Include="Microsoft\Data\SqlClient\SNI\SslOverTdsStream.NetCoreApp.cs" />
<Compile Include="Microsoft\Data\SqlClient\SqlConnectionFactory.AssemblyLoadContext.cs" />
<Compile Include="Microsoft\Data\SqlClient\SqlDependencyUtils.AssemblyLoadContext.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\ConcurrentQueueSemaphore.NetCoreApp.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIStreams.NetCoreApp.cs" />
</ItemGroup>
<ItemGroup Condition="'$(OSGroup)' != 'AnyOS' AND '$(TargetGroup)' == 'netcoreapp' AND '$(BuildSimulator)' == 'true'">
<Compile Include="Microsoft\Data\SqlClient\SimulatorEnclaveProvider.NetCoreApp.cs" />
Expand Down Expand Up @@ -441,6 +445,7 @@
<Compile Include="$(CommonPath)\CoreLib\System\Threading\Tasks\TaskToApm.cs">
<Link>Common\CoreLib\System\Threading\Tasks\TaskToApm.cs</Link>
</Compile>
<Compile Include="Microsoft\Data\SqlClient\SNI\ConcurrentQueueSemaphore.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIError.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIHandle.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNILoadHandle.cs" />
Expand All @@ -452,8 +457,8 @@
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIPacketPool.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIPhysicalHandle.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIProxy.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNITcpHandle.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIStreams.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNITcpHandle.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SslOverTdsStream.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNICommon.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SspiClientContextStatus.cs" />
Expand Down
@@ -0,0 +1,32 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Data.SqlClient.SNI
{
internal sealed partial class ConcurrentQueueSemaphore
{
public ValueTask WaitAsync(CancellationToken cancellationToken)
{
// try sync wait with 0 which will not block to see if we need to do an async wait
if (_semaphore.Wait(0, cancellationToken))
{
return new ValueTask();
}
else
{
var tcs = new TaskCompletionSource<bool>();
_queue.Enqueue(tcs);
_semaphore.WaitAsync().ContinueWith(
continuationAction: s_continuePop,
state: _queue,
cancellationToken: cancellationToken
);
return new ValueTask(tcs.Task);
}
}
}
}
@@ -0,0 +1,24 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Data.SqlClient.SNI
{
internal sealed partial class ConcurrentQueueSemaphore
{
public Task WaitAsync(CancellationToken cancellationToken)
{
var tcs = new TaskCompletionSource<bool>();
_queue.Enqueue(tcs);
_semaphore.WaitAsync().ContinueWith(
continuationAction: s_continuePop,
state: _queue,
cancellationToken: cancellationToken
);
return tcs.Task;
}
}
}
@@ -0,0 +1,44 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Concurrent;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Data.SqlClient.SNI
{
/// <summary>
/// This class implements a FIFO Queue with SemaphoreSlim for ordered execution of parallel tasks.
/// Currently used in Managed SNI (SNISslStream) to override SslStream's WriteAsync implementation.
/// </summary>
internal sealed partial class ConcurrentQueueSemaphore
{
private static readonly Action<Task, object> s_continuePop = ContinuePop;

private readonly SemaphoreSlim _semaphore;
private readonly ConcurrentQueue<TaskCompletionSource<bool>> _queue =
new ConcurrentQueue<TaskCompletionSource<bool>>();

public ConcurrentQueueSemaphore(int initialCount)
{
_semaphore = new SemaphoreSlim(initialCount);
}

public void Release()
{
_semaphore.Release();
}

private static void ContinuePop(Task task, object state)
{
ConcurrentQueue<TaskCompletionSource<bool>> queue = (ConcurrentQueue<TaskCompletionSource<bool>>)state;
if (queue.TryDequeue(out TaskCompletionSource<bool> popped))
{
popped.SetResult(true);
}
}
}

}
@@ -0,0 +1,126 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Net.Security;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using System.Net.Sockets;
using System;

namespace Microsoft.Data.SqlClient.SNI
{

internal sealed partial class SNISslStream
{
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
ValueTask<int> valueTask = ReadAsync(new Memory<byte>(buffer, offset, count), cancellationToken);
if (valueTask.IsCompletedSuccessfully)
{
return Task.FromResult(valueTask.Result);
}
else
{
return valueTask.AsTask();
}
}

public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
return await base.ReadAsync(buffer, cancellationToken).ConfigureAwait(false);
}
finally
{
_readAsyncSemaphore.Release();
}
}

public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
ValueTask valueTask = WriteAsync(new Memory<byte>(buffer, offset, count), cancellationToken);
if (valueTask.IsCompletedSuccessfully)
{
return Task.CompletedTask;
}
else
{
return valueTask.AsTask();
}
}

public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
await base.WriteAsync(buffer, cancellationToken).ConfigureAwait(false);
}
finally
{
_writeAsyncSemaphore.Release();
}
}
}


internal sealed partial class SNINetworkStream
{
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
ValueTask<int> valueTask = ReadAsync(new Memory<byte>(buffer, offset, count), cancellationToken);
if (valueTask.IsCompletedSuccessfully)
{
return Task.FromResult(valueTask.Result);
}
else
{
return valueTask.AsTask();
}
}

public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
return await base.ReadAsync(buffer, cancellationToken).ConfigureAwait(false);
}
finally
{
_readAsyncSemaphore.Release();
}
}

// Prevent the WriteAsync collisions by running the task in a Semaphore Slim
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
ValueTask valueTask = WriteAsync(new Memory<byte>(buffer, offset, count), cancellationToken);
if (valueTask.IsCompletedSuccessfully)
{
return Task.CompletedTask;
}
else
{
return valueTask.AsTask();
}
}

public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
await base.WriteAsync(buffer, cancellationToken).ConfigureAwait(false);
}
finally
{
_writeAsyncSemaphore.Release();
}
}
}
}
@@ -0,0 +1,72 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Net.Security;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using System.Net.Sockets;

namespace Microsoft.Data.SqlClient.SNI
{
internal sealed partial class SNISslStream
{
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
return await base.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_readAsyncSemaphore.Release();
}
}

public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
await base.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_writeAsyncSemaphore.Release();
}
}
}

internal sealed partial class SNINetworkStream
{

public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
return await base.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_readAsyncSemaphore.Release();
}
}

// Prevent the WriteAsync collisions by running the task in a Semaphore Slim
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
await base.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_writeAsyncSemaphore.Release();
}
}
}
}

0 comments on commit 466062b

Please sign in to comment.