Rustで計算グラフを実装中
AIの勉強を始めたので「ゼロからつくるDeep Learning(以降、ゼロつく)」を読みながら、Rustで計算グラフを実装することにしました。
大枠ではこんな感じです。
まずは計算グラフのノードを用意。
struct Node {
parents: Vec<Rc<RefCell<FunctionNode>>>,
data: Tensor,
// 逆伝播まで値がないのでOptionに包む
grad: Option<Tensor>,
}
それから、ノードの親としてFunctionNodeを用意しました。
struct FunctionNode {
input: Vec<Rc<Node>>,
output: Tensor,
func: Box<dyn Function>,
}
あとは各々の計算として関数のtraitを準備。
trait Function {
fn apply(&self, tensors: &[&Tensor]) -> Tensor;
fn backward(&self, grad: Tensor) -> Tensor;
}
これでforwardとbackwardの計算ができます。
Rcは当然使うのが初めてなんですけど、Pythonの変数は全部Rcで使うようなもの、という理解です。(参照でどんどん渡せる便利な道具。けどコストがあるって感じでしょうか)
(Pythonはゼロつくで少し触った程度なんですけど…。とっても便利な言語というイメージです)
Forwardは完全に理解した!
(この構文で合っているよね?)
Forwardは問題ないかなと思います。
下記の順番で処理を行っていきます。
①最初に流し込む値としてnodeを作成
②そのnode に適用する関数と合わせて処理をするnodeを渡して、計算の記録をfunctionNodeに保持しながら、次の計算に使うためのnodeを吐き出す
fn apply<F>(self: Rc<Self>, func: F, nodes: &[Rc<Node>] ) -> Node
where F: Function + 'static
// まだmainから呼んでないのですが、staticじゃ制約が強い気がしています…。
{
let mut tensors = vec![&self.data];
tensors.extend(nodes.iter().map(|n| &n.data));
let output = func.apply(&tensors);
let mut inputs = vec![self];
inputs.extend(nodes.iter().map(|n| Rc::clone(n)));
// 次のNodeの親。新しいNodeを生み出したものを記録
let funcnode = FunctionNode{
input: inputs,
output: output.clone(),
func: Box::new(func),
};
// 次の計算の元になるNode
Node {
parents: vec![Rc::new(RefCell::new(funcnode))],
data: output,
grad: None,
}
}
他のnodeとの演算は「あるだろう」くらいの勢いで、実のところ、今はAffineやReLu、Sigmoid、Convくらいしか想定はしていないので、他nodeの引数は仮で持たせています。
// ここでnodesと複数Nodeを受け取るように一応設計。
fn apply<F>(self: Rc<Self>, func: F, nodes: &[Rc<Node>] )
Backwardがまだわからない。
ずっと一本道で来ていたら、逆伝播もわかるのですが、addとかで道が別れた時に、どうnode(勾配)を集めていけばいいのか、すこし迷子です。
今後しっかりゼロつくを読み込むつもりです。
今後の道筋
Trace関数でどのNode順で追えばいいかを集める。
(これは来た道の逆順だよね? 複数ある時にどう設計すれば思ったような形になるのか悩み中)
fn trace(self: Rc<Self>, is_traced: &mut HashSet<*const Node>, nodes_piled: &Vec<Rc<Node>>){
// すでにtrace済みのNodeじゃないかを確認
let ptr = Rc::as_ptr(&self);
if is_traced.contains(&ptr) {
return;
}
// なかったらInsertする。
is_traced.insert(ptr);
// 親を遡るよ
let parents = &self.parents;
for p in parents.iter() {
// refcellを剥き剥きするときはborrow
let nodes = &p.borrow().input;
for n in nodes.iter() {
// rcなのでcloneしてtraceします。
n.clone().trace(is_traced, nodes_piled);
}
}
}
損失を求めて、最適化を行う。
上記を現状の形で一通り回したいです。
そのあとは、batchとかchannelも考慮した拡張版をかんがえていきたいです。
現状での振り返り
問題は2つあって、一つは計算グラフの設計そのもの。もう一つはRustの学習曲線が急なことです。
nodeを記録するところで結構時間がかかりました。どのように問題を切り分けるかが難しくて。
gradをnodeに置くかFunctionNodeに置くかも丸一日は悩み、traceの中でbackwardをやってしまいたいという気持ちに今なっていたりもします。正直まだまだわからないことだらけ。
Rustに関しては、この取り組みでかなり所有権がピンとくるようになりました。
最初はかなり「わかったつもり」になっていたのですが、やればやるほどわかってないことがわかり…。
「あー! そういうことか!!」というのをなん度も繰り返しました。
特に自動の参照解決でかなり勘違いが多かったです。自動で参照解決されるから、参照のところを参照じゃないって思い込んだり、map内で変数を渡す時は参照だから、中身も参照だろ、って思い込んだり。
RcもRefCellも今ソースコードで書いている書き方以外は多分わかってません。
わからんわからんばっかり書いてますが、良いニュースもありまして。
この計算グラフ全部で大体200行くらいになっているんですが(matmul等も合わせて)、全部何もみずにゼロから書けます。(手が覚えました)
一ヶ月ほど悩み続けたのは無駄ではなかったです。
テストもTensor絡みではすっと書けるようになりましたし。
結構いいRustスタートが切れていると思います。
Discussion