Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP | Prevent WriteAsync collision #579

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -3,8 +3,11 @@
// See the LICENSE file in the project root for more information.

using System;
using System.IO;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Data.SqlClient.SNI
{
Expand Down Expand Up @@ -99,6 +102,26 @@ internal enum SNISMUXFlags
SMUX_DATA = 8 // SMUX data packet
}

internal class SslStreamProxy : SslStream
{
private Task _currentTask;

public SslStreamProxy(Stream innerStream, bool leaveInnerStreamOpen, RemoteCertificateValidationCallback userCertificateValidationCallback)
: base(innerStream, leaveInnerStreamOpen, userCertificateValidationCallback)
{ }

// Prevent the WriteAsync's collision
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
if (_currentTask != null && _currentTask.Status != TaskStatus.RanToCompletion)
{
_currentTask.Wait(cancellationToken);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless I'm mistaken, this is a synchronous call which blocks the current thread - that's a pretty bad idea in an async method (mixing sync and async can lead to all kinds of pain). Since _currentTask represents an asynchronous task (if I'm getting the code right), this is also a case of sync-over-async, which can cause deadlocks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to fix the issue #459 (and some more related threads) and as we see in MS documentation for WriteAsync it says, "The SslStream class does not support multiple simultaneous write operations.", which I agree is against Async API design but I'm not sure if we have a better solution here. I guess that's why most of our APIs are sync-over-async but this mix and match of sync/async is one of the reasons behind colliding Async APIs.

Do you have a better solution in this case?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The driver definitely needs to avoid starting a write if another one is in progress (the same is true for reads); this isn't just true for SslStream, but also for NetworkStream - and pretty much anything that does networking. FWIW I don't think this restriction is against async API design - async is about not blocking the thread, rather than about allowing multiple concurrent operations to occur on the same resource (i.e. connection).

More concretely... I don't doubt the importance of this fix (I have no context here) - all I'm saying is that any waiting for previous writes to complete should be done asynchronously, otherwise you're likely creating a new category of future bugs. One easy way to do this would be to have a SemaphoreSlim on the connection representing writing, and to simply call await WaitAsync on it. This would ensure only one writing operation occurs at any given moment, while making all waiting happens asynchronously. There may be some perf implication here, but I really don't have enough context to know how hot this code path is etc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW here's an example scenario on where sync-over-async could go wrong.

Say you have a lot of writers all of a sudden (e.g. because of a thundering herd effect or whatever). One of the writers wins and starts to do their thing; meanwhile, all the other enter the Wait call, blocking the thread. If your writers happen to be executed on the thread pool (which they are in ASP.NET, for example), then you've just exhausted the thread pool completely. The really bad part, is that when your very first writer completes - the one which actually started writing - it also needs a thread pool thread to execute the continuation (the code that runs after the async operation completed), but the pool is exhausted. This is a pseudo-deadlock, and your application will hang until the thread pool allocates new threads, which can take quite a while.

Since I'm lacking context, I don't know if the above can occur here specifically. But with all the bad scenarios around mixing sync and async code, it really best be avoided...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @roji
Thank you for the valuable point.

Since the sequence of calling the WriteAsync method is important in this scenario, and SemaphoreSlim doesn't respect the order, Do you have any solution/guidance for this case?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DavoudEshtehari apologies for disappearing for so long, I had some personal issues.

Looking at the current method as it is in the PR, I don't see any enforcing of ordering - if _currentTask is not-null (i.e. someone else is writing), we wait until that Task is over. I'm basically proposing substituting _currentTask for SemaphoreSlim as the synchronization check, that's all.

If you need to have multiple writers wait in line based on FIFO, then a more elaborate solution with ConcurrentQueue can be built - but I'd definitely need more context to understand needs here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@roji thank you for your kind reply.

The issue might occur when more than one write request waits behind the semaphore to get permission to run the WriteAsync function. I think this is a possible condition.

Also, I've tried queue patterns but touching the WriteAsync throws the exception even though I wrap it up with a cold task.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that with the current code, if more than one write request Waits on the Task, both will get unblocked the moment the task completes, and will concurrently call into base.WriteAsync - which is what this PR is trying to prevent. SemaphoreSlim would easily take care of that for you - when its WaitAsync method returns, it's guaranteed that only one request has acquired the semaphore. However, there is no guarantee of ordering - if more than one attempt did WaitAsync, they would get unblocked in "random" order.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @roji

Thanks for your tips, I got an implementation of Concurrent Queue with Semaphore Slim working (PR #796), would request you to take a look too.

However, this makes me think couldn't it be done by SslStream class itself? So that libraries don't land into surprises... It seems a basic case of ReadAsync/WriteAsync + parallel tasks is not supported by them and it can cause dangerous problems. Maybe a thread-safe design can be proposed to System.Net.Security?

}
_currentTask = base.WriteAsync(buffer, offset, count, cancellationToken);
return _currentTask;
}
}

internal class SNICommon
{
// Each error number maps to SNI_ERROR_* in String.resx
Expand Down
Expand Up @@ -93,7 +93,7 @@ public SNINpHandle(string serverName, string pipeName, long timerExpire, object
}

_sslOverTdsStream = new SslOverTdsStream(_pipeStream);
_sslStream = new SslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate), null);
_sslStream = new SslStreamProxy(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate));

_stream = _pipeStream;
_status = TdsEnums.SNI_SUCCESS;
Expand Down Expand Up @@ -325,12 +325,17 @@ public override uint Send(SNIPacket packet)
public override uint SendAsync(SNIPacket packet, SNIAsyncCallback callback = null)
{
long scopeID = SqlClientEventSource.Log.TrySNIScopeEnterEvent("<sc.SNI.SNINpHandle.SendAsync |SNI|SCOPE>");
SNIAsyncCallback cb = callback ?? _sendCallback;
try
{
SNIAsyncCallback cb = callback ?? _sendCallback;
packet.WriteToStreamAsync(_stream, cb, SNIProviders.NP_PROV);
return TdsEnums.SNI_SUCCESS_IO_PENDING;
}
catch (Exception e) when (e is ObjectDisposedException || e is InvalidOperationException || e is IOException)
{
SNIPacket errorPacket = packet;
return ReportErrorAndReleasePacket(errorPacket, e);
}
finally
{
SqlClientEventSource.Log.TrySNIScopeLeaveEvent(scopeID);
Expand Down
Expand Up @@ -304,16 +304,21 @@ public void WriteToStream(Stream stream)
public async void WriteToStreamAsync(Stream stream, SNIAsyncCallback callback, SNIProviders provider)
{
uint status = TdsEnums.SNI_SUCCESS;
try
{
await stream.WriteAsync(_data, 0, _dataLength, CancellationToken.None).ConfigureAwait(false);
}
catch (Exception e)

await stream.WriteAsync(_data, 0, _dataLength, CancellationToken.None).ContinueWith(t =>
{
SNILoadHandle.SingletonInstance.LastError = new SNIError(provider, SNICommon.InternalExceptionError, e);
status = TdsEnums.SNI_ERROR;
}
callback(this, status);
Exception e = t.Exception?.InnerException;
if (e != null)
{
SNILoadHandle.SingletonInstance.LastError = new SNIError(provider, SNICommon.InternalExceptionError, e);
status = TdsEnums.SNI_ERROR;
Release();
cheenamalhotra marked this conversation as resolved.
Show resolved Hide resolved
}
callback(this, status);
},
CancellationToken.None,
TaskContinuationOptions.DenyChildAttach,
TaskScheduler.Default);
}
}
}
Expand Up @@ -226,7 +226,7 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
_tcpStream = new NetworkStream(_socket, true);

_sslOverTdsStream = new SslOverTdsStream(_tcpStream);
_sslStream = new SslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate), null);
_sslStream = new SslStreamProxy(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate));
}
catch (SocketException se)
{
Expand Down Expand Up @@ -331,7 +331,7 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo
}

CancellationTokenSource cts = null;

void Cancel()
{
for (int i = 0; i < sockets.Length; ++i)
Expand All @@ -355,7 +355,7 @@ void Cancel()
}

Socket availableSocket = null;
try
try
{
for (int i = 0; i < sockets.Length; ++i)
{
Expand Down Expand Up @@ -706,12 +706,22 @@ public override void SetAsyncCallbacks(SNIAsyncCallback receiveCallback, SNIAsyn
/// <returns>SNI error code</returns>
public override uint SendAsync(SNIPacket packet, SNIAsyncCallback callback = null)
{
long scopeID = SqlClientEventSource.Log.TrySNIScopeEnterEvent("<sc.SNI.SNITcpHandle.SendAsync |SNI|SCOPE>");
SNIAsyncCallback cb = callback ?? _sendCallback;
lock (this)
try
{
packet.WriteToStreamAsync(_stream, cb, SNIProviders.TCP_PROV);
return TdsEnums.SNI_SUCCESS_IO_PENDING;
}
catch (Exception e) when (e is ObjectDisposedException || e is InvalidOperationException || e is IOException)
cheenamalhotra marked this conversation as resolved.
Show resolved Hide resolved
{
SNIPacket errorPacket = packet;
return ReportErrorAndReleasePacket(errorPacket, e);
}
finally
{
SqlClientEventSource.Log.TrySNIScopeLeaveEvent(scopeID);
}
return TdsEnums.SNI_SUCCESS_IO_PENDING;
}

/// <summary>
Expand Down
Expand Up @@ -66,6 +66,7 @@
<Compile Include="SQL\SqlBulkCopyTest\OrderHintMissingTargetColumn.cs" />
<Compile Include="SQL\SqlBulkCopyTest\OrderHintTransaction.cs" />
<Compile Include="SQL\DataClassificationTest\DataClassificationTest.cs" />
<Compile Include="SQL\SqlCommand\SqlCommandExecuteTest.cs" />
<Compile Include="TracingTests\EventSourceTest.cs" />
<Compile Include="SQL\AdapterTest\AdapterTest.cs" />
<Compile Include="SQL\AsyncTest\BeginExecAsyncTest.cs" />
Expand Down
@@ -1,4 +1,8 @@
using System.Data;
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Data;
using Xunit;

namespace Microsoft.Data.SqlClient.ManualTesting.Tests
Expand Down
@@ -0,0 +1,141 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Xunit;

namespace Microsoft.Data.SqlClient.ManualTesting.Tests
{
public static class SqlCommandExecuteTest
{
[Theory]
[MemberData(nameof(GetConnectionStrings))]
public static void ExecuteReaderAsyncTest(string connectionString)
{
int counter = 100;
while (counter-- > 0)
{
ExecuteReaderAsync(connectionString, CancellationToken.None).GetAwaiter().GetResult();
}
}

[Theory]
[MemberData(nameof(GetConnectionStrings))]
public static void ExecuteScalarAsyncTest(string connectionString)
{
int counter = 100;
while (counter-- > 0)
{
ExecuteScalarAsync(connectionString, CancellationToken.None).GetAwaiter().GetResult();
}
}

[Theory]
[MemberData(nameof(GetConnectionStrings))]
public static void ExecuteNonQueryAsyncTest(string connectionString)
{
int counter = 100;
while (counter-- > 0)
{
ExecuteNonQueryAsync(connectionString, CancellationToken.None).GetAwaiter().GetResult();
}
}

[Theory]
[MemberData(nameof(GetConnectionStrings))]
public static void ExecuteXmlReaderAsyncTest(string connectionString)
{
int counter = 100;
while (counter-- > 0)
{
ExecuteXmlReaderAsync(connectionString, CancellationToken.None).GetAwaiter().GetResult();
}
}

#region Execute Async
private static async Task ExecuteReaderAsync(string connectionString, CancellationToken token)
{
using (var connection = new SqlConnection(connectionString))
using (var cmd = GetCommand(connection))
using (var r = await cmd.ExecuteReaderAsync(token))
{
while (await r.ReadAsync(token))
{
await r.GetFieldValueAsync<string>(0);
await r.GetFieldValueAsync<string>(1);
await r.GetFieldValueAsync<string>(2);
}
}
}

private static async Task ExecuteScalarAsync(string connectionString, CancellationToken token)
{
using (var connection = new SqlConnection(connectionString))
using (var cmd = GetCommand(connection))
{
await cmd.ExecuteScalarAsync(token);
}
}

private static async Task ExecuteNonQueryAsync(string connectionString, CancellationToken token)
{
using (var connection = new SqlConnection(connectionString))
using (var cmd = GetCommand(connection))
{
await cmd.ExecuteNonQueryAsync(token);
}
}

private static async Task ExecuteXmlReaderAsync(string connectionString, CancellationToken token)
{
using (var connection = new SqlConnection(connectionString))
using (var cmd = GetCommand(connection))
{
cmd.CommandText += " FOR XML AUTO, XMLDATA";
var r = await cmd.ExecuteXmlReaderAsync(token);
}
}

private static SqlCommand GetCommand(SqlConnection cnn)
{
string aRecord = "('2455cf1b-ebcf-418d-8cce-88e21e1683e3', 'something', 'updated'),";
string query = "SELECT * FROM (VALUES"
+ string.Concat(Enumerable.Repeat(aRecord, 200)).Substring(0, (aRecord.Length * 200) - 1)
+ ") tbl_A ([Id], [Name], [State])";
cnn.Open();
var cmd = cnn.CreateCommand();
cmd.CommandText = query;
return cmd;
}
#endregion

public static IEnumerable<object[]> GetConnectionStrings()
{
SqlConnectionStringBuilder builder;
foreach (var item in DataTestUtility.ConnectionStrings)
{
builder = new SqlConnectionStringBuilder(item)
{
TrustServerCertificate = true,
Encrypt = true,
MultipleActiveResultSets = false,
ConnectTimeout = 10,
ConnectRetryCount = 3,
ConnectRetryInterval = 10,
LoadBalanceTimeout = 60,
MaxPoolSize = 10,
MinPoolSize = 0
};
yield return new object[] { builder.ConnectionString };

builder.TrustServerCertificate = false;
builder.Encrypt = false;
yield return new object[] { builder.ConnectionString };
}
}
}
}
@@ -1,8 +1,10 @@
using System;
using System.Collections.Generic;
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Data.SqlTypes;
using System.Reflection;
using System.Text;
using Xunit;

namespace Microsoft.Data.SqlClient.ManualTesting.Tests
Expand Down