Skip to content

Commit

Permalink
Exchange should wait for remote sinks (#108337)
Browse files Browse the repository at this point in the history
Today, we do not wait for remote sinks to stop before completing the 
main request. While this doesn't affect correctness, it's important  that
we do not spawn child requests after the parent request is completed.

Closes #105859
  • Loading branch information
dnhatn committed May 8, 2024
1 parent 5a622b0 commit ab40808
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 5 deletions.
Expand Up @@ -10,13 +10,15 @@
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.RefCountingListener;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.transport.TransportException;

import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
Expand Down Expand Up @@ -89,6 +91,20 @@ public int bufferSize() {
}
}

public void addCompletionListener(ActionListener<Void> listener) {
buffer.addCompletionListener(ActionListener.running(() -> {
try (RefCountingListener refs = new RefCountingListener(listener)) {
for (PendingInstances pending : List.of(outstandingSinks, outstandingSources)) {
// Create an outstanding instance and then finish to complete the completionListener
// if we haven't registered any instances of exchange sinks or exchange sources before.
pending.trackNewInstance();
pending.completion.addListener(refs.acquire());
pending.finishInstance();
}
}
}));
}

/**
* Create a new {@link ExchangeSource} for exchanging data
*
Expand Down Expand Up @@ -253,10 +269,10 @@ public Releasable addEmptySink() {

private static class PendingInstances {
private final AtomicInteger instances = new AtomicInteger();
private final Releasable onComplete;
private final SubscribableListener<Void> completion = new SubscribableListener<>();

PendingInstances(Releasable onComplete) {
this.onComplete = onComplete;
PendingInstances(Runnable onComplete) {
completion.addListener(ActionListener.running(onComplete));
}

void trackNewInstance() {
Expand All @@ -268,7 +284,7 @@ void finishInstance() {
int refs = instances.decrementAndGet();
assert refs >= 0;
if (refs == 0) {
onComplete.close();
completion.onResponse(null);
}
}
}
Expand Down
Expand Up @@ -55,6 +55,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.function.Supplier;
Expand Down Expand Up @@ -94,6 +95,8 @@ public void testBasic() throws Exception {
ExchangeSink sink1 = sinkExchanger.createExchangeSink();
ExchangeSink sink2 = sinkExchanger.createExchangeSink();
ExchangeSourceHandler sourceExchanger = new ExchangeSourceHandler(3, threadPool.executor(ESQL_TEST_EXECUTOR));
PlainActionFuture<Void> sourceCompletion = new PlainActionFuture<>();
sourceExchanger.addCompletionListener(sourceCompletion);
ExchangeSource source = sourceExchanger.createExchangeSource();
sourceExchanger.addRemoteSink(sinkExchanger::fetchPageAsync, 1);
SubscribableListener<Void> waitForReading = source.waitForReading();
Expand Down Expand Up @@ -133,7 +136,9 @@ public void testBasic() throws Exception {
sink2.finish();
assertTrue(sink2.isFinished());
assertTrue(source.isFinished());
assertFalse(sourceCompletion.isDone());
source.finish();
sourceCompletion.actionGet(10, TimeUnit.SECONDS);
ESTestCase.terminate(threadPool);
for (Page page : pages) {
page.releaseBlocks();
Expand Down Expand Up @@ -320,7 +325,9 @@ protected void start(Driver driver, ActionListener<Void> listener) {

public void testConcurrentWithHandlers() {
BlockFactory blockFactory = blockFactory();
PlainActionFuture<Void> sourceCompletionFuture = new PlainActionFuture<>();
var sourceExchanger = new ExchangeSourceHandler(randomExchangeBuffer(), threadPool.executor(ESQL_TEST_EXECUTOR));
sourceExchanger.addCompletionListener(sourceCompletionFuture);
List<ExchangeSinkHandler> sinkHandlers = new ArrayList<>();
Supplier<ExchangeSink> exchangeSink = () -> {
final ExchangeSinkHandler sinkHandler;
Expand All @@ -336,6 +343,7 @@ public void testConcurrentWithHandlers() {
final int maxInputSeqNo = rarely() ? -1 : randomIntBetween(0, 50_000);
final int maxOutputSeqNo = rarely() ? -1 : randomIntBetween(0, 50_000);
runConcurrentTest(maxInputSeqNo, maxOutputSeqNo, sourceExchanger::createExchangeSource, exchangeSink);
sourceCompletionFuture.actionGet(10, TimeUnit.SECONDS);
}

public void testEarlyTerminate() {
Expand All @@ -358,7 +366,7 @@ public void testEarlyTerminate() {
assertTrue(sink.isFinished());
}

public void testConcurrentWithTransportActions() throws Exception {
public void testConcurrentWithTransportActions() {
MockTransportService node0 = newTransportService();
ExchangeService exchange0 = new ExchangeService(Settings.EMPTY, threadPool, ESQL_TEST_EXECUTOR, blockFactory());
exchange0.registerTransportHandler(node0);
Expand All @@ -371,12 +379,15 @@ public void testConcurrentWithTransportActions() throws Exception {
String exchangeId = "exchange";
Task task = new Task(1, "", "", "", null, Collections.emptyMap());
var sourceHandler = new ExchangeSourceHandler(randomExchangeBuffer(), threadPool.executor(ESQL_TEST_EXECUTOR));
PlainActionFuture<Void> sourceCompletionFuture = new PlainActionFuture<>();
sourceHandler.addCompletionListener(sourceCompletionFuture);
ExchangeSinkHandler sinkHandler = exchange1.createSinkHandler(exchangeId, randomExchangeBuffer());
Transport.Connection connection = node0.getConnection(node1.getLocalNode());
sourceHandler.addRemoteSink(exchange0.newRemoteSink(task, exchangeId, node0, connection), randomIntBetween(1, 5));
final int maxInputSeqNo = rarely() ? -1 : randomIntBetween(0, 50_000);
final int maxOutputSeqNo = rarely() ? -1 : randomIntBetween(0, 50_000);
runConcurrentTest(maxInputSeqNo, maxOutputSeqNo, sourceHandler::createExchangeSource, sinkHandler::createExchangeSink);
sourceCompletionFuture.actionGet(10, TimeUnit.SECONDS);
}
}

Expand Down Expand Up @@ -427,6 +438,8 @@ public void sendResponse(TransportResponse transportResponse) {
String exchangeId = "exchange";
Task task = new Task(1, "", "", "", null, Collections.emptyMap());
var sourceHandler = new ExchangeSourceHandler(randomIntBetween(1, 128), threadPool.executor(ESQL_TEST_EXECUTOR));
PlainActionFuture<Void> sourceCompletionFuture = new PlainActionFuture<>();
sourceHandler.addCompletionListener(sourceCompletionFuture);
ExchangeSinkHandler sinkHandler = exchange1.createSinkHandler(exchangeId, randomIntBetween(1, 128));
Transport.Connection connection = node0.getConnection(node1.getLocalDiscoNode());
sourceHandler.addRemoteSink(exchange0.newRemoteSink(task, exchangeId, node0, connection), randomIntBetween(1, 5));
Expand All @@ -438,6 +451,7 @@ public void sendResponse(TransportResponse transportResponse) {
assertNotNull(cause);
assertThat(cause.getMessage(), equalTo("page is too large"));
sinkHandler.onFailure(new RuntimeException(cause));
sourceCompletionFuture.actionGet(10, TimeUnit.SECONDS);
}
}

Expand Down
Expand Up @@ -205,6 +205,7 @@ public void execute(
RefCountingListener refs = new RefCountingListener(listener.map(unused -> new Result(collectedPages, collectedProfiles)))
) {
// run compute on the coordinator
exchangeSource.addCompletionListener(refs.acquire());
runCompute(
rootTask,
new ComputeContext(sessionId, RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, List.of(), configuration, exchangeSource, null),
Expand Down Expand Up @@ -722,6 +723,7 @@ private void runComputeOnDataNode(
var externalSink = exchangeService.getSinkHandler(externalId);
task.addListener(() -> exchangeService.finishSinkHandler(externalId, new TaskCancelledException(task.getReasonCancelled())));
var exchangeSource = new ExchangeSourceHandler(1, esqlExecutor);
exchangeSource.addCompletionListener(refs.acquire());
exchangeSource.addRemoteSink(internalSink::fetchPageAsync, 1);
ActionListener<Void> reductionListener = cancelOnFailure(task, cancelled, refs.acquire());
runCompute(
Expand Down Expand Up @@ -854,6 +856,7 @@ void runComputeOnRemoteCluster(
RefCountingListener refs = new RefCountingListener(listener.map(unused -> new ComputeResponse(collectedProfiles)))
) {
exchangeSink.addCompletionListener(refs.acquire());
exchangeSource.addCompletionListener(refs.acquire());
PhysicalPlan coordinatorPlan = new ExchangeSinkExec(
plan.source(),
plan.output(),
Expand Down

0 comments on commit ab40808

Please sign in to comment.