Skip to content

Commit

Permalink
Fixes "InvalidOperationException" errors by performing async operatio…
Browse files Browse the repository at this point in the history
…ns in SemaphoreSlim (#796)
  • Loading branch information
cheenamalhotra committed Nov 19, 2020
1 parent f0572f3 commit cde615e
Show file tree
Hide file tree
Showing 10 changed files with 201 additions and 67 deletions.
Expand Up @@ -442,6 +442,7 @@
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIPhysicalHandle.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIProxy.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNITcpHandle.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIStreams.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SslOverTdsStream.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNICommon.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SspiClientContextStatus.cs" />
Expand Down
Expand Up @@ -93,7 +93,7 @@ public SNINpHandle(string serverName, string pipeName, long timerExpire, object
}

_sslOverTdsStream = new SslOverTdsStream(_pipeStream);
_sslStream = new SslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate), null);
_sslStream = new SNISslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate));

_stream = _pipeStream;
_status = TdsEnums.SNI_SUCCESS;
Expand Down Expand Up @@ -286,7 +286,7 @@ public override uint Send(SNIPacket packet)
}

// this lock ensures that two packets are not being written to the transport at the same time
// so that sending a standard and an out-of-band packet are both written atomically no data is
// so that sending a standard and an out-of-band packet are both written atomically no data is
// interleaved
lock (_sendSync)
{
Expand Down
@@ -0,0 +1,99 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Net.Security;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using System.Net.Sockets;

namespace Microsoft.Data.SqlClient.SNI
{
/// <summary>
/// This class extends SslStream to customize stream behavior for Managed SNI implementation.
/// </summary>
internal class SNISslStream : SslStream
{
private readonly ConcurrentQueueSemaphore _writeAsyncSemaphore;
private readonly ConcurrentQueueSemaphore _readAsyncSemaphore;

public SNISslStream(Stream innerStream, bool leaveInnerStreamOpen, RemoteCertificateValidationCallback userCertificateValidationCallback)
: base(innerStream, leaveInnerStreamOpen, userCertificateValidationCallback)
{
_writeAsyncSemaphore = new ConcurrentQueueSemaphore(1);
_readAsyncSemaphore = new ConcurrentQueueSemaphore(1);
}

// Prevent ReadAsync collisions by running the task in a Semaphore Slim
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
return await base.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_readAsyncSemaphore.Release();
}
}

// Prevent the WriteAsync collisions by running the task in a Semaphore Slim
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
await base.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_writeAsyncSemaphore.Release();
}
}
}

/// <summary>
/// This class extends NetworkStream to customize stream behavior for Managed SNI implementation.
/// </summary>
internal class SNINetworkStream : NetworkStream
{
private readonly ConcurrentQueueSemaphore _writeAsyncSemaphore;
private readonly ConcurrentQueueSemaphore _readAsyncSemaphore;

public SNINetworkStream(Socket socket, bool ownsSocket) : base(socket, ownsSocket)
{
_writeAsyncSemaphore = new ConcurrentQueueSemaphore(1);
_readAsyncSemaphore = new ConcurrentQueueSemaphore(1);
}

// Prevent ReadAsync collisions by running the task in a Semaphore Slim
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
return await base.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_readAsyncSemaphore.Release();
}
}

// Prevent the WriteAsync collisions by running the task in a Semaphore Slim
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
await base.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_writeAsyncSemaphore.Release();
}
}
}
}
Expand Up @@ -143,7 +143,7 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
bool reportError = true;

// We will always first try to connect with serverName as before and let the DNS server to resolve the serverName.
// If the DSN resolution fails, we will try with IPs in the DNS cache if existed. We try with IPv4 first and followed by IPv6 if
// If the DSN resolution fails, we will try with IPs in the DNS cache if existed. We try with IPv4 first and followed by IPv6 if
// IPv4 fails. The exceptions will be throw to upper level and be handled as before.
try
{
Expand All @@ -160,14 +160,14 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
{
// Retry with cached IP address
if (ex is SocketException || ex is ArgumentException || ex is AggregateException)
{
{
if (hasCachedDNSInfo == false)
{
throw;
}
else
{
int portRetry = String.IsNullOrEmpty(cachedDNSInfo.Port) ? port : Int32.Parse(cachedDNSInfo.Port);
int portRetry = String.IsNullOrEmpty(cachedDNSInfo.Port) ? port : Int32.Parse(cachedDNSInfo.Port);

try
{
Expand All @@ -180,9 +180,9 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
_socket = Connect(cachedDNSInfo.AddrIPv4, portRetry, ts, isInfiniteTimeOut, cachedFQDN, ref pendingDNSInfo);
}
}
catch(Exception exRetry)
catch (Exception exRetry)
{
if (exRetry is SocketException || exRetry is ArgumentNullException
if (exRetry is SocketException || exRetry is ArgumentNullException
|| exRetry is ArgumentException || exRetry is ArgumentOutOfRangeException || exRetry is AggregateException)
{
if (parallel)
Expand All @@ -199,7 +199,7 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
throw;
}
}
}
}
}
else
{
Expand All @@ -223,10 +223,10 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
}

_socket.NoDelay = true;
_tcpStream = new NetworkStream(_socket, true);
_tcpStream = new SNINetworkStream(_socket, true);

_sslOverTdsStream = new SslOverTdsStream(_tcpStream);
_sslStream = new SslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate), null);
_sslStream = new SNISslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate));
}
catch (SocketException se)
{
Expand Down Expand Up @@ -331,7 +331,7 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo
}

CancellationTokenSource cts = null;

void Cancel()
{
for (int i = 0; i < sockets.Length; ++i)
Expand All @@ -355,7 +355,7 @@ void Cancel()
}

Socket availableSocket = null;
try
try
{
for (int i = 0; i < sockets.Length; ++i)
{
Expand Down Expand Up @@ -566,45 +566,45 @@ public override uint Send(SNIPacket packet)
{
bool releaseLock = false;
try
{
// is the packet is marked out out-of-band (attention packets only) it must be
// sent immediately even if a send of recieve operation is already in progress
// because out of band packets are used to cancel ongoing operations
// so try to take the lock if possible but continue even if it can't be taken
if (packet.IsOutOfBand)
{
Monitor.TryEnter(this, ref releaseLock);
}
else
{
Monitor.Enter(this);
releaseLock = true;
}

// this lock ensures that two packets are not being written to the transport at the same time
// so that sending a standard and an out-of-band packet are both written atomically no data is
// interleaved
lock (_sendSync)
{
try
{
packet.WriteToStream(_stream);
return TdsEnums.SNI_SUCCESS;
}
catch (ObjectDisposedException ode)
// is the packet is marked out out-of-band (attention packets only) it must be
// sent immediately even if a send of recieve operation is already in progress
// because out of band packets are used to cancel ongoing operations
// so try to take the lock if possible but continue even if it can't be taken
if (packet.IsOutOfBand)
{
return ReportTcpSNIError(ode);
Monitor.TryEnter(this, ref releaseLock);
}
catch (SocketException se)
else
{
return ReportTcpSNIError(se);
Monitor.Enter(this);
releaseLock = true;
}
catch (IOException ioe)

// this lock ensures that two packets are not being written to the transport at the same time
// so that sending a standard and an out-of-band packet are both written atomically no data is
// interleaved
lock (_sendSync)
{
return ReportTcpSNIError(ioe);
try
{
packet.WriteToStream(_stream);
return TdsEnums.SNI_SUCCESS;
}
catch (ObjectDisposedException ode)
{
return ReportTcpSNIError(ode);
}
catch (SocketException se)
{
return ReportTcpSNIError(se);
}
catch (IOException ioe)
{
return ReportTcpSNIError(ioe);
}
}
}
}
finally
{
if (releaseLock)
Expand Down Expand Up @@ -633,7 +633,8 @@ public override uint Receive(out SNIPacket packet, int timeoutInMilliseconds)
_socket.ReceiveTimeout = timeoutInMilliseconds;
}
else if (timeoutInMilliseconds == -1)
{ // SqlCient internally represents infinite timeout by -1, and for TcpClient this is translated to a timeout of 0
{
// SqlClient internally represents infinite timeout by -1, and for TcpClient this is translated to a timeout of 0
_socket.ReceiveTimeout = 0;
}
else
Expand Down Expand Up @@ -706,12 +707,17 @@ public override void SetAsyncCallbacks(SNIAsyncCallback receiveCallback, SNIAsyn
/// <returns>SNI error code</returns>
public override uint SendAsync(SNIPacket packet, SNIAsyncCallback callback = null)
{
SNIAsyncCallback cb = callback ?? _sendCallback;
lock (this)
long scopeID = SqlClientEventSource.Log.TrySNIScopeEnterEvent("<sc.SNI.SNIMarsHandle.SendAsync |SNI|INFO|SCOPE>");
try
{
SNIAsyncCallback cb = callback ?? _sendCallback;
packet.WriteToStreamAsync(_stream, cb, SNIProviders.TCP_PROV);
return TdsEnums.SNI_SUCCESS_IO_PENDING;
}
finally
{
SqlClientEventSource.Log.TrySNIScopeLeaveEvent(scopeID);
}
return TdsEnums.SNI_SUCCESS_IO_PENDING;
}

/// <summary>
Expand Down Expand Up @@ -745,15 +751,15 @@ public override uint CheckConnection()
{
try
{
// _socket.Poll method with argument SelectMode.SelectRead returns
// _socket.Poll method with argument SelectMode.SelectRead returns
// True : if Listen has been called and a connection is pending, or
// True : if data is available for reading, or
// True : if the connection has been closed, reset, or terminated, i.e no active connection.
// False : otherwise.
// _socket.Available property returns the number of bytes of data available to read.
//
// Since _socket.Connected alone doesn't guarantee if the connection is still active, we use it in
// combination with _socket.Poll method and _socket.Available == 0 check. When both of them
// Since _socket.Connected alone doesn't guarantee if the connection is still active, we use it in
// combination with _socket.Poll method and _socket.Available == 0 check. When both of them
// return true we can safely determine that the connection is no longer active.
if (!_socket.Connected || (_socket.Poll(100, SelectMode.SelectRead) && _socket.Available == 0))
{
Expand Down
Expand Up @@ -12,24 +12,16 @@ namespace Microsoft.Data.SqlClient.SNI
internal sealed partial class SslOverTdsStream
{
public override int Read(byte[] buffer, int offset, int count)
{
return Read(buffer.AsSpan(offset, count));
}
=> Read(buffer.AsSpan(offset, count));

public override void Write(byte[] buffer, int offset, int count)
{
Write(buffer.AsSpan(offset, count));
}
=> Write(buffer.AsSpan(offset, count));

public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return ReadAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();
}
=> ReadAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();

public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken).AsTask();
}
=> WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken).AsTask();

public override int Read(Span<byte> buffer)
{
Expand Down Expand Up @@ -288,7 +280,6 @@ public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, Cancella

await _stream.FlushAsync().ConfigureAwait(false);


remaining = remaining.Slice(dataLength);
}
}
Expand Down
Expand Up @@ -1319,7 +1319,7 @@ private void ThrowIfReconnectionHasBeenCanceled()
if (_stateObj == null)
{
var reconnectionCompletionSource = _reconnectionCompletionSource;
if (reconnectionCompletionSource != null && reconnectionCompletionSource.Task.IsCanceled)
if (reconnectionCompletionSource != null && reconnectionCompletionSource.Task != null && reconnectionCompletionSource.Task.IsCanceled)
{
throw SQL.CR_ReconnectionCancelled();
}
Expand Down

0 comments on commit cde615e

Please sign in to comment.