Skip to content

Commit

Permalink
Reduce the UnicastProcessor footprint
Browse files Browse the repository at this point in the history
Uses volatile references and a single CaS for cancellation.
  • Loading branch information
jponge committed Sep 30, 2021
1 parent d880fd4 commit 380be0d
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 22 deletions.
Expand Up @@ -2,10 +2,9 @@

import java.util.Objects;
import java.util.Queue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;

import org.reactivestreams.Processor;
import org.reactivestreams.Subscriber;
Expand All @@ -32,12 +31,17 @@ public class UnicastProcessor<T> extends AbstractMulti<T> implements Processor<T
private final Runnable onTermination;
private final Queue<T> queue;

private final AtomicBoolean done = new AtomicBoolean();
private final AtomicReference<Throwable> failure = new AtomicReference<>();
private final AtomicBoolean cancelled = new AtomicBoolean();
private volatile boolean done = false;
private volatile Throwable failure = null;
private volatile boolean cancelled = false;

private volatile Subscriber<? super T> downstream = null;
private static final AtomicReferenceFieldUpdater<UnicastProcessor, Subscriber> DOWNSTREAM_UPDATER = AtomicReferenceFieldUpdater
.newUpdater(UnicastProcessor.class, Subscriber.class, "downstream");

private final AtomicInteger wip = new AtomicInteger();
private final AtomicLong requested = new AtomicLong();
private final AtomicReference<Subscriber<? super T>> downstream = new AtomicReference<>();

private volatile boolean hasUpstream;

/**
Expand Down Expand Up @@ -84,12 +88,11 @@ void drainWithDownstream(Subscriber<? super T> actual) {
long e = 0L;

while (r != e) {
boolean d = done.get();

T t = q.poll();
boolean empty = t == null;

if (isCancelledOrDone(d, empty)) {
if (isCancelledOrDone(done, empty)) {
return;
}

Expand All @@ -103,7 +106,7 @@ void drainWithDownstream(Subscriber<? super T> actual) {
}

if (r == e) {
if (isCancelledOrDone(done.get(), q.isEmpty())) {
if (isCancelledOrDone(done, q.isEmpty())) {
return;
}
}
Expand All @@ -126,7 +129,7 @@ private void drain() {

int missed = 1;
for (;;) {
Subscriber<? super T> actual = downstream.get();
Subscriber<? super T> actual = downstream;
if (actual != null) {
drainWithDownstream(actual);
return;
Expand All @@ -139,13 +142,13 @@ private void drain() {
}

private boolean isCancelledOrDone(boolean isDone, boolean isEmpty) {
Subscriber<? super T> subscriber = downstream.get();
if (cancelled.get()) {
Subscriber<? super T> subscriber = downstream;
if (cancelled) {
queue.clear();
return true;
}
if (isDone && isEmpty) {
Throwable failed = failure.get();
Throwable failed = failure;
if (failed != null) {
subscriber.onError(failed);
} else {
Expand Down Expand Up @@ -174,9 +177,9 @@ public void onSubscribe(Subscription upstream) {
@Override
public void subscribe(MultiSubscriber<? super T> downstream) {
ParameterValidation.nonNull(downstream, "downstream");
if (this.downstream.compareAndSet(null, downstream)) {
if (DOWNSTREAM_UPDATER.compareAndSet(this, null, downstream)) {
downstream.onSubscribe(this);
if (!cancelled.get()) {
if (!cancelled) {
drain();
}
} else {
Expand All @@ -198,7 +201,7 @@ public synchronized void onNext(T t) {
}

private boolean isDoneOrCancelled() {
return done.get() || cancelled.get();
return done || cancelled;
}

@Override
Expand All @@ -209,8 +212,8 @@ public void onError(Throwable failure) {
}

onTerminate();
this.failure.set(failure);
this.done.set(true);
this.failure = failure;
this.done = true;

drain();
}
Expand All @@ -221,7 +224,7 @@ public void onComplete() {
return;
}
onTerminate();
this.done.set(true);
this.done = true;
drain();
}

Expand All @@ -235,12 +238,15 @@ public void request(long n) {

@Override
public void cancel() {
if (cancelled.compareAndSet(false, true)) {
if (cancelled) {
return;
}
this.cancelled = true;
if (DOWNSTREAM_UPDATER.getAndSet(this, null) != null) {
onTerminate();
if (wip.getAndIncrement() == 0) {
queue.clear();
}
downstream.set(null);
}
}

Expand All @@ -251,7 +257,7 @@ public void cancel() {
* @return {@code true} if there is a subscriber, {@code false} otherwise
*/
public boolean hasSubscriber() {
return downstream.get() != null;
return downstream != null;
}

public SerializedProcessor<T, T> serialized() {
Expand Down
Expand Up @@ -6,8 +6,11 @@

import java.io.IOException;
import java.util.Queue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.atomic.AtomicInteger;

import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -102,6 +105,35 @@ public void testWithImmediateCancellationFromDownstreamWhileWaitingForUpstreamSu
.assertHasNotReceivedAnyItem();
}

@RepeatedTest(10)
public void testWithConcurrentCancellations() {
AtomicInteger counter = new AtomicInteger();
Queue<Integer> queue = Queues.<Integer> get(1).get();
UnicastProcessor<Integer> processor = UnicastProcessor.create(queue, counter::incrementAndGet);

AssertSubscriber<Integer> sub = processor.subscribe().withSubscriber(AssertSubscriber.create(Long.MAX_VALUE));

ForkJoinPool pool = ForkJoinPool.commonPool();
CountDownLatch start = new CountDownLatch(10);
CountDownLatch done = new CountDownLatch(10);
for (int i = 0; i < 10; i++) {
pool.execute(() -> {
start.countDown();
try {
start.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
processor.cancel();
done.countDown();
});
}

await().until(() -> done.getCount() == 0);
sub.assertSubscribed().assertHasNotReceivedAnyItem().assertNotTerminated();
assertThat(counter).hasValue(1);
}

@Test
public void testOverflow() {
Queue<Integer> queue = Queues.<Integer> get(1).get();
Expand Down

0 comments on commit 380be0d

Please sign in to comment.