From 9004a29cead4952633b199a5f8c24fe353141990 Mon Sep 17 00:00:00 2001 From: Cheena Malhotra Date: Tue, 17 Nov 2020 01:40:54 -0800 Subject: [PATCH] Implementing SNINetworkStream fixes issue 422 for non-encrypted TCP connections. --- .../src/Microsoft.Data.SqlClient.csproj | 2 +- .../SNI/{SNISslStream.cs => SNIStreams.cs} | 42 ++++++++++++++++++- .../Data/SqlClient/SNI/SNITcpHandle.cs | 2 +- 3 files changed, 43 insertions(+), 3 deletions(-) rename src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/{SNISslStream.cs => SNIStreams.cs} (54%) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index 70d506bbf3..2bb10aeca7 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -442,7 +442,7 @@ - + diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNISslStream.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.cs similarity index 54% rename from src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNISslStream.cs rename to src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.cs index a49cc1ca0d..a4beece738 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNISslStream.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.cs @@ -6,6 +6,7 @@ using System.IO; using System.Threading; using System.Threading.Tasks; +using System.Net.Sockets; namespace Microsoft.Data.SqlClient.SNI { @@ -30,7 +31,46 @@ public override Task ReadAsync(byte[] buffer, int offset, int count, Cancel try { return _readAsyncQueueSemaphore.WaitAsync() - .ContinueWith(_ => base.ReadAsync(buffer, offset, count, cancellationToken).GetAwaiter().GetResult()); + .ContinueWith(_ => base.ReadAsync(buffer, offset, count, cancellationToken).GetAwaiter().GetResult()); + } + finally + { + _readAsyncQueueSemaphore.Release(); + } + } + + // Prevent the WriteAsync's collision by running task in Semaphore Slim + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + try + { + return _writeAsyncQueueSemaphore.WaitAsync().ContinueWith(_ => base.WriteAsync(buffer, offset, count, cancellationToken)); + } + finally + { + _writeAsyncQueueSemaphore.Release(); + } + } + } + + internal class SNINetworkStream : NetworkStream + { + private readonly ConcurrentQueueSemaphore _writeAsyncQueueSemaphore; + private readonly ConcurrentQueueSemaphore _readAsyncQueueSemaphore; + + public SNINetworkStream(Socket socket, bool ownsSocket) : base(socket, ownsSocket) + { + _writeAsyncQueueSemaphore = new ConcurrentQueueSemaphore(1); + _readAsyncQueueSemaphore = new ConcurrentQueueSemaphore(1); + } + + // Prevent the ReadAsync's collision by running task in Semaphore Slim + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + try + { + return _readAsyncQueueSemaphore.WaitAsync() + .ContinueWith(_ => base.ReadAsync(buffer, offset, count, cancellationToken).GetAwaiter().GetResult()); } finally { diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs index f919d3cf52..3bf828d749 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs @@ -223,7 +223,7 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba } _socket.NoDelay = true; - _tcpStream = new NetworkStream(_socket, true); + _tcpStream = new SNINetworkStream(_socket, true); _sslOverTdsStream = new SslOverTdsStream(_tcpStream); _sslStream = new SNISslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate));