From 8cf0f36c10f349dbe8e7d15295928560bb7960af Mon Sep 17 00:00:00 2001 From: Vsevolod Tolstopyatov Date: Thu, 24 Jun 2021 12:48:38 +0300 Subject: [PATCH] Introduce Task.await and Task.asDeferred with CancellationTokenSource (#2786) * Support bi-directional cancellation for Task.asDeferred and Task.await via passed in CancellationTokenSource Fixes #2527 Co-authored-by: Alex Vanyo Co-authored-by: dkhalanskyjb <52952525+dkhalanskyjb@users.noreply.github.com> --- .../README.md | 10 + .../api/kotlinx-coroutines-play-services.api | 2 + .../src/Tasks.kt | 99 +++++-- .../test/TaskTest.kt | 265 ++++++++++++++++++ 4 files changed, 349 insertions(+), 27 deletions(-) diff --git a/integration/kotlinx-coroutines-play-services/README.md b/integration/kotlinx-coroutines-play-services/README.md index 4ee6bf427c..e5e0e613b3 100644 --- a/integration/kotlinx-coroutines-play-services/README.md +++ b/integration/kotlinx-coroutines-play-services/README.md @@ -6,6 +6,7 @@ Extension functions: | **Name** | **Description** | -------- | --------------- +| [Task.asDeferred][asDeferred] | Converts a Task into a Deferred | [Task.await][await] | Awaits for completion of the Task (cancellable) | [Deferred.asTask][asTask] | Converts a deferred value to a Task @@ -25,5 +26,14 @@ val snapshot = try { // Do stuff ``` +If the `Task` supports cancellation via passing a `CancellationToken`, pass the corresponding `CancellationTokenSource` to `asDeferred` or `await` to support bi-directional cancellation: + +```kotlin +val cancellationTokenSource = CancellationTokenSource() +val currentLocationTask = fusedLocationProviderClient.getCurrentLocation(PRIORITY_HIGH_ACCURACY, cancellationTokenSource.token) +val currentLocation = currentLocationTask.await(cancellationTokenSource) // cancelling `await` also cancels `currentLocationTask`, and vice versa +``` + +[asDeferred]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-play-services/kotlinx.coroutines.tasks/com.google.android.gms.tasks.-task/as-deferred.html [await]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-play-services/kotlinx.coroutines.tasks/com.google.android.gms.tasks.-task/await.html [asTask]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-play-services/kotlinx.coroutines.tasks/kotlinx.coroutines.-deferred/as-task.html diff --git a/integration/kotlinx-coroutines-play-services/api/kotlinx-coroutines-play-services.api b/integration/kotlinx-coroutines-play-services/api/kotlinx-coroutines-play-services.api index 9b2c4dd388..cc23e8db2e 100644 --- a/integration/kotlinx-coroutines-play-services/api/kotlinx-coroutines-play-services.api +++ b/integration/kotlinx-coroutines-play-services/api/kotlinx-coroutines-play-services.api @@ -1,6 +1,8 @@ public final class kotlinx/coroutines/tasks/TasksKt { public static final fun asDeferred (Lcom/google/android/gms/tasks/Task;)Lkotlinx/coroutines/Deferred; + public static final fun asDeferred (Lcom/google/android/gms/tasks/Task;Lcom/google/android/gms/tasks/CancellationTokenSource;)Lkotlinx/coroutines/Deferred; public static final fun asTask (Lkotlinx/coroutines/Deferred;)Lcom/google/android/gms/tasks/Task; + public static final fun await (Lcom/google/android/gms/tasks/Task;Lcom/google/android/gms/tasks/CancellationTokenSource;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static final fun await (Lcom/google/android/gms/tasks/Task;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } diff --git a/integration/kotlinx-coroutines-play-services/src/Tasks.kt b/integration/kotlinx-coroutines-play-services/src/Tasks.kt index d89d1aec7c..c37ac7a02d 100644 --- a/integration/kotlinx-coroutines-play-services/src/Tasks.kt +++ b/integration/kotlinx-coroutines-play-services/src/Tasks.kt @@ -6,15 +6,8 @@ package kotlinx.coroutines.tasks -import com.google.android.gms.tasks.CancellationTokenSource -import com.google.android.gms.tasks.RuntimeExecutionException -import com.google.android.gms.tasks.Task -import com.google.android.gms.tasks.TaskCompletionSource -import kotlinx.coroutines.CancellationException -import kotlinx.coroutines.CompletableDeferred -import kotlinx.coroutines.Deferred -import kotlinx.coroutines.Job -import kotlinx.coroutines.suspendCancellableCoroutine +import com.google.android.gms.tasks.* +import kotlinx.coroutines.* import kotlin.coroutines.* /** @@ -45,39 +38,85 @@ public fun Deferred.asTask(): Task { /** * Converts this task to an instance of [Deferred]. * If task is cancelled then resulting deferred will be cancelled as well. + * However, the opposite is not true: if the deferred is cancelled, the [Task] will not be cancelled. + * For bi-directional cancellation, an overload that accepts [CancellationTokenSource] can be used. */ -public fun Task.asDeferred(): Deferred { +public fun Task.asDeferred(): Deferred = asDeferredImpl(null) + +/** + * Converts this task to an instance of [Deferred] with a [CancellationTokenSource] to control cancellation. + * The cancellation of this function is bi-directional: + * * If the given task is cancelled, the resulting deferred will be cancelled. + * * If the resulting deferred is cancelled, the provided [cancellationTokenSource] will be cancelled. + * + * Providing a [CancellationTokenSource] that is unrelated to the receiving [Task] is not supported and + * leads to an unspecified behaviour. + */ +@ExperimentalCoroutinesApi // Since 1.5.1, tentatively until 1.6.0 +public fun Task.asDeferred(cancellationTokenSource: CancellationTokenSource): Deferred = + asDeferredImpl(cancellationTokenSource) + +private fun Task.asDeferredImpl(cancellationTokenSource: CancellationTokenSource?): Deferred { + val deferred = CompletableDeferred() if (isComplete) { val e = exception - return if (e == null) { - @Suppress("UNCHECKED_CAST") - CompletableDeferred().apply { if (isCanceled) cancel() else complete(result as T) } + if (e == null) { + if (isCanceled) { + deferred.cancel() + } else { + @Suppress("UNCHECKED_CAST") + deferred.complete(result as T) + } } else { - CompletableDeferred().apply { completeExceptionally(e) } + deferred.completeExceptionally(e) + } + } else { + addOnCompleteListener { + val e = it.exception + if (e == null) { + @Suppress("UNCHECKED_CAST") + if (it.isCanceled) deferred.cancel() else deferred.complete(it.result as T) + } else { + deferred.completeExceptionally(e) + } } } - val result = CompletableDeferred() - addOnCompleteListener { - val e = it.exception - if (e == null) { - @Suppress("UNCHECKED_CAST") - if (isCanceled) result.cancel() else result.complete(it.result as T) - } else { - result.completeExceptionally(e) + if (cancellationTokenSource != null) { + deferred.invokeOnCompletion { + cancellationTokenSource.cancel() } } - return result + // Prevent casting to CompletableDeferred and manual completion. + return object : Deferred by deferred {} } /** - * Awaits for completion of the task without blocking a thread. + * Awaits the completion of the task without blocking a thread. * * This suspending function is cancellable. * If the [Job] of the current coroutine is cancelled or completed while this suspending function is waiting, this function * stops waiting for the completion stage and immediately resumes with [CancellationException]. + * + * For bi-directional cancellation, an overload that accepts [CancellationTokenSource] can be used. + */ +public suspend fun Task.await(): T = awaitImpl(null) + +/** + * Awaits the completion of the task that is linked to the given [CancellationTokenSource] to control cancellation. + * + * This suspending function is cancellable and cancellation is bi-directional: + * * If the [Job] of the current coroutine is cancelled or completed while this suspending function is waiting, this function + * cancels the [cancellationTokenSource] and throws a [CancellationException]. + * * If the task is cancelled, then this function will throw a [CancellationException]. + * + * Providing a [CancellationTokenSource] that is unrelated to the receiving [Task] is not supported and + * leads to an unspecified behaviour. */ -public suspend fun Task.await(): T { +@ExperimentalCoroutinesApi // Since 1.5.1, tentatively until 1.6.0 +public suspend fun Task.await(cancellationTokenSource: CancellationTokenSource): T = awaitImpl(cancellationTokenSource) + +private suspend fun Task.awaitImpl(cancellationTokenSource: CancellationTokenSource?): T { // fast path if (isComplete) { val e = exception @@ -95,13 +134,19 @@ public suspend fun Task.await(): T { return suspendCancellableCoroutine { cont -> addOnCompleteListener { - val e = exception + val e = it.exception if (e == null) { @Suppress("UNCHECKED_CAST") - if (isCanceled) cont.cancel() else cont.resume(result as T) + if (it.isCanceled) cont.cancel() else cont.resume(it.result as T) } else { cont.resumeWithException(e) } } + + if (cancellationTokenSource != null) { + cont.invokeOnCancellation { + cancellationTokenSource.cancel() + } + } } } diff --git a/integration/kotlinx-coroutines-play-services/test/TaskTest.kt b/integration/kotlinx-coroutines-play-services/test/TaskTest.kt index 0f125ac98c..b125192e93 100644 --- a/integration/kotlinx-coroutines-play-services/test/TaskTest.kt +++ b/integration/kotlinx-coroutines-play-services/test/TaskTest.kt @@ -149,5 +149,270 @@ class TaskTest : TestBase() { } } + @Test + fun testCancellableTaskAsDeferred() = runTest { + val cancellationTokenSource = CancellationTokenSource() + val deferred = Tasks.forResult(42).asDeferred(cancellationTokenSource) + assertEquals(42, deferred.await()) + assertTrue(cancellationTokenSource.token.isCancellationRequested) + } + + @Test + fun testNullResultCancellableTaskAsDeferred() = runTest { + val cancellationTokenSource = CancellationTokenSource() + assertNull(Tasks.forResult(null).asDeferred(cancellationTokenSource).await()) + assertTrue(cancellationTokenSource.token.isCancellationRequested) + } + + @Test + fun testCancelledCancellableTaskAsDeferred() = runTest { + val cancellationTokenSource = CancellationTokenSource() + val deferred = Tasks.forCanceled().asDeferred(cancellationTokenSource) + + assertTrue(deferred.isCancelled) + try { + deferred.await() + fail("deferred.await() should be cancelled") + } catch (e: Exception) { + assertTrue(e is CancellationException) + } + assertTrue(cancellationTokenSource.token.isCancellationRequested) + } + + @Test + fun testCancellingCancellableTaskAsDeferred() = runTest { + val cancellationTokenSource = CancellationTokenSource() + val task = TaskCompletionSource(cancellationTokenSource.token).task + val deferred = task.asDeferred(cancellationTokenSource) + + deferred.cancel() + try { + deferred.await() + fail("deferred.await() should be cancelled") + } catch (e: Exception) { + assertTrue(e is CancellationException) + } + assertTrue(cancellationTokenSource.token.isCancellationRequested) + } + + @Test + fun testExternallyCancelledCancellableTaskAsDeferred() = runTest { + val cancellationTokenSource = CancellationTokenSource() + val task = TaskCompletionSource(cancellationTokenSource.token).task + val deferred = task.asDeferred(cancellationTokenSource) + + cancellationTokenSource.cancel() + + try { + deferred.await() + fail("deferred.await() should be cancelled") + } catch (e: Exception) { + assertTrue(e is CancellationException) + } + assertTrue(cancellationTokenSource.token.isCancellationRequested) + } + + @Test + fun testSeparatelyCancelledCancellableTaskAsDeferred() = runTest { + val cancellationTokenSource = CancellationTokenSource() + val task = TaskCompletionSource().task + task.asDeferred(cancellationTokenSource) + + cancellationTokenSource.cancel() + + assertTrue(cancellationTokenSource.token.isCancellationRequested) + } + + @Test + fun testFailedCancellableTaskAsDeferred() = runTest { + val cancellationTokenSource = CancellationTokenSource() + val deferred = Tasks.forException(TestException("something went wrong")).asDeferred(cancellationTokenSource) + + assertTrue(deferred.isCancelled && deferred.isCompleted) + val completionException = deferred.getCompletionExceptionOrNull()!! + assertTrue(completionException is TestException) + assertEquals("something went wrong", completionException.message) + + try { + deferred.await() + fail("deferred.await() should throw an exception") + } catch (e: Exception) { + assertTrue(e is TestException) + assertEquals("something went wrong", e.message) + } + assertTrue(cancellationTokenSource.token.isCancellationRequested) + } + + @Test + fun testFailingCancellableTaskAsDeferred() = runTest { + val cancellationTokenSource = CancellationTokenSource() + val lock = ReentrantLock().apply { lock() } + + val deferred: Deferred = Tasks.call { + lock.withLock { throw TestException("something went wrong") } + }.asDeferred(cancellationTokenSource) + + assertFalse(deferred.isCompleted) + lock.unlock() + + try { + deferred.await() + fail("deferred.await() should throw an exception") + } catch (e: Exception) { + assertTrue(e is TestException) + assertEquals("something went wrong", e.message) + assertSame(e.cause, deferred.getCompletionExceptionOrNull()) // debug mode stack augmentation + } + assertTrue(cancellationTokenSource.token.isCancellationRequested) + } + + @Test + fun testFastPathCompletedTaskWithCancelledTokenSourceAsDeferred() = runTest { + val cancellationTokenSource = CancellationTokenSource() + val deferred = Tasks.forResult(42).asDeferred(cancellationTokenSource) + cancellationTokenSource.cancel() + assertEquals(42, deferred.await()) + } + + @Test + fun testAwaitCancellableTask() = runTest { + val cancellationTokenSource = CancellationTokenSource() + val taskCompletionSource = TaskCompletionSource(cancellationTokenSource.token) + + val deferred: Deferred = async(start = CoroutineStart.UNDISPATCHED) { + taskCompletionSource.task.await(cancellationTokenSource) + } + + assertFalse(deferred.isCompleted) + taskCompletionSource.setResult(42) + + assertEquals(42, deferred.await()) + assertTrue(deferred.isCompleted) + } + + @Test + fun testFailedAwaitTask() = runTest(expected = { it is TestException }) { + val cancellationTokenSource = CancellationTokenSource() + val taskCompletionSource = TaskCompletionSource(cancellationTokenSource.token) + + val deferred: Deferred = async(start = CoroutineStart.UNDISPATCHED) { + taskCompletionSource.task.await(cancellationTokenSource) + } + + assertFalse(deferred.isCompleted) + taskCompletionSource.setException(TestException("something went wrong")) + + deferred.await() + } + + @Test + fun testCancelledAwaitCancellableTask() = runTest { + val cancellationTokenSource = CancellationTokenSource() + val taskCompletionSource = TaskCompletionSource(cancellationTokenSource.token) + + val deferred: Deferred = async(start = CoroutineStart.UNDISPATCHED) { + taskCompletionSource.task.await(cancellationTokenSource) + } + + assertFalse(deferred.isCompleted) + // Cancel the deferred + deferred.cancel() + + try { + deferred.await() + fail("deferred.await() should be cancelled") + } catch (e: Exception) { + assertTrue(e is CancellationException) + } + + assertTrue(cancellationTokenSource.token.isCancellationRequested) + } + + @Test + fun testExternallyCancelledAwaitCancellableTask() = runTest { + val cancellationTokenSource = CancellationTokenSource() + val taskCompletionSource = TaskCompletionSource(cancellationTokenSource.token) + + val deferred: Deferred = async(start = CoroutineStart.UNDISPATCHED) { + taskCompletionSource.task.await(cancellationTokenSource) + } + + assertFalse(deferred.isCompleted) + // Cancel the cancellation token source + cancellationTokenSource.cancel() + + try { + deferred.await() + fail("deferred.await() should be cancelled") + } catch (e: Exception) { + assertTrue(e is CancellationException) + } + + assertTrue(cancellationTokenSource.token.isCancellationRequested) + } + + @Test + fun testFastPathCancellationTokenSourceCancelledAwaitCancellableTask() = runTest { + val cancellationTokenSource = CancellationTokenSource() + // Construct a task without the cancellation token source + val taskCompletionSource = TaskCompletionSource() + + val deferred: Deferred = async(start = CoroutineStart.LAZY) { + taskCompletionSource.task.await(cancellationTokenSource) + } + + assertFalse(deferred.isCompleted) + cancellationTokenSource.cancel() + + // Cancelling the token doesn't cancel the deferred + assertTrue(cancellationTokenSource.token.isCancellationRequested) + assertFalse(deferred.isCompleted) + + // Cleanup + deferred.cancel() + } + + @Test + fun testSlowPathCancellationTokenSourceCancelledAwaitCancellableTask() = runTest { + val cancellationTokenSource = CancellationTokenSource() + // Construct a task without the cancellation token source + val taskCompletionSource = TaskCompletionSource() + + val deferred: Deferred = async(start = CoroutineStart.UNDISPATCHED) { + taskCompletionSource.task.await(cancellationTokenSource) + } + + assertFalse(deferred.isCompleted) + cancellationTokenSource.cancel() + + // Cancelling the token doesn't cancel the deferred + assertTrue(cancellationTokenSource.token.isCancellationRequested) + assertFalse(deferred.isCompleted) + + // Cleanup + deferred.cancel() + } + + @Test + fun testFastPathWithCompletedTaskAndCanceledTokenSourceAwaitTask() = runTest { + val firstCancellationTokenSource = CancellationTokenSource() + val secondCancellationTokenSource = CancellationTokenSource() + // Construct a task with a different cancellation token source + val taskCompletionSource = TaskCompletionSource(firstCancellationTokenSource.token) + + val deferred: Deferred = async(start = CoroutineStart.LAZY) { + taskCompletionSource.task.await(secondCancellationTokenSource) + } + + assertFalse(deferred.isCompleted) + secondCancellationTokenSource.cancel() + + assertFalse(deferred.isCompleted) + taskCompletionSource.setResult(42) + + assertEquals(42, deferred.await()) + assertTrue(deferred.isCompleted) + } + class TestException(message: String) : Exception(message) }