diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.cs index aed7eac964..cd91b5e8e7 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.cs @@ -15,23 +15,22 @@ namespace Microsoft.Data.SqlClient.SNI /// internal class SNISslStream : SslStream { - private readonly SemaphoreSlim _writeAsyncSemaphore; - private readonly SemaphoreSlim _readAsyncSemaphore; + private readonly ConcurrentQueueSemaphore _writeAsyncSemaphore; + private readonly ConcurrentQueueSemaphore _readAsyncSemaphore; public SNISslStream(Stream innerStream, bool leaveInnerStreamOpen, RemoteCertificateValidationCallback userCertificateValidationCallback) : base(innerStream, leaveInnerStreamOpen, userCertificateValidationCallback) { - _writeAsyncSemaphore = new SemaphoreSlim(1); - _readAsyncSemaphore = new SemaphoreSlim(1); + _writeAsyncSemaphore = new ConcurrentQueueSemaphore(1); + _readAsyncSemaphore = new ConcurrentQueueSemaphore(1); } // Prevent ReadAsync collisions by running the task in a Semaphore Slim public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - await _readAsyncSemaphore.WaitAsync().ConfigureAwait(false); try { - return await base.ReadAsync(buffer, offset, count, cancellationToken); + return await await _readAsyncSemaphore.WaitAsync().ContinueWith(t => base.ReadAsync(buffer, offset, count, cancellationToken)); } finally { @@ -42,10 +41,9 @@ public override async Task ReadAsync(byte[] buffer, int offset, int count, // Prevent the WriteAsync collisions by running the task in a Semaphore Slim public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - await _writeAsyncSemaphore.WaitAsync().ConfigureAwait(false); try { - await base.WriteAsync(buffer, offset, count, cancellationToken); + await await _writeAsyncSemaphore.WaitAsync().ContinueWith(t => base.WriteAsync(buffer, offset, count, cancellationToken)); } finally { @@ -59,22 +57,21 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc /// internal class SNINetworkStream : NetworkStream { - private readonly SemaphoreSlim _writeAsyncSemaphore; - private readonly SemaphoreSlim _readAsyncSemaphore; + private readonly ConcurrentQueueSemaphore _writeAsyncSemaphore; + private readonly ConcurrentQueueSemaphore _readAsyncSemaphore; public SNINetworkStream(Socket socket, bool ownsSocket) : base(socket, ownsSocket) { - _writeAsyncSemaphore = new SemaphoreSlim(1); - _readAsyncSemaphore = new SemaphoreSlim(1); + _writeAsyncSemaphore = new ConcurrentQueueSemaphore(1); + _readAsyncSemaphore = new ConcurrentQueueSemaphore(1); } - // Prevent the ReadAsync collisions by running the task in a Semaphore Slim + // Prevent ReadAsync collisions by running the task in a Semaphore Slim public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - await _readAsyncSemaphore.WaitAsync().ConfigureAwait(false); try { - return await base.ReadAsync(buffer, offset, count, cancellationToken); + return await await _readAsyncSemaphore.WaitAsync().ContinueWith(t => base.ReadAsync(buffer, offset, count, cancellationToken)); } finally { @@ -85,10 +82,9 @@ public override async Task ReadAsync(byte[] buffer, int offset, int count, // Prevent the WriteAsync collisions by running the task in a Semaphore Slim public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - await _writeAsyncSemaphore.WaitAsync().ConfigureAwait(false); try { - await base.WriteAsync(buffer, offset, count, cancellationToken); + await await _writeAsyncSemaphore.WaitAsync().ContinueWith(t => base.WriteAsync(buffer, offset, count, cancellationToken)); } finally { diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs index dda5f8e958..a6b9688afa 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs @@ -2133,4 +2133,47 @@ public static MethodInfo GetPromotedToken } } + /// + /// This class implements a FIFO Queue with SemaphoreSlim for ordered execution of parallel tasks. + /// Currently used in Managed SNI (SNISslStream) to override SslStream's WriteAsync implementation. + /// + internal class ConcurrentQueueSemaphore + { + private readonly SemaphoreSlim _semaphore; + private readonly ConcurrentQueue> _queue = + new ConcurrentQueue>(); + + public ConcurrentQueueSemaphore(int initialCount) + { + _semaphore = new SemaphoreSlim(initialCount); + } + + public ConcurrentQueueSemaphore(int initialCount, int maxCount) + { + _semaphore = new SemaphoreSlim(initialCount, maxCount); + } + + public void Wait() + { + WaitAsync().Wait(); + } + + public Task WaitAsync() + { + var tcs = new TaskCompletionSource(); + _queue.Enqueue(tcs); + _semaphore.WaitAsync().ContinueWith(t => + { + if (_queue.TryDequeue(out TaskCompletionSource popped)) + popped.SetResult(true); + }); + return tcs.Task; + } + + public void Release() + { + _semaphore.Release(); + } + } + }//namespace