Skip to content

Commit

Permalink
Ensure that the OnAbort message is sent if the testhost aborts early (#…
Browse files Browse the repository at this point in the history
…3993)

* Invoke disconnected handler if client is already disconnected

If the testhost disconnects before the Discover/Run request is called
we would return back the abort message.
  • Loading branch information
drognanar committed Nov 24, 2022
1 parent 94cde14 commit 107873f
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Diagnostics;
using System.Globalization;
using System.Threading;
using System.Threading.Tasks;

using Microsoft.VisualStudio.TestPlatform.CommunicationUtilities.Interfaces;
using Microsoft.VisualStudio.TestPlatform.CommunicationUtilities.ObjectModel;
Expand Down Expand Up @@ -40,6 +41,7 @@ public class TestRequestSender : ITestRequestSender

private ICommunicationChannel? _channel;
private EventHandler<MessageReceivedEventArgs>? _onMessageReceived;
private DisconnectedEventArgs? _disconnectedInfo;
private Action<DisconnectedEventArgs>? _onDisconnected;
// Set to 1 if Discovery/Execution is complete, i.e. complete handlers have been invoked
private int _operationCompleted;
Expand Down Expand Up @@ -151,8 +153,14 @@ public int InitializeCommunication()
};

_communicationEndpoint.Disconnected += (sender, args) =>
{
// Store the disconnected info, so that any further DiscoverTests,
// RunTests methods can immediately bail.
_disconnectedInfo = args;
// If there's an disconnected event handler, call it
_onDisconnected?.Invoke(args);
InvokeDisconnectedHandler(args);
};

// Server start returns the listener port
// return int.Parse(this.communicationServer.Start());
Expand All @@ -161,6 +169,46 @@ public int InitializeCommunication()
return endpoint.GetIpEndPoint().Port;
}


private bool TrySetupMessageReceiver(
EventHandler<MessageReceivedEventArgs> onMessageReceived,
Action<DisconnectedEventArgs> onDisconnected)
{
TPDebug.Assert(_channel is not null, "_channel is null");

// Note: Attempts to setup a message receiver.
// It's possible that the testhost was already disconnected and in that case we should
// immediately call the disconnected callback.

// Design: The current method is needed because the request sender sets up
// the disconnect handler late. If the first thing that is done by the class
// is to setup the disconnect handler, then we'd only need to fire the handler
// when the disconnect event fires.

_onDisconnected = onDisconnected;

// If the testhost was already disconnected, trigger the handler immediately.
if (_disconnectedInfo is DisconnectedEventArgs args)
{
InvokeDisconnectedHandler(args);
return false;
}

_onMessageReceived = onMessageReceived;
_channel.MessageReceived += _onMessageReceived;

return true;
}

private void InvokeDisconnectedHandler(DisconnectedEventArgs args)
{
// Note: If the endpoint is disconnected at the same time as the
// disconnected handler is setup, it's possible for this method
// to be invoked twice. Ensure that the handler ever gets invoked once.
var handler = Interlocked.Exchange(ref _onDisconnected, null);
handler?.Invoke(args);
}

/// <inheritdoc />
public bool WaitForRequestHandlerConnection(int connectionTimeout, CancellationToken cancellationToken)
{
Expand Down Expand Up @@ -189,7 +237,8 @@ public void CheckVersionWithTestHost()
// Test host sends back the lower number of the two. So the highest protocol version, that both sides support is used.
// Error case: test host can send a protocol error if it cannot find a supported version
var protocolNegotiated = new ManualResetEvent(false);
_onMessageReceived = (sender, args) =>

EventHandler<MessageReceivedEventArgs> onMessageReceived = (sender, args) =>
{
var message = _dataSerializer.DeserializeMessage(args.Data!);
Expand Down Expand Up @@ -221,7 +270,7 @@ public void CheckVersionWithTestHost()
protocolNegotiated.Set();
};
_channel.MessageReceived += _onMessageReceived;
_channel.MessageReceived += onMessageReceived;

try
{
Expand All @@ -242,8 +291,7 @@ public void CheckVersionWithTestHost()
}
finally
{
_channel.MessageReceived -= _onMessageReceived;
_onMessageReceived = null;
_channel.MessageReceived -= onMessageReceived;
}
}

Expand All @@ -270,10 +318,13 @@ public void DiscoverTests(DiscoveryCriteria discoveryCriteria, ITestDiscoveryEve
_messageEventHandler = discoveryEventsHandler;
// When testhost disconnects, it normally means there was an error in the testhost and it exited unexpectedly.
// But when it was us who aborted the run and killed the testhost, we don't want to wait for it to report error, because there won't be any.
_onDisconnected = disconnectedEventArgs => OnDiscoveryAbort(discoveryEventsHandler, disconnectedEventArgs.Error, getClientError: !_isDiscoveryAborted);
_onMessageReceived = (sender, args) => OnDiscoveryMessageReceived(discoveryEventsHandler, args);
if (!TrySetupMessageReceiver(
onMessageReceived: (_, args) => OnDiscoveryMessageReceived(discoveryEventsHandler, args),
onDisconnected: disconnectedEventArgs => OnDiscoveryAbort(discoveryEventsHandler, disconnectedEventArgs.Error, getClientError: !_isDiscoveryAborted)))
{
return;
}

_channel.MessageReceived += _onMessageReceived;
var message = _dataSerializer.SerializePayload(
MessageType.StartDiscovery,
discoveryCriteria,
Expand Down Expand Up @@ -320,10 +371,13 @@ public void StartTestRun(TestRunCriteriaWithSources runCriteria, IInternalTestRu
{
TPDebug.Assert(_channel is not null, "_channel is null");
_messageEventHandler = eventHandler;
_onDisconnected = (disconnectedEventArgs) => OnTestRunAbort(eventHandler, disconnectedEventArgs.Error, true);

_onMessageReceived = (sender, args) => OnExecutionMessageReceived(args, eventHandler);
_channel.MessageReceived += _onMessageReceived;
if (!TrySetupMessageReceiver(
onMessageReceived: (_, args) => OnExecutionMessageReceived(args, eventHandler),
onDisconnected: disconnectedEventArgs => OnTestRunAbort(eventHandler, disconnectedEventArgs.Error, true)))
{
return;
}

// This code section is needed because we altered the old testhost launch process for
// the debugging workflow. Now we don't ask VS to launch and attach to the testhost
Expand Down Expand Up @@ -360,10 +414,13 @@ public void StartTestRun(TestRunCriteriaWithTests runCriteria, IInternalTestRunE
{
TPDebug.Assert(_channel is not null, "_channel is null");
_messageEventHandler = eventHandler;
_onDisconnected = (disconnectedEventArgs) => OnTestRunAbort(eventHandler, disconnectedEventArgs.Error, true);

_onMessageReceived = (sender, args) => OnExecutionMessageReceived(args, eventHandler);
_channel.MessageReceived += _onMessageReceived;
if (!TrySetupMessageReceiver(
onMessageReceived: (_, args) => OnExecutionMessageReceived(args, eventHandler),
onDisconnected: disconnectedEventArgs => OnTestRunAbort(eventHandler, disconnectedEventArgs.Error, true)))
{
return;
}

// This code section is needed because we altered the old testhost launch process for
// the debugging workflow. Now we don't ask VS to launch and attach to the testhost
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Net;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;

using Microsoft.VisualStudio.TestPlatform.CommunicationUtilities;
using Microsoft.VisualStudio.TestPlatform.CommunicationUtilities.Interfaces;
Expand Down Expand Up @@ -470,6 +471,18 @@ public void DiscoverTestShouldNotifyLogMessageIfClientDisconnectedWithClientExit
_mockDiscoveryEventsHandler.Verify(eh => eh.HandleRawMessage(It.Is<string>(s => !string.IsNullOrEmpty(s) && s.Equals("Serialized Stderr"))), Times.Once);
}

[TestMethod]
public void DiscoverTestShouldNotifyDiscoveryCompleteIfClientDisconnectedBeforeDiscovery()
{
SetupFakeCommunicationChannel();

RaiseClientDisconnectedEvent();

_testRequestSender.DiscoverTests(new DiscoveryCriteria(), _mockDiscoveryEventsHandler.Object);

_mockDiscoveryEventsHandler.Verify(eh => eh.HandleDiscoveryComplete(It.Is<DiscoveryCompleteEventArgs>(dc => dc.IsAborted == true && dc.TotalCount == -1), null));
}

[TestMethod]
public void DiscoverTestShouldNotifyDiscoveryCompleteIfClientDisconnected()
{
Expand Down Expand Up @@ -746,6 +759,52 @@ public void StartTestRunShouldNotifyErrorLogMessageIfClientDisconnectedWithClien
_mockExecutionEventsHandler.Verify(eh => eh.HandleLogMessage(TestMessageLevel.Error, It.Is<string>(s => s.Contains(expectedErrorMessage))), Times.Once);
}

[TestMethod]
public void StartTestRunShouldNotifyExecutionCompleteIfClientDisconnectedBeforeRun()
{
SetupOperationAbortedPayload();
SetupFakeCommunicationChannel();

RaiseClientDisconnectedEvent();

_testRequestSender.StartTestRun(_testRunCriteriaWithSources, _mockExecutionEventsHandler.Object);

_mockExecutionEventsHandler.Verify(eh => eh.HandleTestRunComplete(It.Is<TestRunCompleteEventArgs>(t => t.IsAborted), null, null, null), Times.Once);
_mockExecutionEventsHandler.Verify(eh => eh.HandleRawMessage("SerializedAbortedPayload"), Times.Once);
}

[TestMethod]
public void StartTestRunWithTestsShouldNotifyExecutionCompleteIfClientDisconnectedBeforeRun()
{
var runCriteria = new TestRunCriteriaWithTests(new TestCase[2], "runsettings", null, null!);
SetupOperationAbortedPayload();
SetupFakeCommunicationChannel();

RaiseClientDisconnectedEvent();

_testRequestSender.StartTestRun(runCriteria, _mockExecutionEventsHandler.Object);

_mockExecutionEventsHandler.Verify(eh => eh.HandleTestRunComplete(It.Is<TestRunCompleteEventArgs>(t => t.IsAborted), null, null, null), Times.Once);
_mockExecutionEventsHandler.Verify(eh => eh.HandleRawMessage("SerializedAbortedPayload"), Times.Once);
}

[TestMethod]
public async Task StartTestRunWithTestsShouldNotifyExecutionCompleteIfClientDisconnectedBeforeRunInAThreadSafeWay()
{
var runCriteria = new TestRunCriteriaWithTests(new TestCase[2], "runsettings", null, null!);
SetupOperationAbortedPayload();
SetupFakeCommunicationChannel();

// Note: Even if the calls get invoked on separate threads, the request sender should send back the complete message just once.
var t1 = Task.Run(RaiseClientDisconnectedEvent);
var t2 = Task.Run(() => _testRequestSender.StartTestRun(runCriteria, _mockExecutionEventsHandler.Object));

await Task.WhenAll(t1, t2);

_mockExecutionEventsHandler.Verify(eh => eh.HandleTestRunComplete(It.Is<TestRunCompleteEventArgs>(t => t.IsAborted), null, null, null), Times.Once);
_mockExecutionEventsHandler.Verify(eh => eh.HandleRawMessage("SerializedAbortedPayload"), Times.Once);
}

[TestMethod]
public void StartTestRunShouldNotifyExecutionCompleteIfClientDisconnected()
{
Expand Down

0 comments on commit 107873f

Please sign in to comment.