👏

Luaでワンショット限定継続

2022/12/09に公開

これは言語実装 Advent Calendar 2022の9日目の記事です。Luaをターゲットとする言語を実装する際に役立つかもしれない情報です。

後付けですがLua Advent Calendar 2022にも登録させていただきました。


私が作っているStandard ML処理系LunarMLでは限定継続を提供しています。限定継続についての話はブログの方に何回か書きました:

これまでのLunarMLではJS-CPSバックエンドでCPS変換による限定継続の実装を行なっていましたが、最近Luaバックエンド(の亜種)でコルーチンによるワンショット限定継続の実装を追加しました。

この記事では、Luaのコルーチンを使った限定継続の実装の解説を試みます。

Luaのコルーチン

コルーチンというのは途中で中断できる関数みたいなやつのことです(ざっくり)。

いろんな言語でのコルーチンについてはn月刊ラムダノートの遠藤さんの記事を読んでください:

ここではLuaのコルーチンを扱います。例を見てみましょう:

local co = coroutine.create(function()
  print("Hello")
  local a = coroutine.yield(1 + 2)
  print(a..a..a)
end)
print("created a coroutine")
local _, r1 = coroutine.resume(co)
print("result 1", r1)
coroutine.resume(co, "ABC")

このLuaコードの出力は

created a coroutine
Hello
result1 3
ABCABCABC

となります。

coroutine.create がコルーチンを作る関数、 coroutine.yield はコルーチンを中断する関数、 coroutine.resume はコルーチンを起動・再開する関数です。

coroutine.create を呼んだ段階ではコルーチンの処理は実行されません。coroutine.resume を呼ぶことによってコルーチンの処理が実行されます。コルーチンの中で coroutine.yield が呼ばれると処理がそこで中断して、制御が coroutine.resume の呼び出し元に返ります。もう一度 coroutine.resume を呼ぶとコルーチンが中断箇所から再開します。

Luaのコルーチンを気取った専門用語で言うとstackful asymmetric coroutineとなります。

stackfulというのは中断操作が関数呼び出しを跨げるということです。例えば、coroutine.yield をラップした関数を作ることができます:

function ask_yes_no()
  local result = coroutine.yield()
  if result == "Yes" then
    return true
  elseif result == "No" then
    return false
  else
    error("Invalid input")
  end
end

local co = coroutine.create(function()
  ...
  if ask_yes_no() then
    ...
  else
    ...
  end
  ...
end)

また、 coroutine.yield は普通の関数として定義されており、変数に入れたり他の関数に渡したりできます。

逆にstackfulでないコルーチン(stacklessコルーチン)というのは、言語に yield キーワードみたいなのが組み込まれていて、関数呼び出しを跨ぐには呼び出し側で yield* みたいな構文を使う必要があるやつです。ジェネレーターと呼ばれることも多いです。

asymmetric(非対称)というのは、親子関係があるということです。制御を移す操作が coroutine.resumecoroutine.yield に分かれており、普通の関数呼び出しのような親子関係がコルーチンにもあります。逆にasymmetricではないコルーチン(symmetric(対称)コルーチン)というのは、制御を移す操作が一種類しかなく、常に遷移先のコルーチンを明示します。

Luaのコルーチンについて論文を読みたい人は

を読んでください。

Luaのコルーチンを使うと色々面白いことができます。コルーチンは一種の埋め込みDSLを作るためのツールだと思うことができ、 coroutine.create に渡す関数がDSLの記述、coroutine.yield をラップする関数やメインルーチン(coroutine.resume を呼び出す側)がDSLの実装だと思えます。

コルーチンと限定継続

Luaのコルーチンはワンショット限定継続と等価らしいです。ワンショット限定継続というのは、継続の呼び出しが1回に制限されたもののことです。

対応関係はざっくりこんな感じです:

コルーチン 限定継続 解説
coroutine.create reset, prompt 中断可能な処理の作成
coroutine.yield shift, control, ... 処理の中断・継続のキャプチャー
coroutine.resume k 継続の起動

この記事ではコルーチンを使って(ワンショット)限定継続を実装する方をやってみます。ここでは control0 / prompt0 を実装してみます。お馴染みの shift / resetcontrol0 / prompt0 を使うことで実装できます。

まず、prompt0 では coroutine.create でコルーチンを作ってすぐに呼び出し(resume)します。コルーチンが return で制御を返したら返ってきた値をそのまま返します。関数が yield で制御および「継続を受け取る関数」を返したら control0 が呼ばれたとみなし、「継続を表す値」を作って「継続を受け取る関数」を呼び出します。

継続が起動されたら、コルーチンを resume で再開します。コルーチンが return で制御を返したら値をそのまま返し、yield で制御を返したら control0 が呼ばれたとみなして先ほどと同様の処理を行います。

コルーチンを何回も resume すると毎回違う場所から処理が再開します。つまり、コルーチンをコピーできない以上、同じ場所からは1回しか復帰できません。このことと継続は1回しか使えないことが対応します。継続を起動する際に「使用済み」フラグを検査して、未使用の場合にのみ、「使用済み」フラグを立てた上で継続を起動します。

single promptな限定継続

まずは「区切り」が一種類の限定継続を実装してみます。さっき言葉で説明したことをLuaコードで書けばOKです。

-- 継続を関数っぽく呼び出せるようにするためのメタテーブル
-- 継続はフィールドcoとしてコルーチンを、フィールドdoneとして使用済みか否かを保持する。
local sk_meta = {}

-- コルーチンを起動(再開)する。
-- コルーチンがreturnで制御を返したら値を返し、yieldで制御を返したら継続を構築する。
local function run(co, ...)
  local status, a, b = coroutine.resume(co, ...)
  if status then
    if a == "return" then
      -- bには返すべき値が入っている
      return b
    elseif a == "capture" then
      -- bには呼び出すべき関数が入っている
      local k = setmetatable({co=co, done=false}, sk_meta)
      return b(k)
    else
      error("unexpected result from coroutine: "..tostring(a))
    end
  else
    error(a)
  end
end

-- prompt0ではコルーチンを作成する。
function prompt0(f)
  local co = coroutine.create(function()
    -- returnで制御を返したのかyieldで制御を返したのか判別できるように、最初の返り値は "return" とする。
    return "return", f()
  end)
  return run(co)
end

-- 継続をキャプチャーする。
function control0(f)
  local command, g = coroutine.yield("capture", f)
  if command == "resume" then
    return g()
  else
    error("unexpected command to coroutine: "..tostring(command))
  end
end

-- 継続を起動する。
function pushSubCont(subcont, f)
  if subcont.done then
    error("cannot resume continuation multiple times")
  end
  subcont.done = true
  return run(subcont.co, "resume", f)
end
function sk_meta:__call(a)
  return pushSubCont(self, function() return a end)
end

reset = prompt0
function shift(f)
  return control0(function(k)
    return prompt0(function()
      return f(function(x)
        return prompt0(function()
          return k(x)
        end)
      end)
    end)
  end)
end

試してみましょう。

まずは簡単な例から。

local result = reset(function()
  return 3 * shift(function(k)
    -- k には 3 * _ が束縛される。
    return 1 + k(5)
  end)
end)
print("result1", result) -- 16

shift 自身も「区切り」として動作する例:

local result = reset(function()
  return 1 + shift(function(k)
    -- k = 1 + _
    return 2 * shift(function(l)
      -- l = 2 * _
      return k(l(5))
    end)
  end)
end)
print("result2", result) -- 11

キャプチャーした継続を reset の外で利用する例:

local k = reset(function()
  local f = shift(function(k) return k end)
  return 3 * f()
end)
-- k(f) = reset(function() return 3 * f() end)
print("result3", k(function() return 7 end)) -- 21

他の例はGitHubを見てください:

multi promptな限定継続

一手間加えると「区切り」が複数種類の限定継続も実装できます。Monadic Frameworkの論文のやつを実装してみます。

-- プロンプトのタグとしてテーブル(の同一性)を使う
function newPromptTag()
  return {}
end

-- 継続を関数っぽく呼び出せるようにするためのメタテーブル
local sk_meta = {}

-- コルーチンを起動(再開)する
local function runWithTag(tag, co, ...)
  local status, a, b, c = coroutine.resume(co, ...)
  if status then
    if a == "return" then
      -- 値が返ってきたらそれでよし
      return b
    elseif a == "capture" then
      -- 継続のキャプチャー要求が飛んできた
      -- b: tag
      -- c: callback
      if b == tag then
        local k = setmetatable({co=co, done=false}, sk_meta) -- キャプチャーした継続
        return c(k)
      else
        -- より上位のハンドラーに処理してもらう
        return runWithTag(tag, co, coroutine.yield("capture", b, c))
      end
    else
      error("unexpected result from the function: "..tostring(a))
    end
  else
    error(a)
  end
end

-- 区切りの中で処理を実行する
function pushPrompt(tag, f)
  local co = coroutine.create(function()
    return "return", f()
  end)
  return runWithTag(tag, co)
end

-- 継続のキャプチャーと大域脱出
function withSubCont(tag, f)
  local command, a = coroutine.yield("capture", tag, f)
  if command == "resume" then
    return a()
  else
    error("unexpected command to coroutine: "..tostring(command))
  end
end

-- 継続の起動
function pushSubCont(subcont, f)
  if subcont.done then
    error("cannot resume captured continuation multiple times")
  end
  subcont.done = true
  return runWithTag(nil, subcont.co, "resume", f)
end
function sk_meta:__call(a)
  return pushSubCont(self, function() return a end)
end

resetAt = pushPrompt
function shiftAt(tag, f)
  return withSubCont(tag, function(k)
    return pushPrompt(tag, function()
      return f(function(x)
        return pushPrompt(tag, function()
          return k(x)
        end)
      end)
    end)
  end)
end

指定されたタグに対応するプロンプトがない場合は、「メインスレッドで coroutine.yield しようとした」という旨のエラーが出ます。

タグを固定してやるとsingle promptな限定継続もエミュレートできます:

local tag = newPromptTag()
local function reset(f)
  return resetAt(tag, f)
end
local function shift(f)
  return shiftAt(tag, f)
end

サンプルコードとかはGitHubを見てください:

おまけ:LuaとCスタック

LuaからLuaで実装された関数を呼び出す際はCスタックは消費されません。

しかし、 coroutine.resume は(PUC Luaでは)Cで実装されており、呼び出しの際にCスタックを消費します。したがって、以下のプログラムはCスタックオーバーフローを起こします:

local function recur(n)
  if n == 0 then
    return "OK!!!"
  else
    return reset(function()
      return recur(n - 1)
    end)
  end
end
local result = recur(500)
print("Does not consume C stack?", result)

もっと身近な例としては、 pcall を深くネストさせるとCスタックオーバーフローします:

local function recur(n)
  if n == 0 then
    return "OK!!!"
  else
    local success, result = pcall(function()
      return recur(n - 1)
    end)
    if success then
      return result
    else
      error(result)
    end
  end
end
local result = recur(500)
print("Does not consume C stack?", result)

普通にLuaコードを書く分には pcall を100段ネストするようなことはあまりないかと思われますが、大規模なプログラムをLuaにコンパイルするとこの問題を踏むことになります。LunarMLの場合はHaMLetをLuaにコンパイルするときに遭遇しました。

こういうCスタックオーバーフローは、末尾呼び出し最適化のためのトランポリンみたいなことをコルーチンを使ってやってやると回避できます:

local _depth = 0
function pcallX(f)
  --[[
    単に
    local c = coroutine.create(function()
        return "return", f()
    end)
    return coroutine.yield("handle", c)
    でも良いが、ネストが浅い場合はpcallを使った方が速い
  ]]
  local success, result
  if _depth > 150 then
    local c = coroutine.create(function()
        return "return", f()
    end)
    local olddepth = _depth
    _depth = 0
    success, result = coroutine.yield("handle", c)
    _depth = olddepth
  else
    local olddepth = _depth
    _depth = olddepth + 1
    success, result = pcall(f)
    _depth = olddepth
  end
  return success, result
end

-- インタープリターっぽい
function _run(f)
  local c = coroutine.create(function()
      return "return", f()
  end)
  local stack = {c}
  local values = {}
  while #stack > 0 do
    local status, a, b = coroutine.resume(stack[#stack], table.unpack(values))
    if status then
      if a == "return" then
        table.remove(stack)
        values = {true, b}
      elseif a == "handle" then
        table.insert(stack, b)
        values = {}
      else
        error("unexpected result from the function: " .. tostring(a))
      end
    else
      table.remove(stack)
      if #stack > 0 then
        values = {false, a}
      else
        error(a)
      end
    end
  end
  return table.unpack(values)
end

_run(function()
  local function recur(n)
    if n == 0 then
      return "OK!!!"
    else
      local success, result = pcallX(function()
        return recur(n - 1)
      end)
      if success then
        return result
      else
        error(result)
      end
    end
  end
  local result = recur(500)
  print("Does not consume C stack?", result)
end)

従来はトップレベルに書いていた内容を _run という関数呼び出しで囲う必要があります。これは末尾呼び出し最適化のためのトランポリンと同様です。

この手法は限定継続にも適用できます。詳しくはGitHubのコードを読んでください:

なお、LuaJITでは coroutine.resumepcall が組み込み扱いされているのか、(少なくとも500段程度では)Cスタックオーバーフローは起きないようです。

Discussion