確率的プログラミング言語とAlgebraic Effects
はじめに
この記事は確率的プログラミング言語 Advent Calendar 2023
の23日目の記事です。
この記事を読むとPyro Poutineの気持ちが少しわかります。多分。
Algebraic Effectsとは
Algebraic Effects (and Handlers)は、大枠で言うと言語機能の一種です。
大雑把に説明すると「もとの場所に戻れる例外」です。
さらに言うと戻らなくても良いし、また一度もとの場所に戻った後に何か動作させることもできるし、(multi-shotであれば)複数回もとの場所に戻ることもできます。
そしてこれが非同期プログラミング、パーサー、自動微分や確率的プログラミング等の実装に利用できます。
コードで例を見てみましょう。
Algebraic Effectsを(Web系で)有名にしたのはDan Abramov氏の記事だと思います。
この記事のJavaScriptの例を使います。とはいっても、JavaScriptには現状存在しない機能なのでフィーリングで読んでください。
function getName(user) {
let name = user.name;
if (name === null) {
name = perform 'ask_name';
}
return name;
}
function makeFriends(user1, user2) {
user1.friendNames.push(getName(user2));
user2.friendNames.push(getName(user1));
}
const arya = { name: null, friendNames: [] };
const gendry = { name: 'Gendry', friendNames: [] };
try {
makeFriends(arya, gendry);
} handle (effect) {
if (effect === 'ask_name') {
resume with 'Arya Stark';
}
}
user.name
がnull
の場合に、getName
関数の中でperform
でask_name
なるEffectを発生させています。
そしてgetName
関数を使う側でtry~handle
して、ask_name
Effectが発生した場合にArya Stark
という値をもとの場所に戻しています。
そうして、getName
関数の中でname
にArya Stark
が代入されます。
次に戻らない例を見てみましょう。
try {
makeFriends(arya, gendry);
} handle (effect) {
if (effect === 'ask_name') {
console.error('ERROR');
}
}
こうすると普通の例外と同じような感じで、user.name
がnull
のときは、ERROR
と出力されて終わりです。
最後に戻った後に何かするケースも紹介します。
try {
makeFriends(arya, gendry);
} handle (effect) {
if (effect === 'ask_name') {
resume with 'Arya Stark';
console.log('resumed');
}
}
とすると、最初の例のようにArya Stark
がname
に代入され、makeFriends
の処理が終わったのちに、resumed
と出力されます。
ここで重要な点として、makeFriends
関数は変更することなく、その「ハンドラ」(ここではhandle (effect)
以降)を変更するだけでプログラムの動作が変わることが挙げられます。特に、resume with
の右の値を変えれば、user.name
がnull
のときに代入される値が変わります。なので、Algebraic EffectsはDependency injectionにも使うことができます。
もう一度書きますが、現状Algebraic EffectsはJavaScriptには存在しません。
逆にAlgebraic Effectsを(ライブラリレベルではなく、言語レベルで)備えている言語には、Eff、Koka、Effekt、そしてUnisonがあります。
限定継続
PPLの話をする前に、Algebraic Effectsにおける限定継続について書きます。
継続、とはプログラムの残りの計算です。限定継続とは、プログラムのある場所までの残りの計算という、範囲を限定された継続です。
先ほどのJSにおけるAlgebraic Effectsの例に戻ります。以下のようにCallbackを使った記法を考えてみます。もちろんこれも実際のJSにはない機能です。実際の言語の例を出すより、まずは比較的馴染みの深い言語を使ってそれっぽい記法を使った方が理解しやすいだろう、という意図です。
try_handle(() => {
makeFriends(arya, gendry);
}, (effect, k) => {
if (effect === 'ask_name') {
k('Arya Stark');
}
}
)
これは
try {
makeFriends(arya, gendry);
} handle (effect) {
if (effect === 'ask_name') {
resume with 'Arya Stark';
}
}
をCallbackを使って実装したものを思ってください。さて、ここでk
という謎の引数が出てきました。これが限定継続(のつもり)です。
try_handle
は第一引数に与えられたCallback f
をf()
として呼び出します(ここではそういうものとします)。そうしてEffectが発生したら第二引数に与えられたCallback g
を呼び出します。g
はeffect
とk
という2つの引数を持ちます。effect
はこれまで通り発生したEffectです。そしてk
は、Effectの発生したところから、f()
の終わりまで残りの計算です。もしf
の呼び出しが値を返すのならば、それがk
の返す値で、k
の引数はperform 'ask_name'
に代入されます。
さて、このCallbackを使った書き換えは何が嬉しいのでしょうか。ここでは限定継続がk
という変数に束縛されています。なので、例えばlistに保存しておいて、後で取り出して計算を再開する、といったことができます。
const cont_list = [];
try_handle(() => {
makeFriends(arya, gendry);
}, (effect, k) => {
if (effect === 'ask_name') {
cont_list.push(k);
}
}
)
cont_list.forEach((k) => k('Arya Stark')); // あとでまとめて処理
これが限定継続です。
Effekt言語におけるSequential Monte Carlo
それでは、Effekt言語でPPLをやる例を見てみましょう。コードはすべて以下にあるものです。
まず使用するEffectを定義します。
effect SMC {
def resample(): Unit
def uniform(): Double
def score(d: Double): Unit
}
先ほどのJavaScriptの例では、Effectは単なる文字列でしたが、Algebraic Effectsを備えている言語では、このようにEffectを先に定義することが多いと思います。
またEffectは、
effect Yield(n: Int): Boolean
のように「引数(の型)」と「返り値(の型)」相当を(静的型付き言語であれば)書きます。
ここで、複数のEffectをまとめて扱うことがあります。この場合一つ一つの「Effect」はOperationと呼ばれます。ここでは、resample
、uniform
、score
の3つのOperationが存在します。
…というのは微妙な書き方で、Effectは本来Operationの集合で表され、Effectの持つOperationがただ一つで、かつそれがEffectと同名なら(言語によっては)略記できる、といった方が正しいと思います。
ともかく、上記のEffectを使って計算対象の分布を定義します。見ての通りdo
でOperationを使い、Effectを発生させます(最初のJSの例ではperform
)。
def bernoulli(p: Double) = do uniform() < p
def biasedGeometric(p: Double): Int / SMC = {
do resample();
val x = bernoulli(p);
if (x) {
do score(log(1.5));
1 + biasedGeometric(p)
} else { 1 }
}
さて、データ型を3つ定義します。ここでCont
は限定継続(の型)です。
record Particle(weight: Double, age: Int, cont: Cont[Unit, Unit])
record Measurement[R](weight: Double, data: R)
record Particles[R](moving: List[Particle], done: List[Measurement[R]])
これを使ってsmcHandler
を定義します。
def smcHandler[R](numberOfParticles: Int) { // (A)
resample: Particles[R] => Particles[R]
} { p: () => R / SMC } = { // (E')
var currentWeight = 1.0;
var particles: List[Particle] = Nil() // (B)
var measurements: List[Measurement[R]] = Nil()
var currentAge = 0;
def checkpoint(cont: Cont[Unit, Unit]) =
particles = Cons(Particle(currentWeight, currentAge, cont), particles)
def run(p: Particle): Unit = { // (C)
currentWeight = p.weight;
currentAge = p.age;
p.cont.apply(()) // (D)
}
def run() = {
val Particles(ps, ms) = resample(Particles(particles, measurements));
particles = Nil();
measurements = ms;
ps.foreach { p => p.run } // (C')
}
repeat(numberOfParticles) {
currentWeight = 1.0;
try {
val res = p(); // (E)
measurements = Cons(Measurement(currentWeight, res), measurements)
} with SMC { // (F)
def resample() = checkpoint(cont { t => resume(t) })
def uniform() = resume(random())
def score(d) = { currentWeight = currentWeight * d; resume(()) }
}
}
while (not(particles.isEmpty)) { run() }
measurements
}
ここで、(A)については、
def smcHandler[R](numberOfParticles: Int) {
resample: Particles[R] => Particles[R]
} { p: () => R / SMC }
は高階関数です。要は最終的にnumberOfParticles
とresample
とp
を埋めます。
(B)については、
var particles: List[Particle] = Nil()
は空のリストで、またnewlist = Cons(elem, list)
でlist
にelem
を追加したnewlist
が得られる、という感じです。正確には[1,2]
をCons(1, Cons(2, Nil))
で表すみたいなやつです。
(C')のp.run
のrun
は(C)で定義されているものです。syntax sugarですね。
(E)を見てみましょう。try~catchのようなtry~withという構文があります。これは記事の最初に紹介したtry~handleのような書き方です。
(E)でp
が呼ばれています。p
の返り値はR / SMC
です。R
はいわゆるジェネリクスですが、/
の後ろに発生するEffectの型であるSMC
があります。
実際、(F)ではwith
の後にはSMC
と書かれています。
SMC
Effectには3つのOperationがあるため、それぞれどのようにハンドルするかを記載します。
まずresample
関数としてナイーブな実装を考えます。詳細はURL先を参照してください。
def resampleUniform[R](particles: Particles[R]): Particles[R] = {
// 中略
}
これを用いて以下を定義します。
def smc[R](numberOfParticles: Int) { p: () => R / SMC } =
smcHandler[R](numberOfParticles) { ps => resampleUniform(ps) } { p() }
最後に
def runSMC(numberOfParticles: Int) =
smc(numberOfParticles) { biasedGeometric(0.5) }
として、粒子の数を指定してrunSMC
を呼びます。
smc
のp
にbiasedGeometric(0.5)
が入ります。R
はInt
です。
よってsmcHandler
のresample
にps => resampleUniform(ps)
が、p
にそのままbiasedGeometric(0.5)
が入ります。
そうして、smcHandler
の中でbiasedGeometric(0.5)
が呼ばれ、各Operationが呼ばれたら対応するハンドラが呼ばれ、resume
で戻ります。
例えば
def bernoulli(p: Double) = do uniform() < p
// 中略
val x = bernoulli(p);
ではハンドラがdef uniform() = resume(random())
となっているので、結局x
はrandom()
の結果がP
より小さいかどうかの真偽値が入ります。
というのがrepeat(numberOfParticles)
より、 粒子の数だけ繰り返されます。
最後に(C')=>(C)=>(D)の部分です。biasedGeometric
の中でdo resample();
されるたびにparticles = Cons(Particle(currentWeight, currentAge, cont), particles)
としてParticle
が記録されていきました。そして、resume
はまだされていないので、do resample();
の後ろの計算はされません。これは(D)のapply
によって発現します。この後にval x = bernoulli(p);
にいき、do uniform()
によってハンドラに飛び、という流れです。
ちなみにハンドラ周りを入れ替えるとimportance samplingができます。
def importance[R](n: Int) { p : => R / SMC } = {
var measurements: List[Measurement[R]] = Nil()
n.repeat {
var currentWeight = 1.0;
try {
val result = p();
measurements = Cons(Measurement(currentWeight, result), measurements)
} with SMC {
def resample() = resume(()) // resampleでは何もしない
def uniform() = resume(random())
def score(d) = { currentWeight = currentWeight * d; resume(()) }
}
}
measurements
}
importance(numberOfParticles) { biasedGeometric(0.5) }
こうした「Effectの発生」と「そのhandle方法」を分けられる柔軟性がAlgebraic Effectsの特徴です。
Pyro Poutine
実はPythonのPPLライブラリのPyroのPoutineはAlgebraic Effectsに基づいています。
それを今から見ていきます。コードは上のドキュメントのものです。
def scale(guess):
weight = pyro.sample("weight", dist.Normal(guess, 1.0))
return pyro.sample("measurement", dist.Normal(weight, 0.75))
は
なる
このlog jointは
dist.Normal(guess, 1.0).log_prob(weight).sum() + dist.Normal(weight, 0.75).log_prob(measurement).sum()
で計算できます。できます・・・が、わざわざ定義したscale
の中身をバラバラにして計算する必要があるのは嫌です。scale
を突っ込んだらlog jointが出てきてくれるのがPPLとしては理想です。
そしてPyro Poutineはそれができます。
def make_log_joint(model):
def _log_joint(cond_data, *args, **kwargs):
conditioned_model = poutine.condition(model, data=cond_data)
trace = poutine.trace(conditioned_model).get_trace(*args, **kwargs)
return trace.log_prob_sum()
return _log_joint
scale_log_joint = make_log_joint(scale)
# 例
print(scale_log_joint({"measurement": 9.5, "weight": 8.23}, 8.5))
これはpoutine.condition
とpoutine.trace
がブラックボックスなので、それを開いてみるとこうなります。
from pyro.poutine.trace_messenger import TraceMessenger
from pyro.poutine.condition_messenger import ConditionMessenger
def make_log_joint_2(model):
def _log_joint(cond_data, *args, **kwargs):
with TraceMessenger() as tracer:
with ConditionMessenger(data=cond_data):
model(*args, **kwargs)
trace = tracer.trace
logp = 0.
for name, node in trace.nodes.items():
if node["type"] == "sample":
if node["is_observed"]:
assert node["value"] is cond_data[name]
logp = logp + node["fn"].log_prob(node["value"]).sum()
return logp
return _log_joint
scale_log_joint = make_log_joint_2(scale)
というわけで、それぞれ内部ではContext Managerが使われています。
実はscale
関数の中で呼ばれていたpyro.sample
では、その内部でPyro内部のGlobalなStackに積まれたMessenger
による処理が走ります(apply_stack
)。
TraceMessenger
もConditionMessenger
もMessenger
のサブクラスで、Context Managerとしてwith
文で使うと自身をGlobalなStackに積みます。
apply_stack
の中ではMessenger
の_postprocess_message
が呼ばれます。
そしてTraceMessenger
の場合、_postprocess_message
では_pyro_post_sample
が呼ばれます。
そして_pyro_post_sample
ではself.trace
にNodeを追加しています。
なので、これを後からfor name, node in trace.nodes.items()
などとして引き出せるわけです。
ここでは本物のPyroのコードを引き合いに出しましたが、実はこのコンセプトを使用した最小限の実装であるMini Pyroを公式が用意してくれています。全部で400行ちょっとしかないのでこちらを読むと理解が深まると思います(丸投げ)。
ここまでだとAlgebraic Effectsとの関係がわかりにくいと思います。
ですので、PythonにAlgebraic Effectsがあるとして上の例と同等のものを書いてみます。
が、その前に元のコードのMessenger
達を一つにまとめます。
class LogJointMessenger2(poutine.messenger.Messenger):
def __init__(self, cond_data):
self.data = cond_data
def __call__(self, fn):
def _fn(*args, **kwargs):
with self:
fn(*args, **kwargs)
return self.logp.clone()
return _fn
def __enter__(self):
self.logp = torch.tensor(0.)
return super().__enter__()
def __exit__(self, exc_type, exc_value, traceback):
self.logp = torch.tensor(0.)
return super().__exit__(exc_type, exc_value, traceback)
def _pyro_sample(self, msg):
if msg["name"] in self.data:
msg["value"] = self.data[msg["name"]]
msg["done"] = True
def _pyro_post_sample(self, msg):
assert msg["done"] # the "done" flag asserts that no more modifications to value and fn will be performed.
self.logp = self.logp + (msg["scale"] * msg["fn"].log_prob(msg["value"])).sum()
こうすると、
with LogJointMessenger2(cond_data={"measurement": 9.5, "weight": 8.23}) as m:
scale(8.5)
print(m.logp)
という感じで使えます。
ではこれをAlgebraic Effectsが使えるとして書き直してみます。まずscale
は以下のようにします。
def scale(guess):
weight = perform sample("weight", dist.Normal(guess, 1.0))
measurement = perform sample("measurement", dist.Normal(weight, 0.75))
return measurement
もちろんPythonにperform
などはないのですが、ここは最初のJSの例と同じようフィーリングでお願いします。要はsample
Effectを発生させているつもりです。
これを以下のように使います。
with handle LogJointMessenger2(cond_data={"measurement": 9.5, "weight": 8.23}) as m:
scale(8.5)
print(m.logp)
またwith handle
という本来のPythonには存在しない構文を使いましたが、これはハンドラを定めていると思ってください。ハンドラのas
って何?って言われると困ってしまいますが、ここは元のContext Manager的なサムシングということで・・・
ともかく、sample
Effectがperform
されると、with handle
で定められたハンドラMessengerの適切なメソッドが呼ばれます。そうしてsample
が返り値を返し、scale
関数内の処理が再開される、というつもりです。
このように考えると元のPyroのコードもAlgebra Effectsっぽく見えてきませんでしょうか。もちろんPythonに限定継続があるわけではないので、実装としてはContext Managerを使ってDIしているだけですが、これによりPPLとして使いやすいインターフェースを実現しています。詳細な実装はMini Pyroを読んでください。
ちなみにEffectに対するハンドラをwith
で定めてやる記法自体はKoka言語やEff言語などでみられます。
おわりに
この記事ではPPLとAlgebraic Effectsの関係について書きました。ここまで書いておいてなんですが筆者はAlgebraic Effectsも(限定)継続もPPLもそこまで詳しいわけではないので、あらゆる記述が間違っていて大炎上しないか震えています。なんでも訂正するので許してください。
Discussion