🐖

Rustで計算グラフを実装中

に公開

AIの勉強を始めたので「ゼロからつくるDeep Learning(以降、ゼロつく)」を読みながら、Rustで計算グラフを実装することにしました。

大枠ではこんな感じです。

まずは計算グラフのノードを用意。

struct Node {
    parents: Vec<Rc<RefCell<FunctionNode>>>,
    data: Tensor,
    // 逆伝播まで値がないのでOptionに包む
    grad: Option<Tensor>,
}

それから、ノードの親としてFunctionNodeを用意しました。

struct FunctionNode {
    input: Vec<Rc<Node>>,
    output: Tensor,
    func: Box<dyn Function>,
}

あとは各々の計算として関数のtraitを準備。

trait Function {
    fn apply(&self, tensors: &[&Tensor]) -> Tensor;
    fn backward(&self, grad: Tensor) -> Tensor;
}

これでforwardとbackwardの計算ができます。

Rcは当然使うのが初めてなんですけど、Pythonの変数は全部Rcで使うようなもの、という理解です。(参照でどんどん渡せる便利な道具。けどコストがあるって感じでしょうか)
(Pythonはゼロつくで少し触った程度なんですけど…。とっても便利な言語というイメージです)

Forwardは完全に理解した!

(この構文で合っているよね?)
Forwardは問題ないかなと思います。
下記の順番で処理を行っていきます。
①最初に流し込む値としてnodeを作成
②そのnode に適用する関数と合わせて処理をするnodeを渡して、計算の記録をfunctionNodeに保持しながら、次の計算に使うためのnodeを吐き出す

fn apply<F>(self: Rc<Self>, func: F, nodes: &[Rc<Node>] ) -> Node 
    where F: Function + 'static
        // まだmainから呼んでないのですが、staticじゃ制約が強い気がしています…。
    {
        let mut tensors = vec![&self.data];
        tensors.extend(nodes.iter().map(|n| &n.data));

        let output = func.apply(&tensors);

        let mut inputs = vec![self];
        inputs.extend(nodes.iter().map(|n| Rc::clone(n)));

        // 次のNodeの親。新しいNodeを生み出したものを記録
        let funcnode = FunctionNode{
            input: inputs,
            output: output.clone(),
            func: Box::new(func),
        };
        // 次の計算の元になるNode
        Node {
            parents: vec![Rc::new(RefCell::new(funcnode))],
            data: output,
            grad: None,
        }
    }

他のnodeとの演算は「あるだろう」くらいの勢いで、実のところ、今はAffineやReLu、Sigmoid、Convくらいしか想定はしていないので、他nodeの引数は仮で持たせています。

    // ここでnodesと複数Nodeを受け取るように一応設計。
    fn apply<F>(self: Rc<Self>, func: F, nodes: &[Rc<Node>] )

Backwardがまだわからない。

ずっと一本道で来ていたら、逆伝播もわかるのですが、addとかで道が別れた時に、どうnode(勾配)を集めていけばいいのか、すこし迷子です。
今後しっかりゼロつくを読み込むつもりです。

今後の道筋

Trace関数でどのNode順で追えばいいかを集める。
(これは来た道の逆順だよね? 複数ある時にどう設計すれば思ったような形になるのか悩み中)

fn trace(self: Rc<Self>, is_traced: &mut HashSet<*const Node>, nodes_piled: &Vec<Rc<Node>>){
        // すでにtrace済みのNodeじゃないかを確認
        let ptr = Rc::as_ptr(&self);
        if is_traced.contains(&ptr) {
            return;
        }
        // なかったらInsertする。
        is_traced.insert(ptr);
        // 親を遡るよ
        let parents = &self.parents;
        for p in parents.iter() {
            // refcellを剥き剥きするときはborrow
            let nodes = &p.borrow().input;
            for n in nodes.iter() {
                // rcなのでcloneしてtraceします。
                n.clone().trace(is_traced, nodes_piled);
            }
        }
    }

損失を求めて、最適化を行う。

上記を現状の形で一通り回したいです。
そのあとは、batchとかchannelも考慮した拡張版をかんがえていきたいです。

現状での振り返り

問題は2つあって、一つは計算グラフの設計そのもの。もう一つはRustの学習曲線が急なことです。
nodeを記録するところで結構時間がかかりました。どのように問題を切り分けるかが難しくて。
gradをnodeに置くかFunctionNodeに置くかも丸一日は悩み、traceの中でbackwardをやってしまいたいという気持ちに今なっていたりもします。正直まだまだわからないことだらけ。

Rustに関しては、この取り組みでかなり所有権がピンとくるようになりました。
最初はかなり「わかったつもり」になっていたのですが、やればやるほどわかってないことがわかり…。
「あー! そういうことか!!」というのをなん度も繰り返しました。
特に自動の参照解決でかなり勘違いが多かったです。自動で参照解決されるから、参照のところを参照じゃないって思い込んだり、map内で変数を渡す時は参照だから、中身も参照だろ、って思い込んだり。

RcもRefCellも今ソースコードで書いている書き方以外は多分わかってません。

わからんわからんばっかり書いてますが、良いニュースもありまして。
この計算グラフ全部で大体200行くらいになっているんですが(matmul等も合わせて)、全部何もみずにゼロから書けます。(手が覚えました)
一ヶ月ほど悩み続けたのは無駄ではなかったです。

テストもTensor絡みではすっと書けるようになりましたし。
結構いいRustスタートが切れていると思います。

Discussion