iTranslated by AI
One-shot Delimited Continuations in Lua
This is an article for Day 9 of the Language Implementation Advent Calendar 2022. This information might be useful when implementing a language that targets Lua.
As an afterthought, I also registered this for the Lua Advent Calendar 2022.
In the Standard ML compiler LunarML that I am developing, I provide delimited continuations. I've written about delimited continuations on my blog a few times:
In previous versions of LunarML, delimited continuations were implemented via CPS transformation in the JS-CPS backend. Recently, I added an implementation of one-shot delimited continuations using coroutines in the Lua backend (and its variants).
In this article, I will attempt to explain the implementation of delimited continuations using Lua's coroutines.
Lua Coroutines
Coroutines are essentially like functions that can be suspended midway (broadly speaking).
For information about coroutines in various languages, please read Mr. Endo's article in the n-Monthly Lambda Note:
Here, we will deal with Lua's coroutines. Let's look at an example:
local co = coroutine.create(function()
print("Hello")
local a = coroutine.yield(1 + 2)
print(a..a..a)
end)
print("created a coroutine")
local _, r1 = coroutine.resume(co)
print("result 1", r1)
coroutine.resume(co, "ABC")
The output of this Lua code is:
created a coroutine
Hello
result1 3
ABCABCABC
coroutine.create is the function to create a coroutine, coroutine.yield is the function to suspend it, and coroutine.resume is the function to start or resume it.
The coroutine's process is not executed at the stage when coroutine.create is called. It is executed by calling coroutine.resume. When coroutine.yield is called inside the coroutine, the process is suspended there, and control returns to the caller of coroutine.resume. Calling coroutine.resume again resumes the coroutine from where it was suspended.
In fancy technical terms, Lua's coroutines are called stackful asymmetric coroutines.
'Stackful' means that the suspension operation can span across function calls. For example, you can create a function that wraps coroutine.yield:
function ask_yes_no()
local result = coroutine.yield()
if result == "Yes" then
return true
elseif result == "No" then
return false
else
error("Invalid input")
end
end
local co = coroutine.create(function()
...
if ask_yes_no() then
...
else
...
end
...
end)
Also, coroutine.yield is defined as a regular function and can be stored in variables or passed to other functions.
Conversely, coroutines that are not stackful (stackless coroutines) are those where something like a yield keyword is built into the language, and to span across function calls, the caller must use syntax like yield*. These are often called generators.
'Asymmetric' means that there is a parent-child relationship. The operations to transfer control are divided into coroutine.resume and coroutine.yield, and coroutines have a parent-child relationship similar to regular function calls. On the other hand, coroutines that are not asymmetric (symmetric coroutines) have only one type of operation to transfer control and always explicitly specify the destination coroutine.
If you want to read a paper about Lua's coroutines, please read:
- De Moura AL, Rodriguez N, Ierusalimschy R (2004) Coroutines in Lua. JUCS - Journal of Universal Computer Science 10(7): 910-925. https://doi.org/10.3217/jucs-010-07-0910
You can do many interesting things using Lua's coroutines. You can think of coroutines as a tool for creating a kind of embedded DSL; the function passed to coroutine.create is the DSL description, and the function wrapping coroutine.yield or the main routine (the side calling coroutine.resume) can be seen as the DSL implementation.
Coroutines and Delimited Continuations
Lua's coroutines are reportedly equivalent to one-shot delimited continuations. A one-shot delimited continuation is one where the invocation of the continuation is restricted to a single time.
The correspondence is roughly as follows:
| Coroutine | Delimited Continuation | Description |
|---|---|---|
coroutine.create |
reset, prompt
|
Creating a suspendable process |
coroutine.yield |
shift, control, ... |
Suspending the process and capturing the continuation |
coroutine.resume |
k |
Invoking the continuation |
In this article, we will try implementing (one-shot) delimited continuations using coroutines. Specifically, we will implement control0 / prompt0. The familiar shift / reset can be implemented by using control0 / prompt0.
First, in prompt0, a coroutine is created with coroutine.create and immediately called (resume). If the coroutine returns control with return, the returned value is returned as is. If the function returns control and a "function that receives a continuation" via yield, it is treated as a call to control0, a "value representing the continuation" is created, and the "function that receives a continuation" is called.
When the continuation is invoked, the coroutine is resumed with resume. If the coroutine returns control with return, the value is returned as is; if it returns control with yield, it is treated as a call to control0, and the same process as described above is performed.
If you resume a coroutine multiple times, the process resumes from a different point each time. In other words, since coroutines cannot be copied, you can only resume from the same point once. This corresponds to the fact that the continuation can only be used once. When invoking the continuation, we check a "used" flag and only invoke it if it is unused, after marking it as "used."
Single-prompt Delimited Continuations
First, let's implement delimited continuations with a single type of "delimiter." We can simply write the Lua code for what was explained in words earlier.
-- Metatable to allow calling the continuation like a function
-- The continuation holds the coroutine in the field 'co' and whether it's used in 'done'.
local sk_meta = {}
-- Start (resume) the coroutine.
-- If the coroutine returns control with 'return', return the value; if it returns control with 'yield', construct the continuation.
local function run(co, ...)
local status, a, b = coroutine.resume(co, ...)
if status then
if a == "return" then
-- 'b' contains the value to be returned
return b
elseif a == "capture" then
-- 'b' contains the function to be called
local k = setmetatable({co=co, done=false}, sk_meta)
return b(k)
else
error("unexpected result from coroutine: "..tostring(a))
end
else
error(a)
end
end
-- Create a coroutine in prompt0.
function prompt0(f)
local co = coroutine.create(function()
-- Return "return" as the first value so we can distinguish whether control was returned by 'return' or 'yield'.
return "return", f()
end)
return run(co)
end
-- Capture the continuation.
function control0(f)
local command, g = coroutine.yield("capture", f)
if command == "resume" then
return g()
else
error("unexpected command to coroutine: "..tostring(command))
end
end
-- Invoke the continuation.
function pushSubCont(subcont, f)
if subcont.done then
error("cannot resume continuation multiple times")
end
subcont.done = true
return run(subcont.co, "resume", f)
end
function sk_meta:__call(a)
return pushSubCont(self, function() return a end)
end
reset = prompt0
function shift(f)
return control0(function(k)
return prompt0(function()
return f(function(x)
return prompt0(function()
return k(x)
end)
end)
end)
end)
end
Let's try it out.
Starting with a simple example:
local result = reset(function()
return 3 * shift(function(k)
-- 'k' is bound to 3 * _
return 1 + k(5)
end)
end)
print("result1", result) -- 16
An example where shift itself acts as a "delimiter":
local result = reset(function()
return 1 + shift(function(k)
-- k = 1 + _
return 2 * shift(function(l)
-- l = 2 * _
return k(l(5))
end)
end)
end)
print("result2", result) -- 11
An example of using a captured continuation outside of reset:
local k = reset(function()
local f = shift(function(k) return k end)
return 3 * f()
end)
-- k(f) = reset(function() return 3 * f() end)
print("result3", k(function() return 7 end)) -- 21
Check GitHub for other examples:
Multi-prompt Delimited Continuations
With a little extra effort, you can also implement delimited continuations with multiple types of "delimiters." I will implement the one from the Monadic Framework paper.
-- Use a table (its identity) as a prompt tag
function newPromptTag()
return {}
end
-- Metatable to allow calling the continuation like a function
local sk_meta = {}
-- Start (resume) the coroutine
local function runWithTag(tag, co, ...)
local status, a, b, c = coroutine.resume(co, ...)
if status then
if a == "return" then
-- If a value is returned, that's fine
return b
elseif a == "capture" then
-- A continuation capture request was received
-- b: tag
-- c: callback
if b == tag then
local k = setmetatable({co=co, done=false}, sk_meta) -- The captured continuation
return c(k)
else
-- Let a higher-level handler handle it
return runWithTag(tag, co, coroutine.yield("capture", b, c))
end
else
error("unexpected result from the function: "..tostring(a))
end
else
error(a)
end
end
-- Execute processing within the delimiter
function pushPrompt(tag, f)
local co = coroutine.create(function()
return "return", f()
end)
return runWithTag(tag, co)
end
-- Continuation capture and non-local exit
function withSubCont(tag, f)
local command, a = coroutine.yield("capture", tag, f)
if command == "resume" then
return a()
else
error("unexpected command to coroutine: "..tostring(command))
end
end
-- Invoking the continuation
function pushSubCont(subcont, f)
if subcont.done then
error("cannot resume captured continuation multiple times")
end
subcont.done = true
return runWithTag(nil, subcont.co, "resume", f)
end
function sk_meta:__call(a)
return pushSubCont(self, function() return a end)
end
resetAt = pushPrompt
function shiftAt(tag, f)
return withSubCont(tag, function(k)
return pushPrompt(tag, function()
return f(function(x)
return pushPrompt(tag, function()
return k(x)
end)
end)
end)
end)
end
If there is no prompt corresponding to the specified tag, an error will be raised stating that an attempt was made to call coroutine.yield in the main thread.
By fixing the tag, you can also emulate single-prompt delimited continuations:
local tag = newPromptTag()
local function reset(f)
return resetAt(tag, f)
end
local function shift(f)
return shiftAt(tag, f)
end
Check GitHub for sample code and other details:
Bonus: Lua and the C Stack
When calling a function implemented in Lua from Lua, the C stack is not consumed.
However, coroutine.resume is implemented in C (in PUC Lua) and consumes the C stack during invocation. Therefore, the following program will cause a C stack overflow:
local function recur(n)
if n == 0 then
return "OK!!!"
else
return reset(function()
return recur(n - 1)
end)
end
end
local result = recur(500)
print("Does not consume C stack?", result)
As a more familiar example, deeply nesting pcall will cause a C stack overflow:
local function recur(n)
if n == 0 then
return "OK!!!"
else
local success, result = pcall(function()
return recur(n - 1)
end)
if success then
return result
else
error(result)
end
end
end
local result = recur(500)
print("Does not consume C stack?", result)
While you probably wouldn't nest pcall 100 levels deep when writing normal Lua code, you may encounter this issue when compiling large programs to Lua. In the case of LunarML, I encountered this when compiling HaMLet to Lua.
This kind of C stack overflow can be avoided by using coroutines to implement something like a trampoline for tail call optimization:
local _depth = 0
function pcallX(f)
--[[
Simply doing
local c = coroutine.create(function()
return "return", f()
end)
return coroutine.yield("handle", c)
would also work, but using pcall is faster for shallow nests.
]]
local success, result
if _depth > 150 then
local c = coroutine.create(function()
return "return", f()
end)
local olddepth = _depth
_depth = 0
success, result = coroutine.yield("handle", c)
_depth = olddepth
else
local olddepth = _depth
_depth = olddepth + 1
success, result = pcall(f)
_depth = olddepth
end
return success, result
end
-- Interpreter-like
function _run(f)
local c = coroutine.create(function()
return "return", f()
end)
local stack = {c}
local values = {}
while #stack > 0 do
local status, a, b = coroutine.resume(stack[#stack], table.unpack(values))
if status then
if a == "return" then
table.remove(stack)
values = {true, b}
elseif a == "handle" then
table.insert(stack, b)
values = {}
else
error("unexpected result from the function: " .. tostring(a))
end
else
table.remove(stack)
if #stack > 0 then
values = {false, a}
else
error(a)
end
end
end
return table.unpack(values)
end
_run(function()
local function recur(n)
if n == 0 then
return "OK!!!"
else
local success, result = pcallX(function()
return recur(n - 1)
end)
if success then
return result
else
error(result)
end
end
end
local result = recur(500)
print("Does not consume C stack?", result)
end)
It is necessary to wrap the content that was previously written at the top level in a call to the _run function. This is similar to a trampoline for tail call optimization.
This technique can also be applied to delimited continuations. For more details, please see the code on GitHub:
Additionally, in LuaJIT, coroutine.resume and pcall seem to be treated as built-ins, so (at least at around 500 levels) C stack overflow does not seem to occur.
Discussion