⚙️

『ゼロから作るDeep Learning ❸』をRustに翻訳してみた : step3まで

2024/06/12に公開

はじめに

元々Zennで記事をはじめたコンセプト「機械学習の活用に役立つコンピュータサイエンスを学びたい」にしたがって、今後はDeep Learningをコンピュータサイエンスの観点で深掘りしたいと思っています。
その第一歩として『ゼロから作るDeep Learning ❸ ―フレームワーク編』で構築されているDeep LearningのフレームワークDeZeroをRustで実装し、フレームワークを一段レイヤーを下げて理解することを試みます。

上記の目的のため、Rustコードを示すだけでなく、pythonによる実装との差異について深掘りして、pythonがいい感じに勝手に捌いてくれている事項をできるだけ明るみにしたいと思います。そのため、pythonに精通しているものの、Rustはあまり慣れていない方にも読んでいただける内容になったかなと思います。

本記事では、全60ステップのうち本当の最初であるStep3まで(変数の定義〜関数の連結)を実装したいと思います。ただ、たった3ステップとはいえ、numpyにかわるモジュールはどうするのか、pythonのクラスをどう再現するのか、関数に値をどのように渡すのかなど、Rustへの翻訳にあたり結構検討すべき事項は多くなりました。

コーディングの理解が浅い状態で進めており、こうすべき等の事項がありましたら是非ご指摘をお願いいたします。

環境

  • MacBookPro(14inch) M3 MAX (128GB)
  • OS : Sonoma 14.5
  • Rust Edition : 2021

元のコードと仕上がり

元のコード

原点のgithubリポジトリはコチラです。
https://github.com/oreilly-japan/deep-learning-from-scratch-3/blob/master/steps/step03.py

仕上がり

先に仕上がりのコードを示します。
もしCargoを用いた実行方法に不明点があれば、こちら等を参照ください。
https://qiita.com/yoshii0110/items/6d70323f01fefcf09682

依存関係

Cargo.toml
[package]
name = "step01"
version = "0.1.0"
edition = "2021"

[dependencies]
ndarray = "0.15.4"

mainファイル

main.rs
// pythonのnp.arrayはndarrayとして提供されている
// IxDynはコンパイル時には未定義の次元数を表現する型
use ndarray::{Array, IxDyn};

// Variableを構造体で定義
struct Variable {
    data: Array<f64, IxDyn>,
}

// Variableのコンストラクタ(pythonの__init__に相当)を定義
impl Variable {
    fn new(data: Array<f64, IxDyn>) -> Variable {
        Variable { data }
    }
}

// Functionをトレイトで実装する
trait Function {
    // 関数を呼び出した時の挙動を定義
    // Variableおよびself(内部変数)を参照として受け取り、Variableとして返す
    // 今回は呼び出し時にselfは書き換えないため、参照(&)で指定(書き換える場合は&mut)
    // インプットのVariableも書き換えせず、また再利用できるように&で指定
    fn call(&self, input: &Variable) -> Variable {
        let x = &input.data;
        let y = self.forward(x);
        Variable::new(y)
    }

    // トレイトを実装するさいに定義すべきメソッドを示す
    // (未定義の場合はコンパイルエラーが出る)
    fn forward(&self, x: &Array<f64, IxDyn>) -> Array<f64, IxDyn>;
}

// Squareを構造体として定義する
struct Square;

// FunctionトレイトをSquareとして実装。
// Functionトレイトで未定義だったfowardを定義
impl Function for Square {
    fn forward(&self, x: &Array<f64, IxDyn>) -> Array<f64, IxDyn> {
        x.mapv(|a| a.powi(2))
    }
}

// Squareと同様にExpを実装
struct Exp;

impl Function for Exp {
    fn forward(&self, x: &Array<f64, IxDyn>) -> Array<f64, IxDyn> {
        x.mapv(|a| a.exp())
    }
}

fn main() {
    // 関数のインスタンスを生成
    let layer_a = Square {};
    let layer_b = Exp {};
    let layer_c = Square {};

    // xを定義
    let data = Array::from_elem(IxDyn(&[]), 0.5);
    let x = Variable::new(data);

    // 順伝播の計算
    let a = layer_a.call(&x);
    let b = layer_b.call(&a);
    let y = layer_c.call(&b);
    println!("x = {}, y = {}", x.data, y.data);
}

出力

出力
x = 0.5, y = 1.648721270700128

コードの解説

利用する外部クレート

ndarray

元のpythonコードでは、インプットとなる配列をnp.arrayを用いて定義しています。コードをRustに翻訳するにあたり、このライブラリを置き換える必要があります。
幸いなことに、RustにもNumPyライクで配列の定義・操作が行えるndarrayクレートがあります。こちらのドキュメントにNumPyユーザー向けの解説がありますので、必要に応じ参照ください。
https://docs.rs/ndarray/latest/ndarray/doc/ndarray_for_numpy_users/index.html

上記ページに主要な差異がまとめられていますが、ピックアップすると以下の通りです

所有権とビュー

  • NumPy
    • pythonなので所有権、ビュー、可変の区別はありません
  • ndarray
    • 配列の型により所有権がことなり、明示する必要があります(例 : Array型は所有権あり、ArrayViewはビュー、ArrayViewMutは可変)

次元のとりあつかい

  • NumPy
    • 配列は動的で、ユーザーの入力により決定します
  • ndarray
    • 固定次元の配列を用いることができます(例 : Array3は3次元配列)
    • 今回のように可変次元を用いたい場合、後述のとおり次元をIxDyn型として定義します

「ライブラリのインポート」の翻訳

以後、順次元のpythonコードの翻訳を進めます。
まず、以下のライブラリのインポート箇所を翻訳しましょう。

python
import numpy as np

前述のとおり、np.arrayライクな配列の扱いができるndarrayを用います。配列の型については、1次元の系列データや画像の2次元配列など、ユーザー入力に応じて受け入れることができるよう、任意の配列の型をサポートするArrayを用います。また、ユーザー入力時に次元を決定(すなわちコンパイル時には入力次元は不定)できるよう、IxDynを入力次元の型として用います。
結果、Rustコードでは以下としました。

Rust
use ndarray::{Array, IxDyn};

上記でPythonと同様、ユーザー入力に応じた任意の入力に応じて配列を定義できます。ただ、当然ながらこのような型で動的な定義を採用すると、オーバーヘッドによるパフォーマンスの若干の低下が生じます。そのようなオーバーヘッドがネックになるような場合であれば、汎用なArray型ではなく、Array2Array3を用いることができます。このオーバーヘッドの影響については、今後多次元配列のインプットを用いるさいに検証してみたいと思います。

「Variableクラスの定義」の翻訳

次にVariableクラス(変数を収める箱)の定義です。pythonコードでは以下のとおり定義されます。

python
class Variable:
    def __init__(self, data):
        self.data = data

上記のpythonコードでは以下を行なっています。

  • Variableというクラスを定義
  • コンストラクタ(__init__)でVariableクラスのインスタンス生成時に、入力したdataselfに格納。

pythonでは上記を一度に行えますが、Rustでは二段階で行う必要があります。

まず、Variableのクラスの定義です。Variableクラスはとりあえず、インスタンス生成時に入力するdataを格納する必要があります。この機能は、Rustでは構造体(Struct)で実装できます。

また、インプットするdataは前述のとおりndarrayのArrayを使用します。Rustのコーディングではここでインプットの型を指定しなければなりません。Arrayでは以下の二つを定義します。

  • 配列の要素の型
  • 配列の次元

ここでは、pythonにあわせて以下としています

  • 配列の要素の型 : f64(pythonのfloatと同じく64bit浮動小数点)
  • 配列の次元 : IxDyn(ユーザー入力に応じた次元)

次に、コンストラクタ(__init__)を翻訳します。Rustではコンストラクタは一般的にメソッドnew()として定義されます。Rustでは構造体にメソッドを定義するとき、構造体の定義の中では行えず、定義した構造体に対し別途implキーワードにより定義します。

結果、Rustコードでは以下としました。

Rust
// Variableを構造体で定義
struct Variable {
    data: Array<f64, IxDyn>,
}

// Variableのコンストラクタ(pythonの__init__に相当)を定義
impl Variable {
    // new()メソッドを定義。Arrayをインプットとして受け取り、Variable構造体に格納し返す
    fn new(data: Array<f64, IxDyn>) -> Variable {
        Variable { data }
    }
}

「Functionクラスの定義」の翻訳

続いてFunctionを定義します。FunctionクラスはVariableクラスをインプットとして受け取り、forward関数を適用した後、その実行結果をVariableクラスとして返します。pythonコードでは以下のとおり定義されます。

python
class Function:
    def __call__(self, input):
        x = input.data
        y = self.forward(x)
        output = Variable(y)
        return output

    def forward(self, x):
        raise NotImplementedError()

上記のpythonコードでは以下を行っています。

  • Functionを呼び出した時の動作を__call__で以下の通り定義
    • インプットデータを取得
    • forward関数を適用
    • 出力をVariableに格納し返す
  • forward関数をインスタンス生成時に定義しなければならないことを明示

それでは、Functionの翻訳を進めます。

pythonではVariableと同じくclassで定義されていましたが、Rustでは今回のような"関数の雛形"はtraitで定義できます。定義したtraitを用いることで、SquareExpなど各関数の共通部分をひとまとめに定義できます。

Functionトレイトは、Pythonでの定義と同様、以下の2ステップで定義します。

  • 関数を呼び出した時の挙動を定義(fn call
  • トレイトを各関数に展開するときに定義すべき順伝播関数のインプット・アウトプットの型だけ定義(fn forward

このあたりから、所有権を意識した実装が必要になります。Rustでは、関数に入力xを渡したさい、関数にxの所有権が写り、関数の実行完了後にxが破棄されてしまいます。入力xは再利用するケースが多々あると考えられるため、基本的な動作ではxは破棄しないように実装したいです。
そのため、関数への入力は所有権が移らない参照(&)で渡します。インプットはVariableクラスのインスタンスのため、インプットの型は&Variableとなります。また、関数自身を指すselfについても、再利用できるよう&selfとして呼び出します。
以上から、Functionは以下のとおり定義できます。

Rust
// Functionをトレイトで実装する
trait Function {
    fn call(&self, input: &Variable) -> Variable {
        let x = &input.data;
        let y = self.forward(x);
        Variable::new(y)
    }

    // トレイトを実装するさいに定義すべきメソッドを示す
    // (未定義の場合はコンパイルエラーが出る)
    fn forward(&self, x: &Array<f64, IxDyn>) -> Array<f64, IxDyn>;
}

上記でfn forwardの中身は定義されていません。そのため、インスタンスを生成するさいにこのforwardを上書きしないと、コンパイルエラーが発生します。Pythonではraiseを用いて明示的にエラー表示を実装する必要がありますが、Rustではその必要はありません。

Squareの実装

次に、Squareを定義しましょう。pythonコードでは以下となっています。

python
class Square(Function):
    def forward(self, x):
        return x ** 2

Rustでは、まずSquareを空の構造体として定義します。

Rust
// Squareを構造体として定義する
struct Square;

そのうえで、Functionトレイトをimplで実装することでcallメソッドを引き継げます。ここで、未定義だったforwardを定義します。

Rust
// Functionトレイトで未定義のfowardを定義したうえで、Squareとして実装。
impl Function for Square {
    fn forward(&self, x: &Array<f64, IxDyn>) -> Array<f64, IxDyn> {
        x.mapv(|a| a.powi(2))
    }
}

関数の中身の書き方

x.mapv(|a| a.powi(2))について解説します。
まず、forwardの引数にあるように、xArrayの参照です。しかしRustではなんらかのメソッド(今回はmapv)を適用するときに、自動で参照外しが行われるため、コードで明示する必要はありません。なお、以下のように参照外しを明示した場合でも同じ動作をします。

(*x).mapv(|a| a.powi(2))

続いてmapvについて解説します。これはArrayに対して定義されているメソッドで、xの各要素に対して()内の関数を適用し、xと同じ次元のArrayとして返します。
https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html#method.mapv

次に、|a| a.powi(2)ですが、これはクロージャという匿名関数です。この場合、aを引数として、a.powi(2)(aの整数乗、今回は2乗)を返します。クロージャは、Rustの通常の関数のように関数の引数と戻り値の型の注釈を省略できます。これは、クロージャはコードの内部でしか使用されず、加えてに狭い文脈でしか使用しないためです。

Expの実装

ExpSquareと同様に、以下で定義します。

Rust
struct Exp;

impl Function for Exp {
    fn forward(&self, x: &Array<f64, IxDyn>) -> Array<f64, IxDyn> {
        x.mapv(|a| a.exp())
    }
}

動作確認の翻訳

最後に、定義したVariableおよび関数たちの動作確認を行います。pythonコードでは以下のとおりとなっています。

python
A = Square()
B = Exp()
C = Square()

x = Variable(np.array(0.5))
a = A(x)
b = B(a)
y = C(b)
print(y.data)d

上記のコードは、以下のステップで動作確認を行っています。

  • SquareクラスおよびExpクラスのインスタンスの生成
    • SquareクラスがAおよびBクラスと別物として生成されているのが重要
  • xをスカラーの配列として定義
  • 順伝播を計算

上記をRustに翻訳します。
まず、SquareおよびExp構造体のインスタンスを生成します。Rustでの構造体のインスタンスの生成は、例えばnameを定義すべき変数として通常以下のように行います。

let instance_A = Struct_A { name: "Bob"}

今回、Square構造体およびExp構造体では、定義すべき変数はすべて定義されているため、何も指定せず生成できます。そのため、コードは以下となります

Rust
let layer_a = Square {};
let layer_b = Exp {};
let layer_c = Square {};

なお、Rustでは変数名に大文字を使うのは標準のコード規則に沿わない(コンパイラに注意される)ため、A, B, Cという関数名を置き換えています。

次に、xの定義です。pythonではnp.arrayにスカラーまたは配列を与えてあげればいい感じに次元を解釈してくれますが、Rustではそうはいきません。ここでは、要素がすべて指定した値(elem)でshapeがdimになるArray::from_elem(dim, elem)を用いて定義します。これをVariable::newメソッドによりインスタンスとして生成します。
以上をコードにすると以下となります。

Rust
// xを定義
let data = Array::from_elem(IxDyn(&[]), 0.5);
let x = Variable::new(data);

ここでIxDyn(&[])について説明したいと思います。IxDynに渡す引数は、例えば2 \times 3 \times 4の配列を入力したい場合、&[2, 3, 4]を入力します。今回は次元なしのスカラーを入力するので、長さゼロの配列の参照&[]をインプットしています。以下では、ただの配列ではなく参照(&)を入力する理由について説明します。
Rustでは参照でない配列を指定する場合、コンパイル時にその配列の長さを固定する必要があります。今回、入力の配列の長さはインプットの次元に相当します。しかし、それではVariableをユーザー定義に応じた次元にできないし、そもそも次元数がコンパイル時に可変にできるIxDynを使っている意味がなくなります。そのため、IxDynにインプットする配列は、その長さを可変としたままコンパイルできる、配列の参照としてインプットします。
ちなみに、配列の参照はRustではスライスと呼ばれます。Rustにおけるスライスはpythonと似たような機能であり配列の一部を抜き出してくるのが主な使い方で、以下のような形で定義します。

fn main() {
    let array1 = [1, 2, 3, 4, 5];
    let slice = &array1[1..4]; // スライスを作成
    
    println!("{:?}", slice); // 出力: [2, 3, 4]
}

配列の一部を抜き出すという性質上、配列の長さは可変とする必要があるため、スライスは可変長に対応しています。IxDynのインプットにスライスが指定されているのは、この可変長に対応しているという性質を応用している、私は理解しています。
なお、可変長の入力に対応できるものとして他にVec型があり、こちらもIxDynの入力に使えますが、わざわざ次元の指定のためにVecを定義するのは冗長なので、ここではスライスによる指定としています。
https://docs.rs/ndarray/latest/ndarray/type.IxDyn.html

最後に、順伝播の計算です。
これは、Function::callメソッドを順次呼び出すだけです。なお、callの引数は上で定義したように、入力した値を再利用できるよう参照(&)で渡すよう定義したので、それに従います。上記を踏まえると、順伝播の計算および結果の出力は以下となります。

Rust
// 順伝播の計算
let a = layer_a.call(&x);
let b = layer_b.call(&a);
let y = layer_c.call(&b);
println!("x = {}, y = {}", x.data, y.data);
出力
x = 0.5, y = 1.648721270700128

最後のprintln!()について補足します。これは文字列を標準出力に表示する機能です。最初の引数の文字列に、{}で示した位置に後続の引数の値を順次入れて出力したのち、改行を表示します(改行なしはprint!())。ぱっと見関数に見えますが、名前の最後に!がついており、このようなキーワードはマクロと呼ばれます。マクロは関数に見えますが、実際は一連のコードを省略したものであり、コンパイル時に展開されます。また、マクロはインプットした変数を参照として用いるため、参照として入力しなくても所有権が移らず、マクロを呼び出した後もインプットした変数を使えます。なので、以下のように2回呼び出してもエラーになりません。

println!("x = {}, y = {}", x.data, y.data);
println!("x = {}, y = {}", x.data, y.data);

まとめ

『ゼロから作るDeep Learning ❸ ―フレームワーク編』のstep3までをRust翻訳し、そのなかでpythonとRustの差異を自分が納得できるまで解説してみました。引き続き、step4以降の翻訳と解説を進めていきたいと思います。

初っ端なので解説ばかりになってしまいましたが、以降は既出の内容が増えていくと思うので、スピードアップできると思います。

繰り返しになりますが、Rustコーディングの理解が浅い中で進めているため、よりよいコーディング方法があったり、認識に誤りがある場合などはご指摘いただけますと大変ありがたいです。

Discussion