Open12

Kotlinでshift/reset

りんすりんす

まずは runContを実装する。Kotlinで継続で実装したものとは、contextを引数で受ける点で少し違っているが、基本的なアイデアは変わらない。

fun <T> runCont(context: CoroutineContext, cont: (T) -> Unit, block: suspend () -> T) {
    block.startCoroutine(object : Continuation<T> {
        override val context = context
        override fun resumeWith(result: Result<T>) {
            cont(result.getOrThrow())
        }
    })
}
りんすりんす

次にKotlinでcall/ccを実装する。現在のcoroutineContextを引き継ぐようにしてある点以外は、特別な工夫をすることなく実装できる。

suspend fun <T, U> callCC(block: suspend (suspend (T) -> U) -> T): T {
    val context = coroutineContext
    return suspendCoroutine { cont ->
        runCont(context, { cont.resume(it) }) {
            block { t ->
                cont.resume(t)
                suspendCoroutine<U> {
                    // dispose continuation
                }
            }
        }
    }
}
りんすりんす

このcall/ccは普通に使うことができる。

runCont {
    val n: Int = callCC<Int, String> { k ->
        val s = k(20)
        println("Hello, $s")
        10
    }
    println("${n + 1}")
}

kが呼び出された時点で残りの継続は破棄されて、出力は 21 となる。

りんすりんす

次に用意するのは、callCCの2つ目の型引数をVoidに固定した変種であるescapeだ。

Voidは値を作ることのできないような型として作ったもので、このようなescapeを用意することでcallCCに渡されるラムダ式の中でkが呼ばれることを保証できる。

Voidはカリーハワード同型対応でいうところの偽であるため、爆発律にあたる関数を書ける。これを使えば、callCCの型引数をVoidに固定⇔任意の型Uに戻すという操作を行うことで、型をcallCCのままkの呼び出しの強制できる。

class Void private constructor()

@Suppress("UNUSED_PARAMETER")
fun <T> coerce(void: Void): T {
    throw Exception("Coerce void")
}

suspend fun <T> escape(block: suspend (suspend (T) -> Void) -> T): T {
    return callCC(block)
}
りんすりんす

次に、CoroutineContextに継続を保存するための状態を用意する。

// 文脈の状態を引くキー
object StateCellKey : CoroutineContext.Key<StateCell<*>>

// 文脈に持たせる状態
class StateCell<T>(
    var state: suspend (T) -> Void,
) : CoroutineContext.Element, AbstractCoroutineContextElement(StateCellKey)

// 文脈に持たせる初期状態
val InitialStateCell = StateCell<Any>(state = {
    throw MissingResetException("Reset is missing", null)
})

class MissingResetException(message: String, cause: Throwable?) : Exception(message, cause)
りんすりんす

次にCoroutineContext内の状態を更新する関数と、CoroutineContext内に保存されている継続を実行する関数を書く。
そもそも文脈に状態が載っていないとNullPointerExceptionが起こるので、それを防ぐためにScopeを導入する。(Scope内に状態を載せてもいいのかもしれない)

// NullPointerExceptionを防ぐ工夫
object ShiftResetScope

// 文脈に状態を設定する
suspend fun <T> ShiftResetScope.setStateCell(newState: suspend (T) -> Void) {
    @Suppress("UNCHECKED_CAST")
    val cell: StateCell<T> = coroutineContext[StateCellKey]!! as StateCell<T>
    cell.state = newState
}

// 文脈に保存されている継続を実行する。現在の継続は破棄されるため値は返らない
// Tは文脈のStateCell<T>に一致する必要がある
suspend fun <T, U> ShiftResetScope.abort(v: T): U {
    val cell = coroutineContext[StateCellKey]!!

    @Suppress("UNCHECKED_CAST")
    val typedCell = cell as StateCell<T>
    return coerce(typedCell.state(v))
}
りんすりんす

setStateCell, abortを実行できる文脈を導入する関数を用意する。まだ定義していないresetを使っているが、これはユーザー側がresetを実行せずともshiftを使えるようにする工夫なので説明は後にまわす。

suspend fun <T> shiftResetScope(
    block: suspend ShiftResetScope.() -> T,
): T {
    val context = coroutineContext
    return suspendCoroutine { cont ->
        runCont(context + InitialStateCell, cont::resume) {
            with(ShiftResetScope) {
                reset {
                    block()
                }
            }
        }
    }
}
りんすりんす

resetの定義は次の通りだ。

kの型はsuspend (T) -> Voidであり、vの型はTだ。

suspend fun <T> ShiftResetScope.reset(block: suspend () -> T): T {
    return escape { k ->
        val preservedState = coroutineContext[StateCellKey]!!.state
        setStateCell { v ->
            setStateCell(preservedState)
            k(v)
        }
        abort(block())
    }
}
りんすりんす

shiftの定義は次の通り。

suspend fun <T, U> ShiftResetScope.shift(block: suspend (suspend (T) -> U) -> U): T {
    return escape { k ->
        val r = block { v ->
            reset {
                coerce(k(v))
            }
        }
        abort(r)
    }
}
りんすりんす

全体の使い心地はこのような感じ。これはshift/resetを使って例外機構を実装した例だ。

runCont {
    shiftResetScope {
        val t = ThrowableContext { message ->
            shift { _ ->
                println("Error! Message: $message")
            }
        }
        with(t) {
            val r = safeDiv(1, 0)
            println("Result: $r")
        }
    }
}

fun interface ThrowableContext {
    suspend fun throwException(message: String): Void
}

suspend fun ThrowableContext.safeDiv(n: Int, m: Int): Int {
    if (m == 0) {
        throwException("Zero Division")
    }
    return n / m
}
りんすりんす

今のところの不満は以下の点がある:

  1. Kotlinの継続がone shot継続なので、shiftのkを複数回呼び出すと死ぬ