🎄

確率的プログラミング言語とAlgebraic Effects

2023/12/23に公開

はじめに

この記事は確率的プログラミング言語 Advent Calendar 2023
https://qiita.com/advent-calendar/2023/ppl

の23日目の記事です。
この記事を読むとPyro Poutineの気持ちが少しわかります。多分。

Algebraic Effectsとは

Algebraic Effects (and Handlers)は、大枠で言うと言語機能の一種です。
大雑把に説明すると「もとの場所に戻れる例外」です。
さらに言うと戻らなくても良いし、また一度もとの場所に戻った後に何か動作させることもできるし、(multi-shotであれば)複数回もとの場所に戻ることもできます。
そしてこれが非同期プログラミング、パーサー、自動微分や確率的プログラミング等の実装に利用できます。

コードで例を見てみましょう。
Algebraic Effectsを(Web系で)有名にしたのはDan Abramov氏の記事だと思います。
https://overreacted.io/algebraic-effects-for-the-rest-of-us/

この記事のJavaScriptの例を使います。とはいっても、JavaScriptには現状存在しない機能なのでフィーリングで読んでください。

function getName(user) {
  let name = user.name;
  if (name === null) {
      name = perform 'ask_name';
  }
  return name;
}
 
function makeFriends(user1, user2) {
  user1.friendNames.push(getName(user2));
  user2.friendNames.push(getName(user1));
}
 
const arya = { name: null, friendNames: [] };
const gendry = { name: 'Gendry', friendNames: [] };
try {
  makeFriends(arya, gendry);
} handle (effect) {
  if (effect === 'ask_name') {
      resume with 'Arya Stark';
  }
}

user.namenullの場合に、getName関数の中でperformask_nameなるEffectを発生させています。
そしてgetName関数を使う側でtry~handleして、ask_name Effectが発生した場合にArya Starkという値をもとの場所に戻しています。
そうして、getName関数の中でnameArya Starkが代入されます。

次に戻らない例を見てみましょう。

try {
  makeFriends(arya, gendry);
} handle (effect) {
  if (effect === 'ask_name') {
    console.error('ERROR');
  }
}

こうすると普通の例外と同じような感じで、user.namenullのときは、ERRORと出力されて終わりです。

最後に戻った後に何かするケースも紹介します。

try {
  makeFriends(arya, gendry);
} handle (effect) {
  if (effect === 'ask_name') {
    resume with 'Arya Stark';
    console.log('resumed');
  }
}

とすると、最初の例のようにArya Starknameに代入され、makeFriendsの処理が終わったのちに、resumedと出力されます。

ここで重要な点として、makeFriends関数は変更することなく、その「ハンドラ」(ここではhandle (effect)以降)を変更するだけでプログラムの動作が変わることが挙げられます。特に、resume withの右の値を変えれば、user.namenullのときに代入される値が変わります。なので、Algebraic EffectsはDependency injectionにも使うことができます。

もう一度書きますが、現状Algebraic EffectsはJavaScriptには存在しません。
逆にAlgebraic Effectsを(ライブラリレベルではなく、言語レベルで)備えている言語には、EffKokaEffekt、そしてUnisonがあります。

限定継続

PPLの話をする前に、Algebraic Effectsにおける限定継続について書きます。
継続、とはプログラムの残りの計算です。限定継続とは、プログラムのある場所までの残りの計算という、範囲を限定された継続です。
先ほどのJSにおけるAlgebraic Effectsの例に戻ります。以下のようにCallbackを使った記法を考えてみます。もちろんこれも実際のJSにはない機能です。実際の言語の例を出すより、まずは比較的馴染みの深い言語を使ってそれっぽい記法を使った方が理解しやすいだろう、という意図です。

try_handle(() => {
  makeFriends(arya, gendry);
}, (effect, k) => {
    if (effect === 'ask_name') {
        k('Arya Stark');
    }
  }
)

これは

try {
  makeFriends(arya, gendry);
} handle (effect) {
  if (effect === 'ask_name') {
      resume with 'Arya Stark';
  }
}

をCallbackを使って実装したものを思ってください。さて、ここでkという謎の引数が出てきました。これが限定継続(のつもり)です。
try_handleは第一引数に与えられたCallback ff()として呼び出します(ここではそういうものとします)。そうしてEffectが発生したら第二引数に与えられたCallback gを呼び出します。geffectkという2つの引数を持ちます。effectはこれまで通り発生したEffectです。そしてkは、Effectの発生したところから、f()の終わりまで残りの計算です。もしfの呼び出しが値を返すのならば、それがkの返す値で、kの引数はperform 'ask_name'に代入されます。

さて、このCallbackを使った書き換えは何が嬉しいのでしょうか。ここでは限定継続がkという変数に束縛されています。なので、例えばlistに保存しておいて、後で取り出して計算を再開する、といったことができます。

const cont_list = [];
try_handle(() => {
  makeFriends(arya, gendry);
}, (effect, k) => {
    if (effect === 'ask_name') {
        cont_list.push(k);
    }
  }
)
cont_list.forEach((k) => k('Arya Stark')); // あとでまとめて処理

これが限定継続です。

Effekt言語におけるSequential Monte Carlo

それでは、Effekt言語でPPLをやる例を見てみましょう。コードはすべて以下にあるものです。
https://effekt-lang.org/docs/casestudies/smc

まず使用するEffectを定義します。

effect SMC {
  def resample(): Unit
  def uniform(): Double
  def score(d: Double): Unit
}

先ほどのJavaScriptの例では、Effectは単なる文字列でしたが、Algebraic Effectsを備えている言語では、このようにEffectを先に定義することが多いと思います。
またEffectは、

effect Yield(n: Int): Boolean

のように「引数(の型)」と「返り値(の型)」相当を(静的型付き言語であれば)書きます。

ここで、複数のEffectをまとめて扱うことがあります。この場合一つ一つの「Effect」はOperationと呼ばれます。ここでは、resampleuniformscoreの3つのOperationが存在します。
…というのは微妙な書き方で、Effectは本来Operationの集合で表され、Effectの持つOperationがただ一つで、かつそれがEffectと同名なら(言語によっては)略記できる、といった方が正しいと思います。

ともかく、上記のEffectを使って計算対象の分布を定義します。見ての通りdoでOperationを使い、Effectを発生させます(最初のJSの例ではperform)。

def bernoulli(p: Double) = do uniform() < p

def biasedGeometric(p: Double): Int / SMC = {
  do resample();
  val x = bernoulli(p);
  if (x) {
    do score(log(1.5));
    1 + biasedGeometric(p)
  } else { 1 }
}

さて、データ型を3つ定義します。ここでContは限定継続(の型)です。

record Particle(weight: Double, age: Int, cont: Cont[Unit, Unit])
record Measurement[R](weight: Double, data: R)
record Particles[R](moving: List[Particle], done: List[Measurement[R]])

これを使ってsmcHandlerを定義します。

def smcHandler[R](numberOfParticles: Int) { // (A)
  resample: Particles[R] => Particles[R]
} { p: () => R / SMC } = { // (E')
  var currentWeight = 1.0;
  var particles: List[Particle] = Nil() // (B)
  var measurements: List[Measurement[R]] = Nil()
  var currentAge = 0;

  def checkpoint(cont: Cont[Unit, Unit]) =
    particles = Cons(Particle(currentWeight, currentAge, cont), particles)

  def run(p: Particle): Unit = { // (C)
    currentWeight = p.weight;
    currentAge = p.age;
    p.cont.apply(()) // (D)
  }

  def run() = {
    val Particles(ps, ms) = resample(Particles(particles, measurements));
    particles = Nil();
    measurements = ms;
    ps.foreach { p => p.run } // (C')
  }

  repeat(numberOfParticles) {
    currentWeight = 1.0;
    try {
      val res = p(); // (E)
      measurements = Cons(Measurement(currentWeight, res), measurements)
    } with SMC { // (F)
      def resample() = checkpoint(cont { t => resume(t) })
      def uniform() = resume(random())
      def score(d) = { currentWeight = currentWeight * d; resume(()) }
    }
  }

  while (not(particles.isEmpty)) { run() }
  measurements
}

ここで、(A)については、

def smcHandler[R](numberOfParticles: Int) {
  resample: Particles[R] => Particles[R]
} { p: () => R / SMC } 

は高階関数です。要は最終的にnumberOfParticlesresamplepを埋めます。
(B)については、

  var particles: List[Particle] = Nil()

は空のリストで、またnewlist = Cons(elem, list)listelemを追加したnewlistが得られる、という感じです。正確には[1,2]Cons(1, Cons(2, Nil))で表すみたいなやつです。
https://ja.wikipedia.org/wiki/Cons_(Lisp)
(C')のp.runrunは(C)で定義されているものです。syntax sugarですね。
(E)を見てみましょう。try~catchのようなtry~withという構文があります。これは記事の最初に紹介したtry~handleのような書き方です。
(E)でpが呼ばれています。pの返り値はR / SMCです。Rはいわゆるジェネリクスですが、/の後ろに発生するEffectの型であるSMCがあります。
実際、(F)ではwithの後にはSMCと書かれています。
SMC Effectには3つのOperationがあるため、それぞれどのようにハンドルするかを記載します。

まずresample関数としてナイーブな実装を考えます。詳細はURL先を参照してください。

def resampleUniform[R](particles: Particles[R]): Particles[R] = {
  // 中略
}

これを用いて以下を定義します。

def smc[R](numberOfParticles: Int) { p: () => R / SMC } =
  smcHandler[R](numberOfParticles) { ps => resampleUniform(ps) } { p() }

最後に

def runSMC(numberOfParticles: Int) =
    smc(numberOfParticles) { biasedGeometric(0.5) }

として、粒子の数を指定してrunSMCを呼びます。
smcpbiasedGeometric(0.5)が入ります。RIntです。
よってsmcHandlerresampleps => resampleUniform(ps)が、pにそのままbiasedGeometric(0.5)が入ります。
そうして、smcHandlerの中でbiasedGeometric(0.5)が呼ばれ、各Operationが呼ばれたら対応するハンドラが呼ばれ、resumeで戻ります。
例えば

def bernoulli(p: Double) = do uniform() < p
// 中略
  val x = bernoulli(p);

ではハンドラがdef uniform() = resume(random())となっているので、結局xrandom()の結果がPより小さいかどうかの真偽値が入ります。

というのがrepeat(numberOfParticles)より、 粒子の数だけ繰り返されます。

最後に(C')=>(C)=>(D)の部分です。biasedGeometricの中でdo resample();されるたびにparticles = Cons(Particle(currentWeight, currentAge, cont), particles)としてParticleが記録されていきました。そして、resumeはまだされていないので、do resample();の後ろの計算はされません。これは(D)のapplyによって発現します。この後にval x = bernoulli(p);にいき、do uniform()によってハンドラに飛び、という流れです。

ちなみにハンドラ周りを入れ替えるとimportance samplingができます。

def importance[R](n: Int) { p : => R / SMC } = {
  var measurements: List[Measurement[R]] = Nil()
  n.repeat {
    var currentWeight = 1.0;
    try {
      val result = p();
      measurements = Cons(Measurement(currentWeight, result), measurements)
    } with SMC {
      def resample() = resume(()) // resampleでは何もしない
      def uniform() = resume(random())
      def score(d) = { currentWeight = currentWeight * d; resume(()) }
    }
  }
  measurements
}

importance(numberOfParticles) { biasedGeometric(0.5) }

こうした「Effectの発生」と「そのhandle方法」を分けられる柔軟性がAlgebraic Effectsの特徴です。

Pyro Poutine

実はPythonのPPLライブラリのPyroのPoutineはAlgebraic Effectsに基づいています。
https://docs.pyro.ai/en/stable/poutine.html
https://pyro.ai/examples/effect_handlers.html

それを今から見ていきます。コードは上のドキュメントのものです。

def scale(guess):
    weight = pyro.sample("weight", dist.Normal(guess, 1.0))
    return pyro.sample("measurement", dist.Normal(weight, 0.75))

weight\mid guess\sim\mathrm{Normal}(guess, 1.0) \\ measurement\mid guess,weight\sim\mathrm{Normal}(weight, 0.75)

なるweightmeasurementの同時確率分布です。
このlog jointは

dist.Normal(guess, 1.0).log_prob(weight).sum() + dist.Normal(weight, 0.75).log_prob(measurement).sum()

で計算できます。できます・・・が、わざわざ定義したscaleの中身をバラバラにして計算する必要があるのは嫌です。scaleを突っ込んだらlog jointが出てきてくれるのがPPLとしては理想です。

そしてPyro Poutineはそれができます。

def make_log_joint(model):
    def _log_joint(cond_data, *args, **kwargs):
        conditioned_model = poutine.condition(model, data=cond_data)
        trace = poutine.trace(conditioned_model).get_trace(*args, **kwargs)
        return trace.log_prob_sum()
    return _log_joint

scale_log_joint = make_log_joint(scale)
# 例
print(scale_log_joint({"measurement": 9.5, "weight": 8.23}, 8.5))

これはpoutine.conditionpoutine.traceがブラックボックスなので、それを開いてみるとこうなります。

from pyro.poutine.trace_messenger import TraceMessenger
from pyro.poutine.condition_messenger import ConditionMessenger

def make_log_joint_2(model):
    def _log_joint(cond_data, *args, **kwargs):
        with TraceMessenger() as tracer:
            with ConditionMessenger(data=cond_data):
                model(*args, **kwargs)

        trace = tracer.trace
        logp = 0.
        for name, node in trace.nodes.items():
            if node["type"] == "sample":
                if node["is_observed"]:
                    assert node["value"] is cond_data[name]
                logp = logp + node["fn"].log_prob(node["value"]).sum()
        return logp
    return _log_joint

scale_log_joint = make_log_joint_2(scale)

というわけで、それぞれ内部ではContext Managerが使われています。

実はscale関数の中で呼ばれていたpyro.sampleでは、その内部でPyro内部のGlobalなStackに積まれたMessengerによる処理が走ります(apply_stack)。
https://github.com/pyro-ppl/pyro/blob/834ff633c9ecb4b8fac6c15fecaa35c460984aa3/pyro/primitives.py#L172-L189

TraceMessengerConditionMessengerMessengerのサブクラスで、Context Managerとしてwith文で使うと自身をGlobalなStackに積みます。
apply_stackの中ではMessenger_postprocess_messageが呼ばれます。
https://github.com/pyro-ppl/pyro/blob/834ff633c9ecb4b8fac6c15fecaa35c460984aa3/pyro/poutine/runtime.py#L293

そしてTraceMessengerの場合、_postprocess_messageでは_pyro_post_sampleが呼ばれます。
https://github.com/pyro-ppl/pyro/blob/834ff633c9ecb4b8fac6c15fecaa35c460984aa3/pyro/poutine/messenger.py#L182-L185

そして_pyro_post_sampleではself.traceにNodeを追加しています。
https://github.com/pyro-ppl/pyro/blob/834ff633c9ecb4b8fac6c15fecaa35c460984aa3/pyro/poutine/trace_messenger.py#L146

なので、これを後からfor name, node in trace.nodes.items()などとして引き出せるわけです。

ここでは本物のPyroのコードを引き合いに出しましたが、実はこのコンセプトを使用した最小限の実装であるMini Pyroを公式が用意してくれています。全部で400行ちょっとしかないのでこちらを読むと理解が深まると思います(丸投げ)。
https://github.com/pyro-ppl/pyro/blob/dev/pyro/contrib/minipyro.py

ここまでだとAlgebraic Effectsとの関係がわかりにくいと思います。
ですので、PythonにAlgebraic Effectsがあるとして上の例と同等のものを書いてみます。

が、その前に元のコードのMessenger達を一つにまとめます。

class LogJointMessenger2(poutine.messenger.Messenger):

    def __init__(self, cond_data):
        self.data = cond_data

    def __call__(self, fn):
        def _fn(*args, **kwargs):
            with self:
                fn(*args, **kwargs)
                return self.logp.clone()
        return _fn

    def __enter__(self):
        self.logp = torch.tensor(0.)
        return super().__enter__()

    def __exit__(self, exc_type, exc_value, traceback):
        self.logp = torch.tensor(0.)
        return super().__exit__(exc_type, exc_value, traceback)

    def _pyro_sample(self, msg):
        if msg["name"] in self.data:
            msg["value"] = self.data[msg["name"]]
            msg["done"] = True

    def _pyro_post_sample(self, msg):
        assert msg["done"]  # the "done" flag asserts that no more modifications to value and fn will be performed.
        self.logp = self.logp + (msg["scale"] * msg["fn"].log_prob(msg["value"])).sum()

こうすると、

with LogJointMessenger2(cond_data={"measurement": 9.5, "weight": 8.23}) as m:
    scale(8.5)
    print(m.logp)

という感じで使えます。

ではこれをAlgebraic Effectsが使えるとして書き直してみます。まずscaleは以下のようにします。

def scale(guess):
    weight = perform sample("weight", dist.Normal(guess, 1.0))
    measurement = perform sample("measurement", dist.Normal(weight, 0.75))
    return measurement

もちろんPythonにperformなどはないのですが、ここは最初のJSの例と同じようフィーリングでお願いします。要はsample Effectを発生させているつもりです。
これを以下のように使います。

with handle LogJointMessenger2(cond_data={"measurement": 9.5, "weight": 8.23}) as m:
    scale(8.5)
    print(m.logp)

またwith handleという本来のPythonには存在しない構文を使いましたが、これはハンドラを定めていると思ってください。ハンドラのasって何?って言われると困ってしまいますが、ここは元のContext Manager的なサムシングということで・・・

ともかく、sample Effectがperformされると、with handleで定められたハンドラMessengerの適切なメソッドが呼ばれます。そうしてsampleが返り値を返し、scale関数内の処理が再開される、というつもりです。

このように考えると元のPyroのコードもAlgebra Effectsっぽく見えてきませんでしょうか。もちろんPythonに限定継続があるわけではないので、実装としてはContext Managerを使ってDIしているだけですが、これによりPPLとして使いやすいインターフェースを実現しています。詳細な実装はMini Pyroを読んでください。

ちなみにEffectに対するハンドラをwithで定めてやる記法自体はKoka言語Eff言語などでみられます。

おわりに

この記事ではPPLとAlgebraic Effectsの関係について書きました。ここまで書いておいてなんですが筆者はAlgebraic Effectsも(限定)継続もPPLもそこまで詳しいわけではないので、あらゆる記述が間違っていて大炎上しないか震えています。なんでも訂正するので許してください。

Discussion