diff --git a/implementation/src/main/java/io/smallrye/mutiny/operators/multi/processors/UnicastProcessor.java b/implementation/src/main/java/io/smallrye/mutiny/operators/multi/processors/UnicastProcessor.java index f9dc531ce..0a8bc17a6 100644 --- a/implementation/src/main/java/io/smallrye/mutiny/operators/multi/processors/UnicastProcessor.java +++ b/implementation/src/main/java/io/smallrye/mutiny/operators/multi/processors/UnicastProcessor.java @@ -2,10 +2,9 @@ import java.util.Objects; import java.util.Queue; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import org.reactivestreams.Processor; import org.reactivestreams.Subscriber; @@ -32,12 +31,17 @@ public class UnicastProcessor extends AbstractMulti implements Processor queue; - private final AtomicBoolean done = new AtomicBoolean(); - private final AtomicReference failure = new AtomicReference<>(); - private final AtomicBoolean cancelled = new AtomicBoolean(); + private volatile boolean done = false; + private volatile Throwable failure = null; + private volatile boolean cancelled = false; + + private volatile Subscriber downstream = null; + private static final AtomicReferenceFieldUpdater DOWNSTREAM_UPDATER = AtomicReferenceFieldUpdater + .newUpdater(UnicastProcessor.class, Subscriber.class, "downstream"); + private final AtomicInteger wip = new AtomicInteger(); private final AtomicLong requested = new AtomicLong(); - private final AtomicReference> downstream = new AtomicReference<>(); + private volatile boolean hasUpstream; /** @@ -84,12 +88,11 @@ void drainWithDownstream(Subscriber actual) { long e = 0L; while (r != e) { - boolean d = done.get(); T t = q.poll(); boolean empty = t == null; - if (isCancelledOrDone(d, empty)) { + if (isCancelledOrDone(done, empty)) { return; } @@ -103,7 +106,7 @@ void drainWithDownstream(Subscriber actual) { } if (r == e) { - if (isCancelledOrDone(done.get(), q.isEmpty())) { + if (isCancelledOrDone(done, q.isEmpty())) { return; } } @@ -126,7 +129,7 @@ private void drain() { int missed = 1; for (;;) { - Subscriber actual = downstream.get(); + Subscriber actual = downstream; if (actual != null) { drainWithDownstream(actual); return; @@ -139,13 +142,13 @@ private void drain() { } private boolean isCancelledOrDone(boolean isDone, boolean isEmpty) { - Subscriber subscriber = downstream.get(); - if (cancelled.get()) { + Subscriber subscriber = downstream; + if (cancelled) { queue.clear(); return true; } if (isDone && isEmpty) { - Throwable failed = failure.get(); + Throwable failed = failure; if (failed != null) { subscriber.onError(failed); } else { @@ -174,9 +177,9 @@ public void onSubscribe(Subscription upstream) { @Override public void subscribe(MultiSubscriber downstream) { ParameterValidation.nonNull(downstream, "downstream"); - if (this.downstream.compareAndSet(null, downstream)) { + if (DOWNSTREAM_UPDATER.compareAndSet(this, null, downstream)) { downstream.onSubscribe(this); - if (!cancelled.get()) { + if (!cancelled) { drain(); } } else { @@ -198,7 +201,7 @@ public synchronized void onNext(T t) { } private boolean isDoneOrCancelled() { - return done.get() || cancelled.get(); + return done || cancelled; } @Override @@ -209,8 +212,8 @@ public void onError(Throwable failure) { } onTerminate(); - this.failure.set(failure); - this.done.set(true); + this.failure = failure; + this.done = true; drain(); } @@ -221,7 +224,7 @@ public void onComplete() { return; } onTerminate(); - this.done.set(true); + this.done = true; drain(); } @@ -235,12 +238,15 @@ public void request(long n) { @Override public void cancel() { - if (cancelled.compareAndSet(false, true)) { + if (cancelled) { + return; + } + this.cancelled = true; + if (DOWNSTREAM_UPDATER.getAndSet(this, null) != null) { onTerminate(); if (wip.getAndIncrement() == 0) { queue.clear(); } - downstream.set(null); } } @@ -251,7 +257,7 @@ public void cancel() { * @return {@code true} if there is a subscriber, {@code false} otherwise */ public boolean hasSubscriber() { - return downstream.get() != null; + return downstream != null; } public SerializedProcessor serialized() { diff --git a/implementation/src/test/java/io/smallrye/mutiny/operators/multi/processors/UnicastProcessorTest.java b/implementation/src/test/java/io/smallrye/mutiny/operators/multi/processors/UnicastProcessorTest.java index 3da1504bc..122ffa0f5 100644 --- a/implementation/src/test/java/io/smallrye/mutiny/operators/multi/processors/UnicastProcessorTest.java +++ b/implementation/src/test/java/io/smallrye/mutiny/operators/multi/processors/UnicastProcessorTest.java @@ -6,8 +6,11 @@ import java.io.IOException; import java.util.Queue; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.RepeatedTest; import org.junit.jupiter.api.Test; @@ -102,6 +105,35 @@ public void testWithImmediateCancellationFromDownstreamWhileWaitingForUpstreamSu .assertHasNotReceivedAnyItem(); } + @RepeatedTest(10) + public void testWithConcurrentCancellations() { + AtomicInteger counter = new AtomicInteger(); + Queue queue = Queues. get(1).get(); + UnicastProcessor processor = UnicastProcessor.create(queue, counter::incrementAndGet); + + AssertSubscriber sub = processor.subscribe().withSubscriber(AssertSubscriber.create(Long.MAX_VALUE)); + + ForkJoinPool pool = ForkJoinPool.commonPool(); + CountDownLatch start = new CountDownLatch(10); + CountDownLatch done = new CountDownLatch(10); + for (int i = 0; i < 10; i++) { + pool.execute(() -> { + start.countDown(); + try { + start.await(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + processor.cancel(); + done.countDown(); + }); + } + + await().until(() -> done.getCount() == 0); + sub.assertSubscribed().assertHasNotReceivedAnyItem().assertNotTerminated(); + assertThat(counter).hasValue(1); + } + @Test public void testOverflow() { Queue queue = Queues. get(1).get();