🕰️

[Swift] 不安定なConcurrencyのテストをContinuationとClockで解決する

2024/04/30に公開

問題設定

以下のようなViewModelがあったとします。

struct Post: Identifiable, Equatable {
    var id: Int
}

struct PostRepository {
    let fetchPosts: @Sendable () async throws -> [Post]
}

@MainActor
final class ViewModel: ObservableObject {
    @Published var isLoading = false
    @Published var posts: [Post] = []

    let postRepository: PostRepository

    init(postRepository: PostRepository) {
        self.postRepository = postRepository
    }
}

このViewModelに以下のような[Post]をロードするメソッドがあったとします。
このメソッドで、以下の三つの状態変化(A, B, C)をテストしたい。

extension ViewModel {
    func onAppear() async {
        isLoading = true // A: ここの状態をテスト

        do {
            posts = try await postRepository.fetchPosts() // B: ここの状態をテスト
        } catch {
            // error handing
        }

        isLoading = false // C: ここの状態をテスト
    }
}

問題点

このテストは実は一筋縄ではいきません。

NGケース1

愚直にテストケースのルートでawait viewModel.onAppear()を実行すると、最終状態(B,C)まで実行されてしまうのでAのテストができません。

@MainActor
func testNG1() async {
    let posts: [Post] = [Post(id: 1), Post(id: 2), Post(id: 3)]
    let viewModel = ViewModel(postRepository: .init(fetchPosts: { posts }))

    // onAppear終了まで実行してしまう
    await viewModel.onAppear()

    // A
    XCTAssertTrue(viewModel.isLoading) // 🔴 XCTAssertTrue failed

    // B
    XCTAssertEqual(viewModel.posts, posts)

    // C
    XCTAssertFalse(viewModel.isLoading)
}

NGケース2

別タスクを起動してawait viewModel.onAppear()を実行すると
Aのチェックの時点ではTaskの中身が実行されておらず、失敗してしまいます。

func testNG2() async {
    let posts: [Post] = [Post(id: 1), Post(id: 2), Post(id: 3)]
    let viewModel = ViewModel(postRepository: .init(fetchPosts: { posts }))

    let task = Task {
        await viewModel.onAppear() // Aのチェック時点で呼ばれていない
    }

    // A
    XCTAssertTrue(viewModel.isLoading) // 🔴 XCTAssertTrue failed

    // すべて実行されるまで待つ
    await task.value

    // B
    XCTAssertEqual(viewModel.posts, posts)

    // C
    XCTAssertFalse(viewModel.isLoading)
}

NGケース3: 1回のTask.yield

https://developer.apple.com/documentation/swift/task/yield()

A task can voluntarily suspend itself in the middle of a long-running operation that doesn’t contain any suspension points, to let other tasks run for a while before execution returns to this task.

Task.yield()を使うことで、現在のタスクを中断して別のタスクを実行させることができます。

@MainActor
func testYield() async {
    let posts: [Post] = [Post(id: 1), Post(id: 2), Post(id: 3)]
    let viewModel = ViewModel(postRepository: .init(fetchPosts: { posts }))

    let task = Task {
        await viewModel.onAppear()
    }

    // Task { ... } を実行
    await Task.yield()

    // A
    XCTAssertTrue(viewModel.isLoading)

    // すべて実行されるまで待つ
    await task.value

    // B
    XCTAssertEqual(viewModel.posts, posts)

    // C
    XCTAssertFalse(viewModel.isLoading)
}

しかし、このテストコードは結果が不安定になります(成功する時もあれば失敗する時もある)

viewModel.onAppear()postRepository.fetchPostsasync funcですが、これらの呼び出しで実際に(何回)中断するかどうかはコンパイラが決めるため、中断しない場合もあります。中断しない場合は一回のawait Task.yield()viewModel.onAppear()が最後まで実行されてしまうため、XCTAssertTrue(viewModel.isLoading)が失敗します。

この問題に対して、fetchPostsawait Task.yield()を追加して手動で中断するワークアラウンドがあります。これによって成功する確率が上がりますが100%ではありません。

@MainActor
func testYield2() async {
    let posts: [Post] = [Post(id: 1), Post(id: 2), Post(id: 3)]
    let viewModel = ViewModel(postRepository: .init(fetchPosts: {
        await Task.yield() // 手動で中断
        return posts
    }))

    let task = Task {
        await viewModel.onAppear()
    }

    await Task.yield()

    // A
    XCTAssertTrue(viewModel.isLoading)

    await task.value

    // B
    XCTAssertEqual(viewModel.posts, posts)

    // C
    XCTAssertFalse(viewModel.isLoading)
}

解決策

Continuationを用いて、手動で再開しない限り中断し続けるポイントを作ります。
このポイントがAをチェックするタイミングとなります。

actor AtomicValue<T> {
    var value: T

    init(_ value: T) {
        self.value = value
    }

    func set(_ value: T) {
        self.value = value
    }
}

@MainActor
func testContinuation() async {
    let postContinuation = AtomicValue<UnsafeContinuation<Void, Never>?>(nil)
    let posts: [Post] = [Post(id: 1), Post(id: 2), Post(id: 3)]
    let viewModel = ViewModel(postRepository: .init(fetchPosts: {
        await withUnsafeContinuation { continuation in // ここで中断し続ける
            Task {
                await postContinuation.set(continuation)
            }
        }
        return posts
    }))

    let task = Task {
        await viewModel.onAppear()
    }

    // withUnsafeContinuationが呼び出されるまで(中断するまで)待ち続ける
    while await postContinuation.value == nil {
        await Task.yield()
    }

    // A
    XCTAssertTrue(viewModel.isLoading)

    // Aのチェックが終わったのでcontinuationを再開
    await postContinuation.value?.resume()
    await postContinuation.set(nil)

    await task.value

    // B
    XCTAssertEqual(viewModel.posts, posts)

    // C
    XCTAssertFalse(viewModel.isLoading)
}

参考

https://github.com/koher/swift-async-test-experiment

https://github.com/pointfreeco/swift-concurrency-extras?tab=readme-ov-file#serial-execution

Clockとの統合

上のContinuationの例を実際に活用しようとすると、ボイラープレートが多くなることが予想されます。ここで、Swift5.7からのClockに統合して使いやすくします。

https://developer.apple.com/documentation/swift/clock

SE-0329 Clock, Instant, and Duration

import Foundation

public final class TestClock: Clock, @unchecked Sendable {
    public struct Instant: InstantProtocol {
        public static var zero: Instant { Instant(offset: .zero) }

        public var offset: Duration

        public init(offset: Duration) {
            self.offset = offset
        }

        public func advanced(by duration: Duration) -> Instant {
            Instant(offset: offset + duration)
        }

        public func duration(to other: Instant) -> Duration {
            other.offset - offset
        }

        public static func < (_ lhs: Instant, _ rhs: Instant) -> Bool {
            lhs.offset < rhs.offset
        }

        public static func += (_ lhs: inout Instant, _ rhs: Duration) {
            lhs = lhs.advanced(by: rhs)
        }
    }

    struct WakeUp {
        var when: Instant
        var continuation: AsyncStream<Void>.Continuation

        init(when: Instant, continuation: AsyncStream<Void>.Continuation) {
            self.when = when
            self.continuation = continuation
        }
    }

    public var minimumResolution: Duration = .zero
    public private(set) var now: Instant

    private var noIdSleepCount = 0
    private var wakeUps: [AnyHashable: WakeUp] = [:]
    private let lock = NSLock()

    public init(initialInstant: Instant) {
        self.now = initialInstant
    }

    deinit {
        lock.withLock {
            wakeUps.values.forEach { $0.continuation.finish() }
        }
    }

    public func getAutoId(index: Int) -> String {
        "_auto_id_\(index)"
    }

    public func isSleeping<ID: Hashable>(id: ID) -> Bool {
        return lock.withLock {
            wakeUps[AnyHashable(id)] != nil
        }
    }

    public func sleep<ID: Hashable>(untilSuspendBy id: ID) async throws {
        while !isSleeping(id: id) {
            await Task.yield()
            try Task.checkCancellation()
        }
    }

    public func sleep<ID: Hashable>(id: ID, for duration: Duration, tolerance: Duration? = nil) async throws {
        try await sleep(id: id, until: lock.withLock({ now.advanced(by: duration) }), tolerance: tolerance)
    }

    public func sleep(until deadline: Instant, tolerance: Duration? = nil) async throws {
        let index = lock.withLock {
            let count = noIdSleepCount
            noIdSleepCount += 1
            return count
        }
        return try await sleep(id: getAutoId(index: index), until: deadline, tolerance: tolerance)
    }

    public func sleep<ID: Hashable>(id: ID, until deadline: Instant, tolerance: Duration? = nil) async throws {
        try Task.checkCancellation()

        let stream = AsyncStream<Void> { continuation in
            lock.withLock {
                if deadline <= now {
                    continuation.finish()
                } else {
                    wakeUps[AnyHashable(id)] = WakeUp(when: deadline, continuation: continuation)
                }
            }
        }
        // AsyncStreamはTaskのcancelでfinishが走るため
        // cancelをした瞬間にCancellelationErrorを投げることができる
        // 普通のContinuationでは無理
        for await _ in stream {}

        try Task.checkCancellation()
    }

    public func advance(by amount: Duration) {
        var shouldWakeUps = [WakeUp]()

        lock.withLock {
            now += amount
            for key in wakeUps.keys {
                guard let wakeup = wakeUps[key] else { continue }
                if wakeup.when <= now {
                    shouldWakeUps.append(wakeup)
                    wakeUps[key] = nil
                }
            }
        }

        shouldWakeUps.sort { $0.when < $1.when }
        for item in shouldWakeUps {
            item.continuation.finish()
        }
    }
}

Task.cancel()の挙動も利用できるようにAsyncStreamを使っていますが、基本的にはContinuationの使い方は変わっていません。

これで以下のメリットが得られます。

@MainActor
func testClock() async {
    let fetchPostId = "fetchPosts"
    let clock = TestClock(initialInstant: .init())
    let posts: [Post] = [Post(id: 1), Post(id: 2), Post(id: 3)]
    let viewModel = ViewModel(postRepository: .init(fetchPosts: {
        try await clock.sleep(id: fetchPostId, for: .seconds(1))
        return posts
    }))

    let task = Task {
        await viewModel.onAppear()
    }

    // Clock.sleepが呼び出されるまで(中断するまで)待ち続ける
    await clock.sleep(untilSuspendBy: fetchPostId)

    // A
    XCTAssertTrue(viewModel.isLoading)

    // Aのチェックが終わったので時間を進める
    clock.advance(by: .seconds(1))
    await task.value

    // B
    XCTAssertEqual(viewModel.posts, posts)

    // C
    XCTAssertFalse(viewModel.isLoading)
}

Discussion