diff --git a/kotlinx-coroutines-core/common/src/channels/AbstractChannel.kt b/kotlinx-coroutines-core/common/src/channels/AbstractChannel.kt index c968a48b25..5fadbb52c4 100644 --- a/kotlinx-coroutines-core/common/src/channels/AbstractChannel.kt +++ b/kotlinx-coroutines-core/common/src/channels/AbstractChannel.kt @@ -178,7 +178,9 @@ internal abstract class AbstractSendChannel( private suspend fun sendSuspend(element: E): Unit = suspendCancellableCoroutineReusable sc@ { cont -> loop@ while (true) { if (isFullImpl) { - val send = SendElement(element, onUndeliveredElement, cont) + val send = if (onUndeliveredElement == null) + SendElement(element, cont) else + SendElementWithUndeliveredHandler(element, cont, onUndeliveredElement) val enqueueResult = enqueueSend(send) when { enqueueResult == null -> { // enqueued successfully @@ -574,7 +576,9 @@ internal abstract class AbstractChannel( @Suppress("UNCHECKED_CAST") private suspend fun receiveSuspend(receiveMode: Int): R = suspendCancellableCoroutineReusable sc@ { cont -> - val receive = ReceiveElement(onUndeliveredElement, cont as CancellableContinuation, receiveMode) + val receive = if (onUndeliveredElement == null) + ReceiveElement(cont as CancellableContinuation, receiveMode) else + ReceiveElementWithUndeliveredHandler(cont as CancellableContinuation, receiveMode, onUndeliveredElement) while (true) { if (enqueueReceive(receive)) { removeReceiveOnCancel(cont, receive) @@ -846,7 +850,7 @@ internal abstract class AbstractChannel( } private suspend fun hasNextSuspend(): Boolean = suspendCancellableCoroutineReusable sc@ { cont -> - val receive = ReceiveHasNext(channel.onUndeliveredElement, this, cont) + val receive = ReceiveHasNext(this, cont) while (true) { if (channel.enqueueReceive(receive)) { channel.removeReceiveOnCancel(cont, receive) @@ -883,8 +887,7 @@ internal abstract class AbstractChannel( } } - private class ReceiveElement( - @JvmField val onUndeliveredElement: OnUndeliveredElement?, + private open class ReceiveElement( @JvmField val cont: CancellableContinuation, @JvmField val receiveMode: Int ) : Receive() { @@ -893,10 +896,6 @@ internal abstract class AbstractChannel( else -> value } - fun resumeOnCancellationFun(value: E): ((Throwable) -> Unit)? = - onUndeliveredElement?.bindCancellationFun(value, cont.context) - - @Suppress("IMPLICIT_CAST_TO_ANY") override fun tryResumeReceive(value: E, otherOp: PrepareOp?): Symbol? { val token = cont.tryResume(resumeValue(value), otherOp?.desc, resumeOnCancellationFun(value)) ?: return null assert { token === RESUME_TOKEN } // the only other possible result @@ -917,13 +916,21 @@ internal abstract class AbstractChannel( override fun toString(): String = "ReceiveElement@$hexAddress[receiveMode=$receiveMode]" } - private class ReceiveHasNext( - @JvmField val onUndeliveredElement: OnUndeliveredElement?, + private class ReceiveElementWithUndeliveredHandler( + cont: CancellableContinuation, + receiveMode: Int, + @JvmField val onUndeliveredElement: OnUndeliveredElement + ) : ReceiveElement(cont, receiveMode) { + override fun resumeOnCancellationFun(value: E): ((Throwable) -> Unit)? = + onUndeliveredElement.bindCancellationFun(value, cont.context) + } + + private open class ReceiveHasNext( @JvmField val iterator: Itr, @JvmField val cont: CancellableContinuation ) : Receive() { override fun tryResumeReceive(value: E, otherOp: PrepareOp?): Symbol? { - val token = cont.tryResume(true, otherOp?.desc, onUndeliveredElement?.bindCancellationFun(value, cont.context)) + val token = cont.tryResume(true, otherOp?.desc, resumeOnCancellationFun(value)) ?: return null assert { token === RESUME_TOKEN } // the only other possible result // We can call finishPrepare only after successful tryResume, so that only good affected node is saved @@ -951,6 +958,10 @@ internal abstract class AbstractChannel( cont.completeResume(token) } } + + override fun resumeOnCancellationFun(value: E): ((Throwable) -> Unit)? = + iterator.channel.onUndeliveredElement?.bindCancellationFun(value, cont.context) + override fun toString(): String = "ReceiveHasNext@$hexAddress" } @@ -968,7 +979,7 @@ internal abstract class AbstractChannel( block.startCoroutineCancellable( if (receiveMode == RECEIVE_RESULT) ValueOrClosed.value(value) else value, select.completion, - channel.onUndeliveredElement?.bindCancellationFun(value, select.completion.context) + resumeOnCancellationFun(value) ) } @@ -990,6 +1001,9 @@ internal abstract class AbstractChannel( channel.onReceiveDequeued() // notify cancellation of receive } + override fun resumeOnCancellationFun(value: E): ((Throwable) -> Unit)? = + channel.onUndeliveredElement?.bindCancellationFun(value, select.completion.context) + override fun toString(): String = "ReceiveSelect@$hexAddress[$select,receiveMode=$receiveMode]" } } @@ -1056,10 +1070,8 @@ internal interface ReceiveOrClosed { /** * Represents sender for a specific element. */ -@Suppress("UNCHECKED_CAST") -internal class SendElement( +internal open class SendElement( override val pollResult: E, - @JvmField val onUndeliveredElement: OnUndeliveredElement?, @JvmField val cont: CancellableContinuation ) : Send() { override fun tryResumeSend(otherOp: PrepareOp?): Symbol? { @@ -1072,8 +1084,14 @@ internal class SendElement( override fun completeResumeSend() = cont.completeResume(RESUME_TOKEN) override fun resumeSendClosed(closed: Closed<*>) = cont.resumeWithException(closed.sendException) - override fun toString(): String = "SendElement@$hexAddress($pollResult)" + override fun toString(): String = "$classSimpleName@$hexAddress($pollResult)" +} +internal class SendElementWithUndeliveredHandler( + pollResult: E, + cont: CancellableContinuation, + @JvmField val onUndeliveredElement: OnUndeliveredElement +) : SendElement(pollResult, cont) { override fun remove(): Boolean { if (!super.remove()) return false // if the node was successfully removed (meaning it was added but was not received) then we have undelivered element @@ -1082,7 +1100,7 @@ internal class SendElement( } override fun undeliveredElement() { - onUndeliveredElement?.callUndeliveredElement(pollResult, cont.context) + onUndeliveredElement.callUndeliveredElement(pollResult, cont.context) } } @@ -1108,6 +1126,7 @@ internal class Closed( internal abstract class Receive : LockFreeLinkedListNode(), ReceiveOrClosed { override val offerResult get() = OFFER_SUCCESS abstract fun resumeReceiveClosed(closed: Closed<*>) + open fun resumeOnCancellationFun(value: E): ((Throwable) -> Unit)? = null } @Suppress("NOTHING_TO_INLINE", "UNCHECKED_CAST")