diff --git a/reactor-core/src/main/java/reactor/core/publisher/MonoCollect.java b/reactor-core/src/main/java/reactor/core/publisher/MonoCollect.java index e5f87051e3..2753b01146 100644 --- a/reactor-core/src/main/java/reactor/core/publisher/MonoCollect.java +++ b/reactor-core/src/main/java/reactor/core/publisher/MonoCollect.java @@ -64,6 +64,8 @@ static final class CollectSubscriber extends Operators.MonoSubscriber action; + R container; + Subscription s; boolean done; @@ -73,35 +75,17 @@ static final class CollectSubscriber extends Operators.MonoSubscriber c = (Collection) v; - Operators.onDiscardMultiple(c, actual.currentContext()); - } - else { - super.discard(v); - } - } - @Override public void onSubscribe(Subscription s) { if (Operators.validate(this.s, s)) { @@ -119,15 +103,22 @@ public void onNext(T t) { Operators.onNextDropped(t, actual.currentContext()); return; } - - try { - action.accept(value, t); - } - catch (Throwable e) { - Context ctx = actual.currentContext(); - Operators.onDiscard(t, ctx); - onError(Operators.onOperatorError(this, e, t, ctx)); + R c; + synchronized (this) { + c = container; + if (c != null) { + try { + action.accept(c, t); + } + catch (Throwable e) { + Context ctx = actual.currentContext(); + Operators.onDiscard(t, ctx); + onError(Operators.onOperatorError(this, e, t, ctx)); + } + return; + } } + Operators.onDiscard(t, actual.currentContext()); } @Override @@ -137,9 +128,12 @@ public void onError(Throwable t) { return; } done = true; - R v = value; - discard(v); - value = null; + R c; + synchronized (this) { + c = container; + container = null; + } + discard(c); actual.onError(t); } @@ -149,13 +143,46 @@ public void onComplete() { return; } done = true; - complete(value); + R c; + synchronized (this) { + c = container; + container = null; + } + if (c != null) { + complete(c); + } } @Override - public void setValue(R value) { - // value is constant + protected void discard(R v) { + if (v instanceof Collection) { + Collection c = (Collection) v; + Operators.onDiscardMultiple(c, actual.currentContext()); + } + else { + super.discard(v); + } } + @Override + public void cancel() { + int state; + R c; + synchronized (this) { + state = STATE.getAndSet(this, CANCELLED); + if (state <= HAS_REQUEST_NO_VALUE) { + c = container; + value = null; + container = null; + } + else { + c = null; + } + } + if (c != null) { + s.cancel(); + discard(c); + } + } } } diff --git a/reactor-core/src/main/java/reactor/core/publisher/MonoCollectList.java b/reactor-core/src/main/java/reactor/core/publisher/MonoCollectList.java index f5e54dcf32..81ba58b1f6 100644 --- a/reactor-core/src/main/java/reactor/core/publisher/MonoCollectList.java +++ b/reactor-core/src/main/java/reactor/core/publisher/MonoCollectList.java @@ -42,10 +42,10 @@ public CoreSubscriber subscribeOrReturn(CoreSubscriber extends Operators.MonoSubscriber> { - Subscription s; - List list; + Subscription s; + boolean done; MonoCollectListSubscriber(CoreSubscriber> actual) { @@ -92,7 +92,7 @@ public void onNext(T t) { @Override public void onError(Throwable t) { - if(done) { + if (done) { Operators.onErrorDropped(t, actual.currentContext()); return; } @@ -102,7 +102,7 @@ public void onError(Throwable t) { l = list; list = null; } - Operators.onDiscardMultiple(l, actual.currentContext()); + discard(l); actual.onError(t); } diff --git a/reactor-core/src/test/java/reactor/core/publisher/MonoCollectTest.java b/reactor-core/src/test/java/reactor/core/publisher/MonoCollectTest.java index fda0f06614..9453a1f2c3 100644 --- a/reactor-core/src/test/java/reactor/core/publisher/MonoCollectTest.java +++ b/reactor-core/src/test/java/reactor/core/publisher/MonoCollectTest.java @@ -30,13 +30,21 @@ import reactor.core.CoreSubscriber; import reactor.core.Fuseable; import reactor.core.Scannable; +import reactor.core.publisher.MonoCollect.CollectSubscriber; import reactor.test.StepVerifier; import reactor.test.subscriber.AssertSubscriber; +import reactor.test.util.RaceTestUtils; +import reactor.util.Logger; +import reactor.util.Loggers; +import reactor.util.context.Context; import static org.assertj.core.api.Assertions.assertThat; public class MonoCollectTest { + static final Logger LOGGER = Loggers.getLogger(MonoCollectListTest.class); + + @Test(expected = NullPointerException.class) public void nullSource() { new MonoCollect<>(null, () -> 1, (a, b) -> { @@ -57,7 +65,7 @@ public void nullAction() { public void normal() { AssertSubscriber> ts = AssertSubscriber.create(); - Flux.range(1, 10).collect(ArrayList::new, (a, b) -> a.add(b)).subscribe(ts); + Flux.range(1, 10).collect(ArrayList::new, ArrayList::add).subscribe(ts); ts.assertValues(new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))) .assertNoError() @@ -126,7 +134,7 @@ public void actionThrows() { @Test public void scanSubscriber() { CoreSubscriber> actual = new LambdaMonoSubscriber<>(null, e -> {}, null, null); - MonoCollect.CollectSubscriber> test = new MonoCollect.CollectSubscriber<>( + CollectSubscriber> test = new CollectSubscriber<>( actual, (l, v) -> l.add(v), new ArrayList<>()); Subscription parent = Operators.emptySubscription(); test.onSubscribe(parent); @@ -270,4 +278,62 @@ public void discardWholeArrayOnCancel() { assertThat((Object[]) discarded.get(0)).containsExactly(0L, 1L, null, null); } + @Test + public void discardCancelNextRace() { + AtomicInteger doubleDiscardCounter = new AtomicInteger(); + Context discardingContext = Operators.enableOnDiscard(null, o -> { + AtomicBoolean ab = (AtomicBoolean) o; + if (ab.getAndSet(true)) { + doubleDiscardCounter.incrementAndGet(); + throw new RuntimeException("test"); + } + }); + for (int i = 0; i < 100_000; i++) { + AssertSubscriber> testSubscriber = new AssertSubscriber<>(discardingContext); + CollectSubscriber> subscriber = + new CollectSubscriber<>(testSubscriber, List::add, new ArrayList<>()); + subscriber.onSubscribe(Operators.emptySubscription()); + + AtomicBoolean extraneous = new AtomicBoolean(false); + + RaceTestUtils.race(subscriber::cancel, + () -> subscriber.onNext(extraneous)); + + testSubscriber.assertNoValues(); + if (!extraneous.get()) { + LOGGER.info(""+subscriber.container); + } + assertThat(extraneous).as("released " + i).isTrue(); + } + LOGGER.info("discarded twice or more: {}", doubleDiscardCounter.get()); + } + + @Test + public void discardCancelCompleteRace() { + AtomicInteger doubleDiscardCounter = new AtomicInteger(); + Context discardingContext = Operators.enableOnDiscard(null, o -> { + AtomicBoolean ab = (AtomicBoolean) o; + if (ab.getAndSet(true)) { + doubleDiscardCounter.incrementAndGet(); + throw new RuntimeException("test"); + } + }); + for (int i = 0; i < 100_000; i++) { + AssertSubscriber> testSubscriber = new AssertSubscriber<>(discardingContext); + CollectSubscriber> subscriber = + new CollectSubscriber<>(testSubscriber, List::add, new ArrayList<>()); + subscriber.onSubscribe(Operators.emptySubscription()); + + AtomicBoolean resource = new AtomicBoolean(false); + subscriber.onNext(resource); + + RaceTestUtils.race(subscriber::cancel, subscriber::onComplete); + + if (testSubscriber.values().isEmpty()) { + assertThat(resource).as("not completed and released " + i).isTrue(); + } + } + LOGGER.info("discarded twice or more: {}", doubleDiscardCounter.get()); + } + }