⏱️

[Swift] async関数とAsyncStreamのキャンセル

2024/05/25に公開

async関数のキャンセル

Task.cancel()

Task.cancel()Task.isCancelledのフラグを立てるだけ です。
実際に処理を中断する処理は自分で実装しなければいけません。

https://developer.apple.com/documentation/swift/task/cancel()

Calling this method on a task that doesn’t support cancellation has no effect.

キャンセルをサポートしていないタスクでこのメソッドを呼んでも何の効果もない。

「キャンセルをサポートしていないタスク」と言う表現が、実際に処理を途中終了する処理は自分で実装しなければいけないことを意味しています。

よって以下の実験では途中終了せず1が表示されます。

func doSomething() async -> Int {
    await withCheckedContinuation { continuation in
        // 3秒後に1を返す、`withCheckedContinuation`はそれまでsuspendする。
        DispatchQueue.main.asyncAfter(deadline: .now() + 3) {
            continuation.resume(returning: 1)
        }
    }
}

let task = Task {
    let result = await doSomething()
    print(result)
}

task.cancel()
print("task.cancel")

await task.value
// 経過時間を出力するデバッグ用関数
func wait(seconds: Int, printTime: Bool) async {
    if printTime {
        print("time: 0")
    }

    for i in 1...seconds {
        try? await Task.sleep(for: .seconds(1))
        if printTime {
            print("time: \(i)")
        }
    }
}

let timer = Task {
    await wait(seconds: 5, printTime: true)
}

let task = Task {
    let result = await doSomething()
    print(result)
}

task.cancel()
print("task.cancel")

await task.value
await timer.value

出力

time: 0
task.cancel
time: 1
time: 2
time: 3
1
time: 4
time: 5

実際に途中終了する処理を実装すると以下のようになります。

func doSomethingSupportCancel() async throws -> Int {
    if Task.isCancelled { throw CancellationError() }

    let result = await withCheckedContinuation { continuation in
        DispatchQueue.main.asyncAfter(deadline: .now() + 3) {
            continuation.resume(returning: 1)
        }
    }
    if Task.isCancelled { throw CancellationError() }

    return result
}

CancellationError が標準で用意されているので、基本的にはこれを throw して終了する処理を実装するのが良いと思います。
そのため、実質的にキャンセル対応をする場合はasync throws関数に変更することになります。

https://developer.apple.com/documentation/swift/cancellationerror

また、if Task.isCancelled { throw CancellationError() }と同値である関数Task.checkCancellation()が用意されているのでこちらを使うと便利です。

https://developer.apple.com/documentation/swift/task/checkcancellation()

https://github.com/apple/swift/blob/3693592c622aafa12c69fa0c8af54282cd867880/stdlib/public/Concurrency/TaskCancellation.swift#L121-L125

func doSomethingSupportCancel2() async throws -> Int {

    try Task.checkCancellation()
    let result = try await withCheckedThrowingContinuation { continuation in
        DispatchQueue.main.asyncAfter(deadline: .now() + 3) {
            continuation.resume(returning: 1)
        }
    }
    try Task.checkCancellation()

    return result
}

実行例

let task = Task {
    do {
        let result = try await doSomethingSupportCancel()
        print(result)
    } catch {
        print("error: \(error)")
    }
}

task.cancel()
print("task.cancel")

await task.value
await timer.value

出力

time: 0
task.cancel
error: CancellationError()
time: 1
time: 2
time: 3
time: 4
time: 5

CancellationErrorを使わないキャンセル処理

実質的にキャンセル対応をする場合はasync throws関数に変更することになります。

と言いましたが、asyncのままでもキャンセル対応できます。CancellationErrorを使わない途中終了処理を書けば良いです。(よくないデザインになることもあるので、ケースによります)

func doSomethingSupportCancelNil() async -> Int? {
    if Task.isCancelled { return nil }

    let result =  try await withCheckedContinuation { continuation in
        DispatchQueue.main.asyncAfter(deadline: .now() + 3) {
            continuation.resume(returning: 1)
        }
    }

    if Task.isCancelled { return nil }
    return result
}

let task = Task {
    let result = await doSomethingSupportCancelNil()
    print(result)
}

task.cancel()
print("task.cancel")

await task.value
await timer.value
time: 0
task.cancel
nil
time: 1
time: 2
time: 3
time: 4
time: 5

キャンセルされるタイミング

ここで、Task.cancel()=「キャンセルのフラグを立てる」タイミングとtry Task.checkCancellation()=「キャンセルのフラグを見て実際に処理を途中終了する」タイミングは違うことに意識しなければいけません。

キャンセルのフラグを立てるタイミングを1.5秒後にしてみます。

let task = Task {
    do {
        let result = try await doSomethingSupportCancel()
        print(result)
    } catch {
        print("error: \(error)")
    }
}

// 1.5秒後にキャンセル
try? await Task.sleep(for: .seconds(1.5))
task.cancel()
print("task.cancel")

await task.value
await timer.value

実行結果

time: 0
time: 1
task.cancel
time: 2
time: 3
error: CancellationError()
time: 4
time: 5

1.5秒の時点でtask.cancel()が実行されていますが、実際に関数が途中終了しているのは3秒の時点です。なぜなら、try Task.checkCancellation()が実行されるタイミングでしか途中終了できないからです。(考えてみれば自然なことだと思います。)

なので、現在の実装では「Task開始直後」か「コールバック終了後 (3秒後)」でしか途中終了できません。

func doSomethingSupportCancel2() async throws -> Int {

    try Task.checkCancellation() // ここを通ったタイミング(開始直後)でしか途中終了できない

    // ここのawaitは3秒経つまで終了しない
    let result = try await withCheckedThrowingContinuation { continuation in
        DispatchQueue.main.asyncAfter(deadline: .now() + 3) {
            continuation.resume(returning: 1)
        }
    }
    try Task.checkCancellation() // ここを通ったタイミング(コールバック終了後)でしか途中終了できない

    return result
}

この場合にはwithTaskCancellationHandlerを使います。
また、コールバック自体に途中終了する機能が必要なので、より現実に近い実装で実験します。

// URLSessionのDataTask等のイメージ
final class APITask {
    private var completion: ((Result<Int, any Error>) -> Void)?

    init(completion: @escaping (Result<Int, any Error>) -> Void) {
        self.completion = completion

        DispatchQueue.main.asyncAfter(deadline: .now() + 3) {
            self.completion?(.success(1))
        }
    }

    func cancel() {
        completion?(.failure(CancellationError()))
        completion = nil
    }
}

func callAPI(completion: @escaping (Result<Int, any Error>) -> Void) -> APITask {
    APITask(completion: completion)
}
// スレッド危険の問題でconcurrency境界を跨いでvarな変数をキャプチャできないので、その対処
final class LockedTask: @unchecked Sendable {
    private let lock = NSLock()

    private var task: APITask?

    func cancel() {
        lock.withLock {
            task?.cancel()
        }
    }

    func set(task: APITask) {
        lock.withLock {
            self.task = task
        }
    }
}

func doSomethingSupportCancelRealTime() async throws -> Int {
    let apiTask = LockedTask()
    let result = try await withTaskCancellationHandler {
        try await withCheckedThrowingContinuation { continuation in
            let api = callAPI { result in
                continuation.resume(with: result)
            }
            apiTask.set(task: api)
        }
    } onCancel: {
        // ここはTask.cancel()を呼んだ直後に呼ばれる
        // apiTask.cancel()はコールバックを即時終了させるため、`withCheckedThrowingContinuation`も即時終了する。
        apiTask.cancel()
    }

    return result
}

実行結果

time: 0
time: 1
task.cancel
error: CancellationError()
time: 2
time: 3
time: 4
time: 5

AsyncStreamのキャンセル

AsyncStreamのキャンセルの挙動を見てみましょう。3秒後に一つの値を返すAsyncStreamを作ります。

func doSomethingAsyncStream() -> AsyncStream<Int> {
    AsyncStream { continuation in
        DispatchQueue.main.asyncAfter(deadline: .now() + 3) {
            continuation.yield(1)
            continuation.finish()
        }
    }
}

let task = Task {
    for await result in doSomethingAsyncStream() {
        print(result)
    }
    print("finish")
}

try? await Task.sleep(for: .seconds(1.5))

task.cancel()
print("task.cancel")

await task.value
await timer.value

実行結果

time: 0
time: 1
task.cancel
finish
time: 2
time: 3
time: 4
time: 5

値が出力されずfinishが呼ばれているので、AsyncStreamは自分でキャンセル対応せずともキャンセル処理が行われていることがわかります。
さらに、1秒と2秒の間にfinishが出力されているので、前提知識で書いたwithTaskCancellationHandlerの挙動になっていることがわかります。

この挙動の原因はどこにあるのでしょうか?より詳しく見てみましょう。
for await result inは以下のAsyncIteratorProtocol.next()while letで呼び出すシンタックスシュガーです。シンタックスシュガーを解いて実行してみましょう。

https://github.com/apple/swift-evolution/blob/main/proposals/0298-asyncsequence.md#detailed-design

let task = Task {
    var iterator = doSomethingAsyncStream().makeAsyncIterator()
    while let result = await iterator.next() {
        print(result)
    }
    print("finish")
}

同じ結果になりました。

time: 0
time: 1
task.cancel
finish
time: 2
time: 3
time: 4
time: 5

さらに分解します。while letを外してみましょう。

let task = Task {
    var iterator = doSomethingAsyncStream().makeAsyncIterator()
    let result = await iterator.next() 
    print(result)
    print("finish")
}
time: 0
time: 1
task.cancel
nil
finish
time: 2
time: 3
time: 4
time: 5

nilが出力されています。つまり、AsyncStream.Iterator.next()CancellationErrorではなくnilを返すことでキャンセルを実装していることがわかります。
実際この挙動はAsyncStream.Iteratorのドキュメントに書いてあります。

https://developer.apple.com/documentation/swift/asyncstream/iterator/next()

When next() returns nil, this signifies the end of the AsyncStream.
It is a programmer error to invoke next() from a concurrent context that contends with another such call, which results in a call to fatalError().
If you cancel the task this iterator is running in while next() is awaiting a value, the AsyncStream terminates. In this case, next() might return nil immediately, or return nil on subsequent calls.

実際にAsyncStream._Storage.next() (AsyncStream.Iterator.next()の実態)をのぞいてみると、withTaskCancellationHandlerを利用して、キャンセルが呼ばれたら

  • AsyncStreamにfinishを伝える(onTerminationを呼ぶ)
  • next()nilを返す

などの処理が書いてあります。

https://github.com/apple/swift/blob/525e245e0afa3d32e57ca8d2d1b8f1e61baee1ad/stdlib/public/Concurrency/AsyncStreamBuffer.swift#L255-L263

AsyncStream(unfold:onCancel)

https://developer.apple.com/documentation/swift/asyncstream/init(unfolding:oncancel:)

上記のAsyncStream(build:)AsyncStream(unfold:onCancel)のキャンセルの挙動は異なります。
AsyncStream(unfold:onCancel)はasync関数を受け取るinitです。asyncをAsyncStreamに変換したい場合に使います。

import Foundation

print("Stream continuation")
let streamA = AsyncStream { continuation in
    DispatchQueue.main.asyncAfter(deadline: .now() + 5) {
        continuation.yield(1)
    }

    continuation.onTermination = { _ in
        print("cancelled")
    }
}

let timerA = Task {
    await wait(seconds: 7, printTime: true)
}
let taskA = Task {
    for await i in streamA {
        print("value: \(i)")
    }
    print("finish")
}

try? await Task.sleep(for: .seconds(1))
taskA.cancel()

await taskA.value
await timerA.value

print("Stream unfold")
let streamB = AsyncStream {
    // キャンセルがsleepまで伝搬しないようにする
    // streamAと挙動を合わせるため(必ず5秒後にreturnするようにする)
    await Task {
        try? await Task.sleep(for: .seconds(5))
    }.value

    return 1
} onCancel: {
    print("cancelled")
}

let timerB = Task {
    await wait(seconds: 7, printTime: true)
}
let taskB = Task {
    for await i in streamB {
        print("value: \(i)")
    }
    print("finish")
}

try? await Task.sleep(for: .seconds(1))
taskB.cancel()

await taskB.value
await timerB.value

実行結果

Stream continuation
time: 0
time: 1
cancelled
finish
time: 2
time: 3
time: 4
time: 5
time: 6
time: 7
Stream unfold
time: 0
time: 1
cancelled
time: 2
time: 3
time: 4
time: 5
value: 1
cancelled
finish
time: 6
time: 7

AsyncStream(build:)によって生成されたAsyncStreamのIterator.next()は(build:の実装に関わらず)先述の通りcancelを呼ばれたら即nilをreturnします。

一方で

AsyncStream(unfold:onCancel:)によって生成されたAsyncStreamのIterator.next()はcancelを呼ばれるとunfold引数のasync関数にcancelを伝えます。その後の挙動はunfoldの実装次第です。

以上より、AsyncStream(unfold:onCancel:)を使う場合はキャンセル時の挙動に気を使う必要があります。

https://github.com/apple/swift/blob/a2eb9e04cc8ae4d7e951924738ad6606cdfbfbb0/stdlib/public/Concurrency/AsyncStream.swift#L333-L352

(おまけ) AsyncStreamは利用者側からキャンセルか終了かわからない

AsyncStreamCancellelationErrorをthrowしないので、利用側から「キャンセル」なのか「終了」なのかわかりません。

AsyncStreamの作成側からはわかります。Streamがterminateした際に呼ばれるonTerminationの引数に、「キャンセル」なのか「終了」なのかのenumが渡されます。

let stream = AsyncStream { continuation in
    continuation.onTermination = { termination in
        switch termination {
        case .cancelled:
            print("cancelled")
        case .finished:
            print("finished")
        }
    }

    continuation.yield(1)
}

以下の例は1秒毎、10秒間に渡って乱数を生成するStreamの合計を計算するプログラムです。

let stream = AsyncStream { continuation in
    Task {
        for _ in 0..<10 {
            try? await Task.sleep(for: .seconds(1))
            continuation.yield(Int.random(in: 0..<100))
        }

        continuation.finish()
    }
}
let task = Task {
    let sum = await stream.reduce(0, +)
    print(sum)
}

問題は、途中でtask.cancel()を読ぶとその時点までの合計値を出力してしまいます。
10秒間の合計値ではない場合は不正な合計値なのでエラーとして扱いたい、そんな場合もあります。

そういったケースではstreamを扱ったその下でtry Task.checkCancellation()を呼ぶ必要があります。

func sumStream() async throws -> Int {
    let sum = await stream.reduce(0, +)
    try Task.checkCancellation()
    return sum
}

let task = Task {
    do {
        let sum = try await sumStream()
        print(sum)
    } catch {
        print("error: \(error)")
    }
}

Discussion