🎄

[Rust] サブグラフを生成する高階自動微分

2023/08/20に公開

前回は二重数の拡張によって高階自動微分する方法を Rust で実装しました。
しかし、リバースモードが使えない、多変数への拡張が困難、任意の関数の高階微分はやっぱり自分で求めなければいけない、などの問題があり、スケールアップするには難がありました。

そこでこのサブグラフを生成する手法が役に立つのですが、この手法、名前は何というのかよくわかりません。 PyTorchJaxTensorFlow に機能としてはあるのですが、呼び方はばらばらで、文献なども示されていません。しかし、ちょっと考えてみれば誰でも思いつくような方法なので、敢えて名前で呼ぶようなものでもないのかもしれません。

いつも通り全てのコードはこちらのリポジトリにあります。

https://github.com/msakuta/rustograd

理論

基本的なアイデアはいたってシンプルです。フォワードモードの要領でグラフを辿り、微分値を評価する代わりに微分値を出力するノードを生成するのです。

例えば、次のような非常にシンプルな式を考えます。

f(a) = (a - b) a

最初の状態では式のグラフは次のようになります(繋がっていない 0 と 1 は後で使います、と思ったら 0 は結局使いませんでした)。

1次の微分のサブグラフの計算が終わった時点では次のようになります。これは正しく次の計算結果に一致しています。

f'(a) = 2a - b

2次の微分は次のようになります。

f''(a) = 2

ここで微分した結果のノードはサブグラフの一部に元の関数の一部を含むのが見て取れると思います。これは計算量を最適化するためにかなり重要な性質で、複雑な関数であっても共有する計算ノードが多く出てくる傾向があるので、重複した計算を避けることができます。

全体の流れを次のようにアニメーションにすることもできますが、 graphviz はノードの位置を毎回変えてしまうのであまり見やすくはありません。それでもなんとなくアイデアはわかると思います。

もちろん、微分の階数を混合した次のような関数も扱えます。

f(x) = \frac{d g(x)}{dx} + g(x)

rustograd ではこのサブグラフを生成するメソッドを gen_graph と呼んでいます。これを繰り返し呼び出すことで高階微分が可能になります。

let dddaba = aba.gen_graph(&a).unwrap().gen_graph(&a);

さらに、微分する変数を毎回選べるので、交差項 \partial a \partial b なども簡単に計算できます。

欠点を挙げるとすると、生成されるサブグラフが階数に応じて指数的に大きくなる傾向があることです。例えば、次のような単純なガウス関数の3次微分までを求めてみます。

f(x) = \exp(-x^2)

計算グラフは次のようになります。

アニメーションもできますが、もはや何が何だかわかりません。

これは合成関数の微分が積の微分を生成し、それが項を増やしていく性質によるものです。よく見ると同じ値を持つのに共有されていないノードがたくさんあるので、最適化の余地はあると思うのですが、かなりの労力を必要としそうです。

実装

この実装は思ったよりも簡単でした。基本的にフォワードモードの微分のコードそのままで、間にノードを生成するコードを挟みつつ再帰呼び出しをしていきます。まずは大枠となる gen_graph の定義です。ここでは Tape の実装を使っているので、 Vec<TapeNode> の可変参照を引数に取ります。返り値として Option<TapeIndex> を返しているのがミソで、これが親ノードの生成したノードになります。

fn gen_graph<T: Tensor + 'static>(
    nodes: &mut Vec<TapeNode<T>>,
    idx: TapeIndex,
    wrt: TapeIndex,
    cb: &impl Fn(&[TapeNode<T>], TapeIndex, TapeIndex),
) -> Option<TapeIndex> {
    use TapeValue::*;
    let ret = match nodes[idx as usize].value {
        // ...
    }
}

マッチ式の中で式の種別ごとに動作を記述します。

入力変数

値の場合は単純に微分先の変数であれば 1、そうでなければ何も返しません。

        Value(_) => {
            if idx == wrt {
                Some(1)
            } else {
                None
            }
        }

加算

加算の場合はどちらか一方のサブグラフがノードを生成していればそれを返し、両方生成していれば加算し、どちらも生成していなければ None を返します。

        Add(lhs, rhs) => {
            let lhs = gen_graph(nodes, lhs, wrt, cb);
            let rhs = gen_graph(nodes, rhs, wrt, cb);
            match (lhs, rhs) {
                (Some(lhs), None) => Some(lhs),
                (None, Some(rhs)) => Some(rhs),
                (Some(lhs), Some(rhs)) => Some(add_add(nodes, lhs, rhs)),
                _ => None,
            }
        }

減算

引き算の場合は符号反転ノードを使うほかは加算と同じです。

        Sub(lhs, rhs) => {
            let lhs = gen_graph(nodes, lhs, wrt, cb);
            let rhs = gen_graph(nodes, rhs, wrt, cb);
            match (lhs, rhs) {
                (Some(lhs), None) => Some(lhs),
                (None, Some(rhs)) => Some(add_neg(nodes, rhs)),
                (Some(lhs), Some(rhs)) => Some(add_sub(nodes, lhs, rhs)),
                _ => None,
            }
        }

積算

ちょっと複雑になってくるのが掛け算です。オペランドの片方のみノードを生成している場合は、逆側のノードの値との積を生成して返します。両方のオペランドが生成している場合、それぞれの積を足したノードを生成して返しています。

        Mul(lhs, rhs) => {
            let dlhs = gen_graph(nodes, lhs, wrt, cb);
            let drhs = gen_graph(nodes, rhs, wrt, cb);
            match (dlhs, drhs) {
                (Some(dlhs), None) => Some(add_mul(nodes, dlhs, rhs)),
                (None, Some(drhs)) => Some(add_mul(nodes, lhs, drhs)),
                (Some(dlhs), Some(drhs)) => {
                    let plhs = add_mul(nodes, dlhs, rhs);
                    let prhs = add_mul(nodes, lhs, drhs);
                    let node = add_add(nodes, plhs, prhs);
                    Some(node)
                }
                _ => None,
            }
        }

商算

割り算の場合は掛け算と同じように中間計算ノードを生成しますが、ノードの数が増えるだけで考え方は一緒です。

        Div(lhs, rhs) => {
            let dlhs = gen_graph(nodes, lhs, wrt, cb);
            let drhs = gen_graph(nodes, rhs, wrt, cb);
            match (dlhs, drhs) {
                (Some(dlhs), None) => Some(add_div(nodes, dlhs, rhs)),
                (None, Some(drhs)) => {
                    let node = add_mul(nodes, lhs, drhs);
                    let node = add_div(nodes, node, rhs);
                    let node = add_div(nodes, node, rhs);
                    Some(add_neg(nodes, node))
                }
                (Some(dlhs), Some(drhs)) => {
                    let plhs = add_div(nodes, dlhs, rhs);
                    let node = add_mul(nodes, lhs, drhs);
                    let prhs = add_div(nodes, node, rhs);
                    let prhs = add_div(nodes, prhs, rhs);
                    Some(add_sub(nodes, plhs, prhs))
                }
                _ => None,
            }
        }

符号反転

符号反転は単純に、生成されたノードがあればそれを反転したノードを返すので一行で書けます。

Neg(term) => gen_graph(nodes, term, wrt, cb).map(|node| add_neg(nodes, node)),

任意の関数ノード

任意の関数は少しややこしいです。任意の関数についてノードを生成するか否かはケースバイケースで、関数の内容次第ですが、ライブラリとしては初等関数 exp や sin など以外にもユーザーが実装できるようにしたいところです。そこで rustograd では関数の挙動を UnaryFn トレイトとしてインターフェースを定義しています。

pub trait UnaryFn<T> {
    fn name(&self) -> String;
    fn f(&self, data: T) -> T;
    fn grad(&self, data: T) -> T;
    fn t(&self, data: T) -> T {
        data
    }
    fn gen_graph(
        &self,
        _nodes: &mut Vec<TapeNode<T>>,
        _input: TapeIndex,
        _output: TapeIndex,
        _derived: TapeIndex,
    ) -> Option<TapeIndex> {
        None
    }
}

ここで使うのは最後の gen_graph メソッドです。このメソッドはこのノードの局所的な導関数を生成するものです。デフォルトでは None を返しますがほとんどの関数は何らかの値を返すでしょう。

input, output, derived という3つのノードが引数として与えられています。それぞれこのノードへの入力、出力、入力の微分のノードを示します。これは関数によって最適なノードを生成する入力が異なるためです。例えば指数関数の場合は微分は次のように入力の微分と \exp(g) 自体の出力の積になります。

\frac{d \exp(g)}{dx} = \frac{dg}{dx} \exp(g)

この場合、 outputderived の積をノードとして生成して返せばよいです。

impl UnaryFn<f64> for ExpFn {
    fn gen_graph(
        &self,
        nodes: &mut Vec<TapeNode<f64>>,
        _input: TapeIndex,
        output: TapeIndex,
        derived: TapeIndex,
    ) -> Option<TapeIndex> {
        Some(add_mul(nodes, output, derived))
    }
}

これに対し、 sin などのように入力変数に「接ぎ木」した方が良い関数もあります。
まず、三角関数は微分の次数を状態変数として取る関数にしておきます。

struct SinFn(usize);

impl UnaryFn<f64> for SinFn {
    fn name(&self) -> String {
        match self.0 % 4 {
            0 => "sin",
            1 => "cos",
            2 => "-sin",
            3 => "-cos",
            _ => unreachable!(),
        }
        .to_string()
    }

    fn f(&self, data: f64) -> f64 {
        match self.0 % 4 {
            0 => data.sin(),
            1 => data.cos(),
            2 => -data.sin(),
            3 => -data.cos(),
            _ => unreachable!(),
        }
    }

    fn grad(&self, data: f64) -> f64 {
        Self(self.0 + 1).f(data)
    }
}

こうしておくと gen_graph は次数を上げた自分自身を新たなノードとして作り、同じ入力変数としたうえで、入力の微分との積をとればよいです。

\frac{d \sin(g)}{dx} = \frac{dg}{dx} \frac{d \sin(g)}{dg}
impl UnaryFn<f64> for SinFn {
    fn gen_graph(
        &self,
        nodes: &mut Vec<TapeNode<f64>>,
        input: TapeIndex,
        _output: TapeIndex,
        derived: TapeIndex,
    ) -> Option<TapeIndex> {
        let rhs = add_unary_fn(nodes, Box::new(Self(self.0 + 1)), input);
        Some(add_mul(nodes, derived, rhs))
    }
}

さて、このように関数が実装されていれば、次のように処理できます。ここで少しややこしいのは借用チェッカーで、 fUnaryFn トレイトオブジェクトとして実装されており、それが nodes 自体の一部なので、自分自身を含むコンテナを変更しようとするとコンパイルエラーになります。実はこのエラーは真っ当なもので、ノードを増やすことでコンテナの Vec がリアロケートされる可能性があるので、 f 自体の参照を持ちながらノードを増やすのは dangling reference となりうる危険があります。

これに対応する方法は Rc で包むか、所有権を一時的に奪ってしまうという方法があります。ここでは所有権を奪う方法で借用チェッカーエラーを回避しています。これを後で戻し忘れるのはありがちなバグなのでしっかり戻しましょう。

        UnaryFn(UnaryFnPayload { term, ref mut f }) => {
            let taken_f = f.take();
            let derived = gen_graph(nodes, term, wrt, cb);
            let ret = derived.and_then(|derived| {
                taken_f
                    .as_ref()
                    .unwrap()
                    .gen_graph(nodes, term, idx, derived)
            });
            if let UnaryFn(UnaryFnPayload { ref mut f, .. }) = nodes[idx as usize].value {
                *f = taken_f;
            } else {
                unreachable!()
            }
            ret
        }

まとめ

これで任意の階数の微分が好きな変数に対して行えるようになりました。得られた柔軟性は非常に高いものですが、同時にメモリ使用量も高くつきます。しかしこれはシンボリック微分でも複雑な数式になるのである程度は仕方のない部分もあるでしょう。

共通項が出てきたらまとめるなどの最適化も考えられますが、その最適化自体に処理時間がかかっては意味がありませんので、そこら辺のバランスは問題ごとに注意深く取っていく必要があると思います。

そんなわけで rustograd は機能していますが、再度お勧めしたいのは自分で実装することです。自分で実装すれば中身を理解できるのはもちろん、問題に応じたチューニングや最適化も可能になります。ライブラリを使うということはその労力をライブラリ実装者に丸投げするということになります(多くの場合、ライブラリに投入されている労力はアプリケーションよりも大きいですが、ライブラリも全てのユースケースをカバーできるわけではありません)。

今後の課題

今後拡張できるのは以下の点です。

  • テンソル型への対応
  • 任意の多変数関数ノードの追加
  • GPU 計算

一つ目のテンソル型への対応は少し説明が必要かもしれません。 rustograd はすでにジェネリック型としてテンソル型への対応は可能になっています。しかし問題は恒等変換を示す "1" です。テンソルの形状は一般に異なるので、 "1" を表すテンソルとそれ以外の演算が可能とは限りません。そんなわけで現時点では固定サイズではないテンソルのサブグラフを生成してリバースモードの自動微分はうまくいきません。

broadcasting を行えば少なくともスカラーの "1" とテンソルの演算は可能じゃないかと思われるかもしれませんが、それだとリバースモード時の逆変換が上手くいきません。この辺はうまい解決策が見つかったらまた記事にしようと思います。

Discussion