Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes "InvalidOperationException" errors by performing async operations in SemaphoreSlim #796

Merged
merged 20 commits into from Nov 19, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -442,6 +442,7 @@
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIPhysicalHandle.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIProxy.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNITcpHandle.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNISslStream.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SslOverTdsStream.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNICommon.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SspiClientContextStatus.cs" />
Expand Down
Expand Up @@ -93,7 +93,7 @@ public SNINpHandle(string serverName, string pipeName, long timerExpire, object
}

_sslOverTdsStream = new SslOverTdsStream(_pipeStream);
_sslStream = new SslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate), null);
_sslStream = new SNISslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate));

_stream = _pipeStream;
_status = TdsEnums.SNI_SUCCESS;
Expand Down
@@ -0,0 +1,44 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Net.Security;
using System.IO;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Data.SqlClient.SNI
{
/// <summary>
/// This class extends SslStream to customize stream behavior for Managed SNI implementation.
/// </summary>
internal class SNISslStream : SslStream
{
private readonly ConcurrentQueueSemaphore _writeQueueSemaphore;
private readonly ConcurrentQueueSemaphore _readQueueSemaphore;

public SNISslStream(Stream innerStream, bool leaveInnerStreamOpen, RemoteCertificateValidationCallback userCertificateValidationCallback)
: base(innerStream, leaveInnerStreamOpen, userCertificateValidationCallback)
{
_writeQueueSemaphore = new ConcurrentQueueSemaphore(1);
_readQueueSemaphore = new ConcurrentQueueSemaphore(1);
}

// Prevent the ReadAsync's collision by running task in Semaphore Slim
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
_readQueueSemaphore.Wait();
Task<int> t = base.ReadAsync(buffer, offset, count, cancellationToken);
_readQueueSemaphore.Release();
cheenamalhotra marked this conversation as resolved.
Show resolved Hide resolved
return t;
}

// Prevent the WriteAsync's collision by running task in Semaphore Slim
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
_writeQueueSemaphore.Wait();
return base.WriteAsync(buffer, offset, count, cancellationToken)
.ContinueWith(t => _writeQueueSemaphore.Release());
}
}
}
Expand Up @@ -160,14 +160,14 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
{
// Retry with cached IP address
if (ex is SocketException || ex is ArgumentException || ex is AggregateException)
{
{
if (hasCachedDNSInfo == false)
{
throw;
}
else
{
int portRetry = String.IsNullOrEmpty(cachedDNSInfo.Port) ? port : Int32.Parse(cachedDNSInfo.Port);
int portRetry = String.IsNullOrEmpty(cachedDNSInfo.Port) ? port : Int32.Parse(cachedDNSInfo.Port);

try
{
Expand All @@ -180,9 +180,9 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
_socket = Connect(cachedDNSInfo.AddrIPv4, portRetry, ts, isInfiniteTimeOut, cachedFQDN, ref pendingDNSInfo);
}
}
catch(Exception exRetry)
catch (Exception exRetry)
{
if (exRetry is SocketException || exRetry is ArgumentNullException
if (exRetry is SocketException || exRetry is ArgumentNullException
|| exRetry is ArgumentException || exRetry is ArgumentOutOfRangeException || exRetry is AggregateException)
{
if (parallel)
Expand All @@ -199,7 +199,7 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
throw;
}
}
}
}
}
else
{
Expand All @@ -226,7 +226,7 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
_tcpStream = new NetworkStream(_socket, true);

_sslOverTdsStream = new SslOverTdsStream(_tcpStream);
_sslStream = new SslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate), null);
_sslStream = new SNISslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate));
}
catch (SocketException se)
{
Expand Down Expand Up @@ -331,7 +331,7 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo
}

CancellationTokenSource cts = null;

void Cancel()
{
for (int i = 0; i < sockets.Length; ++i)
Expand All @@ -355,7 +355,7 @@ void Cancel()
}

Socket availableSocket = null;
try
try
{
for (int i = 0; i < sockets.Length; ++i)
{
Expand Down Expand Up @@ -706,12 +706,17 @@ public override void SetAsyncCallbacks(SNIAsyncCallback receiveCallback, SNIAsyn
/// <returns>SNI error code</returns>
public override uint SendAsync(SNIPacket packet, SNIAsyncCallback callback = null)
{
SNIAsyncCallback cb = callback ?? _sendCallback;
lock (this)
cheenamalhotra marked this conversation as resolved.
Show resolved Hide resolved
long scopeID = SqlClientEventSource.Log.TrySNIScopeEnterEvent("<sc.SNI.SNIMarsHandle.SendAsync |SNI|INFO|SCOPE>");
try
{
SNIAsyncCallback cb = callback ?? _sendCallback;
packet.WriteToStreamAsync(_stream, cb, SNIProviders.TCP_PROV);
return TdsEnums.SNI_SUCCESS_IO_PENDING;
}
finally
{
SqlClientEventSource.Log.TrySNIScopeLeaveEvent(scopeID);
}
return TdsEnums.SNI_SUCCESS_IO_PENDING;
}

/// <summary>
Expand Down
Expand Up @@ -12,24 +12,16 @@ namespace Microsoft.Data.SqlClient.SNI
internal sealed partial class SslOverTdsStream
{
public override int Read(byte[] buffer, int offset, int count)
{
return Read(buffer.AsSpan(offset, count));
}
=> Read(buffer.AsSpan(offset, count));

public override void Write(byte[] buffer, int offset, int count)
{
Write(buffer.AsSpan(offset, count));
}
=> Write(buffer.AsSpan(offset, count));

public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return ReadAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();
}
=> ReadAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();

public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken).AsTask();
}
=> WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken).AsTask();

public override int Read(Span<byte> buffer)
{
Expand Down Expand Up @@ -288,7 +280,6 @@ public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, Cancella

await _stream.FlushAsync().ConfigureAwait(false);


remaining = remaining.Slice(dataLength);
}
}
Expand Down
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Concurrent;
cheenamalhotra marked this conversation as resolved.
Show resolved Hide resolved
using System.Collections.Generic;
using System.Data;
using System.Diagnostics;
Expand Down Expand Up @@ -2131,4 +2132,48 @@ public static MethodInfo GetPromotedToken
}
}
}

/// <summary>
/// This class implements a FIFO Queue with SemaphoreSlim for ordered execution of parallel tasks.
/// Currently used in Managed SNI (SNISslStream) to override SslStream's WriteAsync implementation.
/// </summary>
internal class ConcurrentQueueSemaphore
{
private readonly SemaphoreSlim _semaphore;
private readonly ConcurrentQueue<TaskCompletionSource<bool>> _queue =
new ConcurrentQueue<TaskCompletionSource<bool>>();

public ConcurrentQueueSemaphore(int initialCount)
{
_semaphore = new SemaphoreSlim(initialCount);
}

public ConcurrentQueueSemaphore(int initialCount, int maxCount)
{
_semaphore = new SemaphoreSlim(initialCount, maxCount);
}

public void Wait()
{
WaitAsync().Wait();
}

public Task WaitAsync()
{
var tcs = new TaskCompletionSource<bool>();
_queue.Enqueue(tcs);
_semaphore.WaitAsync().ContinueWith(t =>
{
if (_queue.TryDequeue(out TaskCompletionSource<bool> popped))
popped.SetResult(true);
});
return tcs.Task;
}
cheenamalhotra marked this conversation as resolved.
Show resolved Hide resolved

public void Release()
{
_semaphore.Release();
}
}

}//namespace
Expand Up @@ -3266,7 +3266,7 @@ internal Task WritePacket(byte flushMode, bool canAccumulate = false)

if (willCancel)
{
// If we have been cancelled, then ensure that we write the ATTN packet as well
// If we have been canceled, then ensure that we write the ATTN packet as well
task = AsyncHelper.CreateContinuationTask(task, CancelWritePacket);
}
return task;
Expand Down
Expand Up @@ -72,13 +72,13 @@ internal override void AssignPendingDNSInfo(string userProtocol, string DNSCache
internal void ReadAsyncCallback(SNIPacket packet, uint error)
{
ReadAsyncCallback(IntPtr.Zero, PacketHandle.FromManagedPacket(packet), error);
_sessionHandle.ReturnPacket(packet);
_sessionHandle?.ReturnPacket(packet);
}

internal void WriteAsyncCallback(SNIPacket packet, uint sniError)
{
WriteAsyncCallback(IntPtr.Zero, PacketHandle.FromManagedPacket(packet), sniError);
_sessionHandle.ReturnPacket(packet);
_sessionHandle?.ReturnPacket(packet);
cheenamalhotra marked this conversation as resolved.
Show resolved Hide resolved
}

protected override void RemovePacketFromPendingList(PacketHandle packet)
Expand Down