Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes "InvalidOperationException" errors by performing async operations in SemaphoreSlim #796

Merged
merged 20 commits into from Nov 19, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -442,6 +442,7 @@
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIPhysicalHandle.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIProxy.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNITcpHandle.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNISslStream.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SslOverTdsStream.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNICommon.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SspiClientContextStatus.cs" />
Expand Down
Expand Up @@ -93,7 +93,7 @@ public SNINpHandle(string serverName, string pipeName, long timerExpire, object
}

_sslOverTdsStream = new SslOverTdsStream(_pipeStream);
_sslStream = new SslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate), null);
_sslStream = new SNISslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate));

_stream = _pipeStream;
_status = TdsEnums.SNI_SUCCESS;
Expand Down Expand Up @@ -189,39 +189,36 @@ public override uint Receive(out SNIPacket packet, int timeout)
try
{
SNIPacket errorPacket;
lock (this)
packet = null;
try
{
packet = null;
try
{
packet = RentPacket(headerSize: 0, dataSize: _bufferSize);
packet.ReadFromStream(_stream);
packet = RentPacket(headerSize: 0, dataSize: _bufferSize);
packet.ReadFromStream(_stream);

if (packet.Length == 0)
{
errorPacket = packet;
packet = null;
var e = new Win32Exception();
SqlClientEventSource.Log.TrySNITraceEvent("<sc.SNI.SNINpHandle.Receive |SNI|ERR> packet length is 0.");
return ReportErrorAndReleasePacket(errorPacket, (uint)e.NativeErrorCode, 0, e.Message);
}
}
catch (ObjectDisposedException ode)
if (packet.Length == 0)
{
errorPacket = packet;
packet = null;
SqlClientEventSource.Log.TrySNITraceEvent("<sc.SNI.SNINpHandle.Receive |SNI|ERR> ObjectDisposedException message = {0}.", ode.Message);
return ReportErrorAndReleasePacket(errorPacket, ode);
var e = new Win32Exception();
SqlClientEventSource.Log.TrySNITraceEvent("<sc.SNI.SNINpHandle.Receive |SNI|ERR> packet length is 0.");
return ReportErrorAndReleasePacket(errorPacket, (uint)e.NativeErrorCode, 0, e.Message);
}
catch (IOException ioe)
{
errorPacket = packet;
packet = null;
SqlClientEventSource.Log.TrySNITraceEvent("<sc.SNI.SNINpHandle.Receive |SNI|ERR> IOException message = {0}.", ioe.Message);
return ReportErrorAndReleasePacket(errorPacket, ioe);
}
return TdsEnums.SNI_SUCCESS;
}
catch (ObjectDisposedException ode)
{
errorPacket = packet;
packet = null;
SqlClientEventSource.Log.TrySNITraceEvent("<sc.SNI.SNINpHandle.Receive |SNI|ERR> ObjectDisposedException message = {0}.", ode.Message);
return ReportErrorAndReleasePacket(errorPacket, ode);
}
catch (IOException ioe)
{
errorPacket = packet;
packet = null;
SqlClientEventSource.Log.TrySNITraceEvent("<sc.SNI.SNINpHandle.Receive |SNI|ERR> IOException message = {0}.", ioe.Message);
return ReportErrorAndReleasePacket(errorPacket, ioe);
}
return TdsEnums.SNI_SUCCESS;
}
finally
{
Expand Down Expand Up @@ -286,7 +283,7 @@ public override uint Send(SNIPacket packet)
}

// this lock ensures that two packets are not being written to the transport at the same time
// so that sending a standard and an out-of-band packet are both written atomically no data is
// so that sending a standard and an out-of-band packet are both written atomically no data is
// interleaved
lock (_sendSync)
{
Expand Down
@@ -0,0 +1,41 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Net.Security;
using System.IO;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Data.SqlClient.SNI
{
/// <summary>
/// This class extends SslStream to customize stream behavior for Managed SNI implementation.
/// </summary>
internal class SNISslStream : SslStream
{
private readonly ConcurrentQueueSemaphore _writeAsyncQueueSemaphore;
private readonly ConcurrentQueueSemaphore _readAsyncQueueSemaphore;

public SNISslStream(Stream innerStream, bool leaveInnerStreamOpen, RemoteCertificateValidationCallback userCertificateValidationCallback)
: base(innerStream, leaveInnerStreamOpen, userCertificateValidationCallback)
{
_writeAsyncQueueSemaphore = new ConcurrentQueueSemaphore(1);
_readAsyncQueueSemaphore = new ConcurrentQueueSemaphore(1);
}

// Prevent the ReadAsync's collision by running task in Semaphore Slim
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return _readAsyncQueueSemaphore.WaitAsync().ContinueWith<int>(t => base.ReadAsync(buffer, offset, count, cancellationToken).Result)
cheenamalhotra marked this conversation as resolved.
Show resolved Hide resolved
.ContinueWith(t => _readAsyncQueueSemaphore.Release(t.Result));
}

// Prevent the WriteAsync's collision by running task in Semaphore Slim
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return _writeAsyncQueueSemaphore.WaitAsync().ContinueWith(t => base.WriteAsync(buffer, offset, count, cancellationToken))
.ContinueWith(t => _writeAsyncQueueSemaphore.Release());
}
}
}
Expand Up @@ -143,7 +143,7 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
bool reportError = true;

// We will always first try to connect with serverName as before and let the DNS server to resolve the serverName.
// If the DSN resolution fails, we will try with IPs in the DNS cache if existed. We try with IPv4 first and followed by IPv6 if
// If the DSN resolution fails, we will try with IPs in the DNS cache if existed. We try with IPv4 first and followed by IPv6 if
// IPv4 fails. The exceptions will be throw to upper level and be handled as before.
try
{
Expand All @@ -160,14 +160,14 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
{
// Retry with cached IP address
if (ex is SocketException || ex is ArgumentException || ex is AggregateException)
{
{
if (hasCachedDNSInfo == false)
{
throw;
}
else
{
int portRetry = String.IsNullOrEmpty(cachedDNSInfo.Port) ? port : Int32.Parse(cachedDNSInfo.Port);
int portRetry = String.IsNullOrEmpty(cachedDNSInfo.Port) ? port : Int32.Parse(cachedDNSInfo.Port);

try
{
Expand All @@ -180,9 +180,9 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
_socket = Connect(cachedDNSInfo.AddrIPv4, portRetry, ts, isInfiniteTimeOut, cachedFQDN, ref pendingDNSInfo);
}
}
catch(Exception exRetry)
catch (Exception exRetry)
{
if (exRetry is SocketException || exRetry is ArgumentNullException
if (exRetry is SocketException || exRetry is ArgumentNullException
|| exRetry is ArgumentException || exRetry is ArgumentOutOfRangeException || exRetry is AggregateException)
{
if (parallel)
Expand All @@ -199,7 +199,7 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
throw;
}
}
}
}
}
else
{
Expand All @@ -226,7 +226,7 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
_tcpStream = new NetworkStream(_socket, true);

_sslOverTdsStream = new SslOverTdsStream(_tcpStream);
_sslStream = new SslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate), null);
_sslStream = new SNISslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate));
}
catch (SocketException se)
{
Expand Down Expand Up @@ -331,7 +331,7 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo
}

CancellationTokenSource cts = null;

void Cancel()
{
for (int i = 0; i < sockets.Length; ++i)
Expand All @@ -355,7 +355,7 @@ void Cancel()
}

Socket availableSocket = null;
try
try
{
for (int i = 0; i < sockets.Length; ++i)
{
Expand Down Expand Up @@ -582,7 +582,7 @@ public override uint Send(SNIPacket packet)
}

// this lock ensures that two packets are not being written to the transport at the same time
// so that sending a standard and an out-of-band packet are both written atomically no data is
// so that sending a standard and an out-of-band packet are both written atomically no data is
// interleaved
lock (_sendSync)
{
Expand Down Expand Up @@ -623,67 +623,65 @@ public override uint Send(SNIPacket packet)
public override uint Receive(out SNIPacket packet, int timeoutInMilliseconds)
{
SNIPacket errorPacket;
lock (this)
packet = null;
try
{
packet = null;
try
if (timeoutInMilliseconds > 0)
{
if (timeoutInMilliseconds > 0)
{
_socket.ReceiveTimeout = timeoutInMilliseconds;
}
else if (timeoutInMilliseconds == -1)
{ // SqlCient internally represents infinite timeout by -1, and for TcpClient this is translated to a timeout of 0
_socket.ReceiveTimeout = 0;
}
else
{
// otherwise it is timeout for 0 or less than -1
ReportTcpSNIError(0, SNICommon.ConnTimeoutError, string.Empty);
return TdsEnums.SNI_WAIT_TIMEOUT;
}

packet = RentPacket(headerSize: 0, dataSize: _bufferSize);
packet.ReadFromStream(_stream);

if (packet.Length == 0)
{
errorPacket = packet;
packet = null;
var e = new Win32Exception();
return ReportErrorAndReleasePacket(errorPacket, (uint)e.NativeErrorCode, 0, e.Message);
}

return TdsEnums.SNI_SUCCESS;
_socket.ReceiveTimeout = timeoutInMilliseconds;
}
catch (ObjectDisposedException ode)
else if (timeoutInMilliseconds == -1)
{
errorPacket = packet;
packet = null;
return ReportErrorAndReleasePacket(errorPacket, ode);
// SqlClient internally represents infinite timeout by -1, and for TcpClient this is translated to a timeout of 0
_socket.ReceiveTimeout = 0;
}
catch (SocketException se)
else
{
errorPacket = packet;
packet = null;
return ReportErrorAndReleasePacket(errorPacket, se);
// otherwise it is timeout for 0 or less than -1
ReportTcpSNIError(0, SNICommon.ConnTimeoutError, string.Empty);
return TdsEnums.SNI_WAIT_TIMEOUT;
}
catch (IOException ioe)

packet = RentPacket(headerSize: 0, dataSize: _bufferSize);
packet.ReadFromStream(_stream);

if (packet.Length == 0)
{
errorPacket = packet;
packet = null;
uint errorCode = ReportErrorAndReleasePacket(errorPacket, ioe);
if (ioe.InnerException is SocketException socketException && socketException.SocketErrorCode == SocketError.TimedOut)
{
errorCode = TdsEnums.SNI_WAIT_TIMEOUT;
}

return errorCode;
var e = new Win32Exception();
return ReportErrorAndReleasePacket(errorPacket, (uint)e.NativeErrorCode, 0, e.Message);
}
finally

return TdsEnums.SNI_SUCCESS;
}
catch (ObjectDisposedException ode)
{
errorPacket = packet;
packet = null;
return ReportErrorAndReleasePacket(errorPacket, ode);
}
catch (SocketException se)
{
errorPacket = packet;
packet = null;
return ReportErrorAndReleasePacket(errorPacket, se);
}
catch (IOException ioe)
{
errorPacket = packet;
packet = null;
uint errorCode = ReportErrorAndReleasePacket(errorPacket, ioe);
if (ioe.InnerException is SocketException socketException && socketException.SocketErrorCode == SocketError.TimedOut)
{
_socket.ReceiveTimeout = 0;
errorCode = TdsEnums.SNI_WAIT_TIMEOUT;
}

return errorCode;
}
finally
{
_socket.ReceiveTimeout = 0;
}
}

Expand All @@ -706,12 +704,17 @@ public override void SetAsyncCallbacks(SNIAsyncCallback receiveCallback, SNIAsyn
/// <returns>SNI error code</returns>
public override uint SendAsync(SNIPacket packet, SNIAsyncCallback callback = null)
{
SNIAsyncCallback cb = callback ?? _sendCallback;
lock (this)
cheenamalhotra marked this conversation as resolved.
Show resolved Hide resolved
long scopeID = SqlClientEventSource.Log.TrySNIScopeEnterEvent("<sc.SNI.SNIMarsHandle.SendAsync |SNI|INFO|SCOPE>");
try
{
SNIAsyncCallback cb = callback ?? _sendCallback;
packet.WriteToStreamAsync(_stream, cb, SNIProviders.TCP_PROV);
return TdsEnums.SNI_SUCCESS_IO_PENDING;
}
finally
{
SqlClientEventSource.Log.TrySNIScopeLeaveEvent(scopeID);
}
return TdsEnums.SNI_SUCCESS_IO_PENDING;
}

/// <summary>
Expand Down Expand Up @@ -745,15 +748,15 @@ public override uint CheckConnection()
{
try
{
// _socket.Poll method with argument SelectMode.SelectRead returns
// _socket.Poll method with argument SelectMode.SelectRead returns
// True : if Listen has been called and a connection is pending, or
// True : if data is available for reading, or
// True : if the connection has been closed, reset, or terminated, i.e no active connection.
// False : otherwise.
// _socket.Available property returns the number of bytes of data available to read.
//
// Since _socket.Connected alone doesn't guarantee if the connection is still active, we use it in
// combination with _socket.Poll method and _socket.Available == 0 check. When both of them
// Since _socket.Connected alone doesn't guarantee if the connection is still active, we use it in
// combination with _socket.Poll method and _socket.Available == 0 check. When both of them
// return true we can safely determine that the connection is no longer active.
if (!_socket.Connected || (_socket.Poll(100, SelectMode.SelectRead) && _socket.Available == 0))
{
Expand Down