diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index d4d8d25cb7..e46d123395 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -442,6 +442,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNINpHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNINpHandle.cs index 692aa9b7fe..0132d7df58 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNINpHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNINpHandle.cs @@ -93,7 +93,7 @@ public SNINpHandle(string serverName, string pipeName, long timerExpire, object } _sslOverTdsStream = new SslOverTdsStream(_pipeStream); - _sslStream = new SslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate), null); + _sslStream = new SNISslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate)); _stream = _pipeStream; _status = TdsEnums.SNI_SUCCESS; @@ -286,7 +286,7 @@ public override uint Send(SNIPacket packet) } // this lock ensures that two packets are not being written to the transport at the same time - // so that sending a standard and an out-of-band packet are both written atomically no data is + // so that sending a standard and an out-of-band packet are both written atomically no data is // interleaved lock (_sendSync) { diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.cs new file mode 100644 index 0000000000..eb8661d022 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.cs @@ -0,0 +1,99 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Net.Security; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using System.Net.Sockets; + +namespace Microsoft.Data.SqlClient.SNI +{ + /// + /// This class extends SslStream to customize stream behavior for Managed SNI implementation. + /// + internal class SNISslStream : SslStream + { + private readonly ConcurrentQueueSemaphore _writeAsyncSemaphore; + private readonly ConcurrentQueueSemaphore _readAsyncSemaphore; + + public SNISslStream(Stream innerStream, bool leaveInnerStreamOpen, RemoteCertificateValidationCallback userCertificateValidationCallback) + : base(innerStream, leaveInnerStreamOpen, userCertificateValidationCallback) + { + _writeAsyncSemaphore = new ConcurrentQueueSemaphore(1); + _readAsyncSemaphore = new ConcurrentQueueSemaphore(1); + } + + // Prevent ReadAsync collisions by running the task in a Semaphore Slim + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + try + { + return await base.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); + } + finally + { + _readAsyncSemaphore.Release(); + } + } + + // Prevent the WriteAsync collisions by running the task in a Semaphore Slim + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + try + { + await base.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); + } + finally + { + _writeAsyncSemaphore.Release(); + } + } + } + + /// + /// This class extends NetworkStream to customize stream behavior for Managed SNI implementation. + /// + internal class SNINetworkStream : NetworkStream + { + private readonly ConcurrentQueueSemaphore _writeAsyncSemaphore; + private readonly ConcurrentQueueSemaphore _readAsyncSemaphore; + + public SNINetworkStream(Socket socket, bool ownsSocket) : base(socket, ownsSocket) + { + _writeAsyncSemaphore = new ConcurrentQueueSemaphore(1); + _readAsyncSemaphore = new ConcurrentQueueSemaphore(1); + } + + // Prevent ReadAsync collisions by running the task in a Semaphore Slim + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + try + { + return await base.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); + } + finally + { + _readAsyncSemaphore.Release(); + } + } + + // Prevent the WriteAsync collisions by running the task in a Semaphore Slim + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + try + { + await base.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); + } + finally + { + _writeAsyncSemaphore.Release(); + } + } + } +} diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs index b072a4fa01..ef85841d24 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs @@ -143,7 +143,7 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba bool reportError = true; // We will always first try to connect with serverName as before and let the DNS server to resolve the serverName. - // If the DSN resolution fails, we will try with IPs in the DNS cache if existed. We try with IPv4 first and followed by IPv6 if + // If the DSN resolution fails, we will try with IPs in the DNS cache if existed. We try with IPv4 first and followed by IPv6 if // IPv4 fails. The exceptions will be throw to upper level and be handled as before. try { @@ -160,14 +160,14 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba { // Retry with cached IP address if (ex is SocketException || ex is ArgumentException || ex is AggregateException) - { + { if (hasCachedDNSInfo == false) { throw; } else { - int portRetry = String.IsNullOrEmpty(cachedDNSInfo.Port) ? port : Int32.Parse(cachedDNSInfo.Port); + int portRetry = String.IsNullOrEmpty(cachedDNSInfo.Port) ? port : Int32.Parse(cachedDNSInfo.Port); try { @@ -180,9 +180,9 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba _socket = Connect(cachedDNSInfo.AddrIPv4, portRetry, ts, isInfiniteTimeOut, cachedFQDN, ref pendingDNSInfo); } } - catch(Exception exRetry) + catch (Exception exRetry) { - if (exRetry is SocketException || exRetry is ArgumentNullException + if (exRetry is SocketException || exRetry is ArgumentNullException || exRetry is ArgumentException || exRetry is ArgumentOutOfRangeException || exRetry is AggregateException) { if (parallel) @@ -199,7 +199,7 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba throw; } } - } + } } else { @@ -223,10 +223,10 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba } _socket.NoDelay = true; - _tcpStream = new NetworkStream(_socket, true); + _tcpStream = new SNINetworkStream(_socket, true); _sslOverTdsStream = new SslOverTdsStream(_tcpStream); - _sslStream = new SslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate), null); + _sslStream = new SNISslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate)); } catch (SocketException se) { @@ -331,7 +331,7 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo } CancellationTokenSource cts = null; - + void Cancel() { for (int i = 0; i < sockets.Length; ++i) @@ -355,7 +355,7 @@ void Cancel() } Socket availableSocket = null; - try + try { for (int i = 0; i < sockets.Length; ++i) { @@ -566,45 +566,45 @@ public override uint Send(SNIPacket packet) { bool releaseLock = false; try - { - // is the packet is marked out out-of-band (attention packets only) it must be - // sent immediately even if a send of recieve operation is already in progress - // because out of band packets are used to cancel ongoing operations - // so try to take the lock if possible but continue even if it can't be taken - if (packet.IsOutOfBand) - { - Monitor.TryEnter(this, ref releaseLock); - } - else - { - Monitor.Enter(this); - releaseLock = true; - } - - // this lock ensures that two packets are not being written to the transport at the same time - // so that sending a standard and an out-of-band packet are both written atomically no data is - // interleaved - lock (_sendSync) { - try - { - packet.WriteToStream(_stream); - return TdsEnums.SNI_SUCCESS; - } - catch (ObjectDisposedException ode) + // is the packet is marked out out-of-band (attention packets only) it must be + // sent immediately even if a send of recieve operation is already in progress + // because out of band packets are used to cancel ongoing operations + // so try to take the lock if possible but continue even if it can't be taken + if (packet.IsOutOfBand) { - return ReportTcpSNIError(ode); + Monitor.TryEnter(this, ref releaseLock); } - catch (SocketException se) + else { - return ReportTcpSNIError(se); + Monitor.Enter(this); + releaseLock = true; } - catch (IOException ioe) + + // this lock ensures that two packets are not being written to the transport at the same time + // so that sending a standard and an out-of-band packet are both written atomically no data is + // interleaved + lock (_sendSync) { - return ReportTcpSNIError(ioe); + try + { + packet.WriteToStream(_stream); + return TdsEnums.SNI_SUCCESS; + } + catch (ObjectDisposedException ode) + { + return ReportTcpSNIError(ode); + } + catch (SocketException se) + { + return ReportTcpSNIError(se); + } + catch (IOException ioe) + { + return ReportTcpSNIError(ioe); + } } } - } finally { if (releaseLock) @@ -633,7 +633,8 @@ public override uint Receive(out SNIPacket packet, int timeoutInMilliseconds) _socket.ReceiveTimeout = timeoutInMilliseconds; } else if (timeoutInMilliseconds == -1) - { // SqlCient internally represents infinite timeout by -1, and for TcpClient this is translated to a timeout of 0 + { + // SqlClient internally represents infinite timeout by -1, and for TcpClient this is translated to a timeout of 0 _socket.ReceiveTimeout = 0; } else @@ -706,12 +707,17 @@ public override void SetAsyncCallbacks(SNIAsyncCallback receiveCallback, SNIAsyn /// SNI error code public override uint SendAsync(SNIPacket packet, SNIAsyncCallback callback = null) { - SNIAsyncCallback cb = callback ?? _sendCallback; - lock (this) + long scopeID = SqlClientEventSource.Log.TrySNIScopeEnterEvent(""); + try { + SNIAsyncCallback cb = callback ?? _sendCallback; packet.WriteToStreamAsync(_stream, cb, SNIProviders.TCP_PROV); + return TdsEnums.SNI_SUCCESS_IO_PENDING; + } + finally + { + SqlClientEventSource.Log.TrySNIScopeLeaveEvent(scopeID); } - return TdsEnums.SNI_SUCCESS_IO_PENDING; } /// @@ -745,15 +751,15 @@ public override uint CheckConnection() { try { - // _socket.Poll method with argument SelectMode.SelectRead returns + // _socket.Poll method with argument SelectMode.SelectRead returns // True : if Listen has been called and a connection is pending, or // True : if data is available for reading, or // True : if the connection has been closed, reset, or terminated, i.e no active connection. // False : otherwise. // _socket.Available property returns the number of bytes of data available to read. // - // Since _socket.Connected alone doesn't guarantee if the connection is still active, we use it in - // combination with _socket.Poll method and _socket.Available == 0 check. When both of them + // Since _socket.Connected alone doesn't guarantee if the connection is still active, we use it in + // combination with _socket.Poll method and _socket.Available == 0 check. When both of them // return true we can safely determine that the connection is no longer active. if (!_socket.Connected || (_socket.Poll(100, SelectMode.SelectRead) && _socket.Available == 0)) { diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs index cb634cb6af..97a1181ae3 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs @@ -12,24 +12,16 @@ namespace Microsoft.Data.SqlClient.SNI internal sealed partial class SslOverTdsStream { public override int Read(byte[] buffer, int offset, int count) - { - return Read(buffer.AsSpan(offset, count)); - } + => Read(buffer.AsSpan(offset, count)); public override void Write(byte[] buffer, int offset, int count) - { - Write(buffer.AsSpan(offset, count)); - } + => Write(buffer.AsSpan(offset, count)); public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - return ReadAsync(new Memory(buffer, offset, count), cancellationToken).AsTask(); - } + => ReadAsync(new Memory(buffer, offset, count), cancellationToken).AsTask(); public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - return WriteAsync(new ReadOnlyMemory(buffer, offset, count), cancellationToken).AsTask(); - } + => WriteAsync(new ReadOnlyMemory(buffer, offset, count), cancellationToken).AsTask(); public override int Read(Span buffer) { @@ -288,7 +280,6 @@ public override async ValueTask WriteAsync(ReadOnlyMemory buffer, Cancella await _stream.FlushAsync().ConfigureAwait(false); - remaining = remaining.Slice(dataLength); } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs index e19ee3eba0..10d5064a87 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs @@ -1319,7 +1319,7 @@ private void ThrowIfReconnectionHasBeenCanceled() if (_stateObj == null) { var reconnectionCompletionSource = _reconnectionCompletionSource; - if (reconnectionCompletionSource != null && reconnectionCompletionSource.Task.IsCanceled) + if (reconnectionCompletionSource != null && reconnectionCompletionSource.Task != null && reconnectionCompletionSource.Task.IsCanceled) { throw SQL.CR_ReconnectionCancelled(); } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs index 5ddd978093..bf21c8db3a 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Data; using System.Diagnostics; @@ -2131,4 +2132,38 @@ public static MethodInfo GetPromotedToken } } } + + /// + /// This class implements a FIFO Queue with SemaphoreSlim for ordered execution of parallel tasks. + /// Currently used in Managed SNI (SNISslStream) to override SslStream's WriteAsync implementation. + /// + internal class ConcurrentQueueSemaphore + { + private readonly SemaphoreSlim _semaphore; + private readonly ConcurrentQueue> _queue = + new ConcurrentQueue>(); + + public ConcurrentQueueSemaphore(int initialCount) + { + _semaphore = new SemaphoreSlim(initialCount); + } + + public Task WaitAsync(CancellationToken cancellationToken) + { + var tcs = new TaskCompletionSource(); + _queue.Enqueue(tcs); + _semaphore.WaitAsync().ContinueWith(t => + { + if (_queue.TryDequeue(out TaskCompletionSource popped)) + popped.SetResult(true); + }, cancellationToken); + return tcs.Task; + } + + public void Release() + { + _semaphore.Release(); + } + } + }//namespace diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index 7e9f5a6d67..5d36213755 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -3266,7 +3266,7 @@ internal Task WritePacket(byte flushMode, bool canAccumulate = false) if (willCancel) { - // If we have been cancelled, then ensure that we write the ATTN packet as well + // If we have been canceled, then ensure that we write the ATTN packet as well task = AsyncHelper.CreateContinuationTask(task, CancelWritePacket); } return task; diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs index ea802107ff..9e9e281a32 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs @@ -72,13 +72,13 @@ internal override void AssignPendingDNSInfo(string userProtocol, string DNSCache internal void ReadAsyncCallback(SNIPacket packet, uint error) { ReadAsyncCallback(IntPtr.Zero, PacketHandle.FromManagedPacket(packet), error); - _sessionHandle.ReturnPacket(packet); + _sessionHandle?.ReturnPacket(packet); } internal void WriteAsyncCallback(SNIPacket packet, uint sniError) { WriteAsyncCallback(IntPtr.Zero, PacketHandle.FromManagedPacket(packet), sniError); - _sessionHandle.ReturnPacket(packet); + _sessionHandle?.ReturnPacket(packet); } protected override void RemovePacketFromPendingList(PacketHandle packet) diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AsyncTest/AsyncCancelledConnectionsTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AsyncTest/AsyncCancelledConnectionsTest.cs index dfcd78bc15..7ff00b8335 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AsyncTest/AsyncCancelledConnectionsTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AsyncTest/AsyncCancelledConnectionsTest.cs @@ -12,7 +12,9 @@ namespace Microsoft.Data.SqlClient.ManualTesting.Tests public class AsyncCancelledConnectionsTest { private readonly ITestOutputHelper _output; + private const int NumberOfTasks = 100; // How many attempts to poison the connection pool we will try + private const int NumberOfNonPoisoned = 10; // Number of normal requests for each attempt public AsyncCancelledConnectionsTest(ITestOutputHelper output)