Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes "InvalidOperationException" errors by performing async operations in SemaphoreSlim #796

Merged
merged 20 commits into from Nov 19, 2020
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -442,6 +442,7 @@
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIPhysicalHandle.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIProxy.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNITcpHandle.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIStreams.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SslOverTdsStream.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNICommon.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SspiClientContextStatus.cs" />
Expand Down
Expand Up @@ -93,7 +93,7 @@ public SNINpHandle(string serverName, string pipeName, long timerExpire, object
}

_sslOverTdsStream = new SslOverTdsStream(_pipeStream);
_sslStream = new SslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate), null);
_sslStream = new SNISslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate));

_stream = _pipeStream;
_status = TdsEnums.SNI_SUCCESS;
Expand Down Expand Up @@ -286,7 +286,7 @@ public override uint Send(SNIPacket packet)
}

// this lock ensures that two packets are not being written to the transport at the same time
// so that sending a standard and an out-of-band packet are both written atomically no data is
// so that sending a standard and an out-of-band packet are both written atomically no data is
// interleaved
lock (_sendSync)
{
Expand Down
@@ -0,0 +1,99 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Net.Security;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using System.Net.Sockets;

namespace Microsoft.Data.SqlClient.SNI
{
/// <summary>
/// This class extends SslStream to customize stream behavior for Managed SNI implementation.
/// </summary>
internal class SNISslStream : SslStream
{
private readonly SemaphoreSlim _writeAsyncSemaphore;
private readonly SemaphoreSlim _readAsyncSemaphore;

public SNISslStream(Stream innerStream, bool leaveInnerStreamOpen, RemoteCertificateValidationCallback userCertificateValidationCallback)
: base(innerStream, leaveInnerStreamOpen, userCertificateValidationCallback)
{
_writeAsyncSemaphore = new SemaphoreSlim(1);
_readAsyncSemaphore = new SemaphoreSlim(1);
}

// Prevent ReadAsync collisions by running the task in a Semaphore Slim
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _readAsyncSemaphore.WaitAsync().ConfigureAwait(false);
try
{
return await base.ReadAsync(buffer, offset, count, cancellationToken);
}
finally
{
_readAsyncSemaphore.Release();
}
}

// Prevent the WriteAsync collisions by running the task in a Semaphore Slim
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _writeAsyncSemaphore.WaitAsync().ConfigureAwait(false);
try
{
await base.WriteAsync(buffer, offset, count, cancellationToken);
}
finally
{
_writeAsyncSemaphore.Release();
}
}
}

/// <summary>
/// This class extends NetworkStream to customize stream behavior for Managed SNI implementation.
/// </summary>
internal class SNINetworkStream : NetworkStream
{
private readonly SemaphoreSlim _writeAsyncSemaphore;
private readonly SemaphoreSlim _readAsyncSemaphore;

public SNINetworkStream(Socket socket, bool ownsSocket) : base(socket, ownsSocket)
{
_writeAsyncSemaphore = new SemaphoreSlim(1);
_readAsyncSemaphore = new SemaphoreSlim(1);
}

// Prevent the ReadAsync collisions by running the task in a Semaphore Slim
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _readAsyncSemaphore.WaitAsync().ConfigureAwait(false);
try
{
return await base.ReadAsync(buffer, offset, count, cancellationToken);
}
finally
{
_readAsyncSemaphore.Release();
}
}

// Prevent the WriteAsync collisions by running the task in a Semaphore Slim
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _writeAsyncSemaphore.WaitAsync().ConfigureAwait(false);
try
{
await base.WriteAsync(buffer, offset, count, cancellationToken);
}
finally
{
_writeAsyncSemaphore.Release();
}
}
}
}
Expand Up @@ -143,7 +143,7 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
bool reportError = true;

// We will always first try to connect with serverName as before and let the DNS server to resolve the serverName.
// If the DSN resolution fails, we will try with IPs in the DNS cache if existed. We try with IPv4 first and followed by IPv6 if
// If the DSN resolution fails, we will try with IPs in the DNS cache if existed. We try with IPv4 first and followed by IPv6 if
// IPv4 fails. The exceptions will be throw to upper level and be handled as before.
try
{
Expand All @@ -160,14 +160,14 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
{
// Retry with cached IP address
if (ex is SocketException || ex is ArgumentException || ex is AggregateException)
{
{
if (hasCachedDNSInfo == false)
{
throw;
}
else
{
int portRetry = String.IsNullOrEmpty(cachedDNSInfo.Port) ? port : Int32.Parse(cachedDNSInfo.Port);
int portRetry = String.IsNullOrEmpty(cachedDNSInfo.Port) ? port : Int32.Parse(cachedDNSInfo.Port);

try
{
Expand All @@ -180,9 +180,9 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
_socket = Connect(cachedDNSInfo.AddrIPv4, portRetry, ts, isInfiniteTimeOut, cachedFQDN, ref pendingDNSInfo);
}
}
catch(Exception exRetry)
catch (Exception exRetry)
{
if (exRetry is SocketException || exRetry is ArgumentNullException
if (exRetry is SocketException || exRetry is ArgumentNullException
|| exRetry is ArgumentException || exRetry is ArgumentOutOfRangeException || exRetry is AggregateException)
{
if (parallel)
Expand All @@ -199,7 +199,7 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
throw;
}
}
}
}
}
else
{
Expand All @@ -223,10 +223,10 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
}

_socket.NoDelay = true;
_tcpStream = new NetworkStream(_socket, true);
_tcpStream = new SNINetworkStream(_socket, true);
cheenamalhotra marked this conversation as resolved.
Show resolved Hide resolved

_sslOverTdsStream = new SslOverTdsStream(_tcpStream);
_sslStream = new SslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate), null);
_sslStream = new SNISslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate));
}
catch (SocketException se)
{
Expand Down Expand Up @@ -331,7 +331,7 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo
}

CancellationTokenSource cts = null;

void Cancel()
{
for (int i = 0; i < sockets.Length; ++i)
Expand All @@ -355,7 +355,7 @@ void Cancel()
}

Socket availableSocket = null;
try
try
{
for (int i = 0; i < sockets.Length; ++i)
{
Expand Down Expand Up @@ -566,45 +566,45 @@ public override uint Send(SNIPacket packet)
{
bool releaseLock = false;
try
{
// is the packet is marked out out-of-band (attention packets only) it must be
// sent immediately even if a send of recieve operation is already in progress
// because out of band packets are used to cancel ongoing operations
// so try to take the lock if possible but continue even if it can't be taken
if (packet.IsOutOfBand)
{
Monitor.TryEnter(this, ref releaseLock);
}
else
{
Monitor.Enter(this);
releaseLock = true;
}

// this lock ensures that two packets are not being written to the transport at the same time
// so that sending a standard and an out-of-band packet are both written atomically no data is
// interleaved
lock (_sendSync)
{
try
{
packet.WriteToStream(_stream);
return TdsEnums.SNI_SUCCESS;
}
catch (ObjectDisposedException ode)
// is the packet is marked out out-of-band (attention packets only) it must be
// sent immediately even if a send of recieve operation is already in progress
// because out of band packets are used to cancel ongoing operations
// so try to take the lock if possible but continue even if it can't be taken
if (packet.IsOutOfBand)
{
return ReportTcpSNIError(ode);
Monitor.TryEnter(this, ref releaseLock);
}
catch (SocketException se)
else
{
return ReportTcpSNIError(se);
Monitor.Enter(this);
releaseLock = true;
}
catch (IOException ioe)

// this lock ensures that two packets are not being written to the transport at the same time
// so that sending a standard and an out-of-band packet are both written atomically no data is
// interleaved
lock (_sendSync)
{
return ReportTcpSNIError(ioe);
try
{
packet.WriteToStream(_stream);
return TdsEnums.SNI_SUCCESS;
}
catch (ObjectDisposedException ode)
{
return ReportTcpSNIError(ode);
}
catch (SocketException se)
{
return ReportTcpSNIError(se);
}
catch (IOException ioe)
{
return ReportTcpSNIError(ioe);
}
}
}
}
finally
{
if (releaseLock)
Expand Down Expand Up @@ -633,7 +633,8 @@ public override uint Receive(out SNIPacket packet, int timeoutInMilliseconds)
_socket.ReceiveTimeout = timeoutInMilliseconds;
}
else if (timeoutInMilliseconds == -1)
{ // SqlCient internally represents infinite timeout by -1, and for TcpClient this is translated to a timeout of 0
{
// SqlClient internally represents infinite timeout by -1, and for TcpClient this is translated to a timeout of 0
_socket.ReceiveTimeout = 0;
}
else
Expand Down Expand Up @@ -706,12 +707,17 @@ public override void SetAsyncCallbacks(SNIAsyncCallback receiveCallback, SNIAsyn
/// <returns>SNI error code</returns>
public override uint SendAsync(SNIPacket packet, SNIAsyncCallback callback = null)
{
SNIAsyncCallback cb = callback ?? _sendCallback;
lock (this)
cheenamalhotra marked this conversation as resolved.
Show resolved Hide resolved
long scopeID = SqlClientEventSource.Log.TrySNIScopeEnterEvent("<sc.SNI.SNIMarsHandle.SendAsync |SNI|INFO|SCOPE>");
try
{
SNIAsyncCallback cb = callback ?? _sendCallback;
packet.WriteToStreamAsync(_stream, cb, SNIProviders.TCP_PROV);
return TdsEnums.SNI_SUCCESS_IO_PENDING;
}
finally
{
SqlClientEventSource.Log.TrySNIScopeLeaveEvent(scopeID);
}
return TdsEnums.SNI_SUCCESS_IO_PENDING;
}

/// <summary>
Expand Down Expand Up @@ -745,15 +751,15 @@ public override uint CheckConnection()
{
try
{
// _socket.Poll method with argument SelectMode.SelectRead returns
// _socket.Poll method with argument SelectMode.SelectRead returns
// True : if Listen has been called and a connection is pending, or
// True : if data is available for reading, or
// True : if the connection has been closed, reset, or terminated, i.e no active connection.
// False : otherwise.
// _socket.Available property returns the number of bytes of data available to read.
//
// Since _socket.Connected alone doesn't guarantee if the connection is still active, we use it in
// combination with _socket.Poll method and _socket.Available == 0 check. When both of them
// Since _socket.Connected alone doesn't guarantee if the connection is still active, we use it in
// combination with _socket.Poll method and _socket.Available == 0 check. When both of them
// return true we can safely determine that the connection is no longer active.
if (!_socket.Connected || (_socket.Poll(100, SelectMode.SelectRead) && _socket.Available == 0))
{
Expand Down
Expand Up @@ -12,24 +12,16 @@ namespace Microsoft.Data.SqlClient.SNI
internal sealed partial class SslOverTdsStream
{
public override int Read(byte[] buffer, int offset, int count)
{
return Read(buffer.AsSpan(offset, count));
}
=> Read(buffer.AsSpan(offset, count));

public override void Write(byte[] buffer, int offset, int count)
{
Write(buffer.AsSpan(offset, count));
}
=> Write(buffer.AsSpan(offset, count));

public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return ReadAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();
}
=> ReadAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();

public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken).AsTask();
}
=> WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken).AsTask();

public override int Read(Span<byte> buffer)
{
Expand Down Expand Up @@ -288,7 +280,6 @@ public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, Cancella

await _stream.FlushAsync().ConfigureAwait(false);


remaining = remaining.Slice(dataLength);
}
}
Expand Down
Expand Up @@ -1319,7 +1319,7 @@ private void ThrowIfReconnectionHasBeenCanceled()
if (_stateObj == null)
{
var reconnectionCompletionSource = _reconnectionCompletionSource;
if (reconnectionCompletionSource != null && reconnectionCompletionSource.Task.IsCanceled)
if (reconnectionCompletionSource != null && reconnectionCompletionSource.Task != null && reconnectionCompletionSource.Task.IsCanceled)
{
throw SQL.CR_ReconnectionCancelled();
}
Expand Down
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Concurrent;
cheenamalhotra marked this conversation as resolved.
Show resolved Hide resolved
using System.Collections.Generic;
using System.Data;
using System.Diagnostics;
Expand Down Expand Up @@ -2131,4 +2132,5 @@ public static MethodInfo GetPromotedToken
}
}
}

}//namespace
Expand Up @@ -3266,7 +3266,7 @@ internal Task WritePacket(byte flushMode, bool canAccumulate = false)

if (willCancel)
{
// If we have been cancelled, then ensure that we write the ATTN packet as well
// If we have been canceled, then ensure that we write the ATTN packet as well
task = AsyncHelper.CreateContinuationTask(task, CancelWritePacket);
}
return task;
Expand Down