🤖

Rustで自動微分 with Algebraic Effects

2023/07/26に公開

はじめに

前回はC++でAlgebraic Effectsを使って自動微分を実装しました。
https://zenn.dev/catminusminus/articles/aed3b5b7c5ac31

実はRustでやろうと思ったんですが、型が合わず上手くいかなかったのでC++でやりました。
今回黒魔術に手を出すことでそれっぽく出来たのでRust版の記事を書きます。
なお本記事では、上の記事もしくは大元のEffekt言語のドキュメントに書いてあることは仮定します。

RustでAlgebraic Effects

RustでAlgebraic Effectsをやるにはいくつか選択肢がありますが、今回はeffective-rustを使います。
https://github.com/pandaman64/effective-rust

ただし、そのままだとビルドできないので、issueの記載のように修正してください。

順伝播の実装

そうしたら以下のように実装します。

#[derive(Clone, Debug)]
struct Num {
    data: f64,
    grad: Rc<RefCell<f64>>,
}

struct NumOp(f64);
impl Effect for NumOp {
    type Output = Num;
}

struct AddOp(Num, Num);
impl Effect for AddOp {
    type Output = Num;
}

struct MulOp(Num, Num);
impl Effect for MulOp {
    type Output = Num;
}

こうして、以下のように3.0x + x^2の計算に使えます。

#[eff(NumOp, AddOp, MulOp)]
fn prog(x: Num) -> Num {
    let x_2 = perform!(MulOp(x.clone(), x.clone()));
    let ret = perform!(AddOp(perform!(MulOp(perform!(NumOp(3.0)), x)), x_2));
    ret
}

あとはこれをハンドルして。。。と行きたいところですが、こんな自動微分ライブラリは嫌だと思います。
演算子をオーバーロードすれば、と最初は思いますが、

#[eff(AddOp)]
fn add(x: Num, y: Num) -> Num {
    let ret = perform!(AddOp(x, y));
    ret
}

impl Add for Num {
    type Output = impl Effectful<Output = Num, Effect = eff::coproduct::Either<AddOp, !>>;
    fn add(self, rhs: Self) -> Self::Output {
        add(self, rhs)
    }
}
// MulOpも同様

とすると、prog内で型エラーが出てしまいます(OutputがNumではないので)。
そこで今回は、こちらを参考にproc_macro_attributeを実装しました。
実装は単純で、#left.saturating_add(#right)perform!(AddOp(#left, #right))(掛け算も同様)にするだけなのでコードは省略します。
こうしてcomputeマクロを実装したら、

#[eff(NumOp, AddOp, MulOp)]
#[compute]
fn prog(x: Num) -> Num {
    let ret = perform!(NumOp(3.0)) * x.clone() + x.clone() * x;
    ret
}

と普通の計算のように書けます。
後はハンドラを定義してやって、

fn main() {
   let result = prog(Num {
        data: 2.0,
        grad: Rc::new(1.0.into()),
    })
    .handle(handler!(
        x => x,
        NumOp(x), k => perform!(k.resume(Num { data: x, grad: Rc::new(1.0.into())})),
        AddOp(x, y), k => perform!(k.resume(Num { data: x.data + y.data, grad: Rc::new((*x.grad.borrow() + *y.grad.borrow()).into())})),
        MulOp(x, y), k => perform!(k.resume(Num { data: x.data * y.data, grad: Rc::new((*x.grad.borrow() * y.data + *y.grad.borrow() * x.data).into())})),
    )).block_on();
    println!("{:?}", result.data);
}

とすると、計算結果の10.0が表示されます。

逆伝播

#[eff(NumOp, AddOp, MulOp)]
fn prog_(x: Num) -> Num {
    let ret = perform_from!(prog(x));
    *ret.grad.borrow_mut() += 1.0;
    ret
}

とします。ここで、prog(x)Numを返すわけではないので、perform_from!で再度performして使います。ちなみにEffekt言語だとこういうことはせずにprog(input).push(1.0)とやっているので(AD effectはNumのような型とは「別枠」なので)、やはり言語に組み込まれていると良いですね。
さて、後はハンドルするだけです。

fn main() {
    let input = Num {
        data: 2.0,
        grad: Rc::new(0.0.into()),
    };
    let _ = prog_(input.clone())
    .handle(handler!(
        x => x,
        NumOp(x), k => perform!(k.resume(Num { data: x, grad: Rc::new(0.0.into())})),
        AddOp(x, y), k => {
            let z = Num { data: x.data + y.data, grad: Rc::new(0.0.into())};
            perform!(k.resume(z.clone()));
            *x.grad.borrow_mut() += *z.grad.borrow();
            *y.grad.borrow_mut() += *z.grad.borrow();
            z
        },
        MulOp(x, y), k => {
            let z = Num { data: x.data * y.data, grad: Rc::new((*x.grad.borrow() * y.data + *y.grad.borrow() * x.data).into())};
            perform!(k.resume(z.clone()));
            *x.grad.borrow_mut() += y.data * *z.grad.borrow();
            *y.grad.borrow_mut() += x.data * *z.grad.borrow();
            z
        },
    )).block_on();
    println!("{:?}", *input.grad.borrow());
}

これで導関数の3.0 + 2xxに2.0を代入した7.0が表示されます。

おわりに

RustでもAlgebraic Effectsを使って自動微分をやることができました。

Discussion