Rustで深層学習フレームワークを開発しています
リポジトリ
⭐️がつくとやる気が出ます!
サンプルコード(GAN など)
ちなみに上記のサンプルを実行すると、以下のような画像が生成できました。
|  |  |  |  | 
|---|
はじめに
- Rust で深層学習フレームワーク Zenu を開発しています。
- CPU / GPU 両方に対応し、型安全とメモリ安全を重視した設計です。
- MNIST や GAN などのサンプルコードがありますので、ぜひ試してみてください。
モチベーション
Python / PyTorch での開発のつらさ
深層学習といえば Python + PyTorch が定番ですが、実際の研究・開発では以下のような問題に悩まされることが多いです。
- 
静的型がない(動的型付け) 
 GPU / CPU 間のデバイス不一致や、float32 と float64 の混在など、型にまつわるバグがランタイムエラーで起こりがち。長時間学習してやっと失敗がわかることも…。
- 
ランタイムエラーが学習後に発覚しがち 
 大規模モデルや長時間学習を回してから「デバイスが違う」などのエラーに直面すると、とてもつらい。再学習のコストが大きい。学習が一通り終わって最後のtestを回し始めた時にruntime errorが出ると、とても辛い。
import torch
x = torch.ones((2, 2), device="cuda")
y = torch.ones((2, 2), device="cpu")
# GPU テンソルと CPU テンソルを足す → 実行時エラー
z = x + y
実行すると…
RuntimeError: Expected all tensors to be on the same device...
学習や推論の途中で気づくと精神的なダメージが大きいですよね。
Rust でやるメリット
- 
型安全 & 所有権 
 Rust は静的型で、かつ型推論が優秀です。テンソルのデバイス (CPU / GPU) やスカラー型 (f32 / f64) を型レベルで区別でき、ミスマッチはコンパイルエラーに。
 例:Matrix<Owned<f32>, Dim2, Cpu>とMatrix<Owned<f32>, Dim2, Nvidia>は別の型として扱うので、CPU テンソルに GPU テンソルを足すようなコードはコンパイルが通りません。
- 
Cargo によるテストが標準で充実 
 C++ と比べて、ユニットテストやベンチマークを導入しやすい。CMake などの外部ツールをあまり使わなくて済む。
 この利点はかなり大きいと思っています。
- 
Rust が好き 
 皆さんはRust好きですよね?ね?ね?じゃあ、Rust で深層学習フレームワークを作りましょう!
Zenu の概要
Zenu は、大きく分けると以下の 6 クレートに分割して開発しています。
- 
zenu - トップレベルクレート。利用するときはこれを指定。
- 
feature = ["nvidia"]を有効にすると CUDA / cuBLAS / cuDNN を使った GPU 演算が可能。
 
- 
zenu-cuda - CUDA runtime / driver / cuBLAS / cuDNN など、NVIDIA 系の低レベル API をまとめたクレート。
- CUDA カーネルのビルドや呼び出し周りをカプセル化。
 
- 
zenu-matrix - pythonでいうところの numpy に相当する多次元配列クレート。
- ndarray を参考にさせていただきました。(圧倒的感謝)
- CPU / GPU 両方に対応。
 
- 
zenu-autograd - 自動微分&演算グラフクレート。
- PyTorch の torch.autogradみたいなイメージで、forward / backward を実装。
- 
Variable型を用いて勾配を保持。
 
- 
zenu-layers - PyTorch の torch.nn相当の高レベルレイヤ (Linear, Conv, BatchNorm など) をまとめるクレート。
- 
Moduleトレイトによるインターフェースを提供。
 
- PyTorch の 
- 
zenu-optimizer - パラメータ更新アルゴリズム (SGD, Adam, AdamW など) をまとめるクレート。
- PyTorch の torch.optimに近いイメージ。
 
クレート分割の課題
- 細かく分けすぎて 共通の変更が複数クレートに及ぶ と、API の食い違いが起きやすい。
- 自動微分はテンソル演算と密接なので、切り離すのが意外と大変だった。
 feature="nvidia" で GPU サポートを ON にする
Cargo.toml の dependencies に次のように書きます:
[dependencies.zenu]
version = "*"
features = ["nvidia"]  # GPU 機能を有効化
これで内部的に zenu-cuda がビルドされ、Nvidia デバイスが使えるようになります。
use zenu::matrix::device::nvidia::Nvidia;
let model_gpu = SimpleModel::<f32, Nvidia>::new();
もし CPU から GPU へ転送するなら
let model_gpu = model.to::<Nvidia>();
と書くだけです。
主な機能
- 最適化アルゴリズム: SGD / Adam / AdamW など
- NN レイヤー: Linear, Convolution, Pooling, Dropout, BatchNorm など
- 活性化関数: ReLU, Sigmoid, Tanh, Softmax
- 損失関数: 二乗誤差、クロスエントロピー
MNIST を使った実装例
ここでは Zenu を使って MNIST で学習する簡単なサンプルコードを載せます。
MNISTコード例
use zenu::{
    autograd::{
        activation::relu::relu, creator::from_vec::from_vec, loss::cross_entropy::cross_entropy,
        no_train, set_train, Variable,
    },
    dataset::{train_val_split, DataLoader, Dataset},
    dataset_loader::mnist_dataset,
    layer::{layers::linear::Linear, Module},
    matrix::{
        device::{cpu::Cpu, Device},
        num::Num,
    },
    optimizer::{sgd::SGD, Optimizer},
};
use zenu_macros::Parameters;
// モデル定義
#[derive(Parameters)]
#[parameters(num=T, device=D)]
pub struct SimpleModel<T: Num, D: Device> {
    pub linear_1: Linear<T, D>,
    pub linear_2: Linear<T, D>,
}
impl<D: Device> SimpleModel<f32, D> {
    #[must_use]
    pub fn new() -> Self {
        Self {
            linear_1: Linear::new(28 * 28, 512, true),
            linear_2: Linear::new(512, 10, true),
        }
    }
}
impl<D: Device> Default for SimpleModel<f32, D> {
    fn default() -> Self {
        Self::new()
    }
}
// Module トレイト実装 (forward 計算)
impl<D: Device> Module<f32, D> for SimpleModel<f32, D> {
    type Input = Variable<f32, D>;
    type Output = Variable<f32, D>;
    fn call(&self, inputs: Variable<f32, D>) -> Variable<f32, D> {
        let x = self.linear_1.call(inputs);
        let x = relu(x);
        self.linear_2.call(x)
    }
}
// MNIST データセット
struct MnistDataset {
    data: Vec<(Vec<u8>, u8)>,
}
impl Dataset<f32> for MnistDataset {
    type Item = (Vec<u8>, u8);
    fn item(&self, item: usize) -> Vec<Variable<f32, Cpu>> {
        let (x, y) = &self.data[item];
        let x_f32 = x.iter().map(|&xi| xi as f32).collect::<Vec<_>>();
        let x = from_vec::<f32, _, Cpu>(x_f32, [784]);
        x.get_data_mut().to_ref_mut().div_scalar_assign(127.5);
        x.get_data_mut().to_ref_mut().sub_scalar_assign(1.0);
        let y_onehot = (0..10)
            .map(|i| if i == *y as usize { 1.0 } else { 0.0 })
            .collect::<Vec<_>>();
        let y = from_vec(y_onehot, [10]);
        vec![x, y]
    }
    fn len(&self) -> usize {
        self.data.len()
    }
    fn all_data(&mut self) -> &mut [Self::Item] {
        &mut self.data
    }
}
#[expect(clippy::cast_precision_loss)]
fn main() {
    // モデルを CPU デバイスで作成 (GPU の場合は SimpleModel::<f32, Nvidia>::new())
    let model = SimpleModel::<f32, Cpu>::new();
    // MNIST データ読み込み
    let (train, test) = mnist_dataset().unwrap();
    let (train, val) = train_val_split(&train, 0.8, true);
    // DataLoader の作成 (PyTorch でいう DataLoader に近い)
    let test_dataloader = DataLoader::new(MnistDataset { data: test }, 1);
    // 最適化アルゴリズム: SGD
    let optimizer = SGD::<f32, Cpu>::new(0.01);
    for epoch in 0..20 {
        set_train();  // PyTorch の with torch.no_grad() の“逆”
        let mut train_dataloader = DataLoader::new(
            MnistDataset {
                data: train.clone(),
            },
            32,
        );
        train_dataloader.shuffle();
        let mut train_loss = 0.0;
        let mut num_iter = 0;
        for batch in train_dataloader {
            let input = batch[0].clone();
            let target = batch[1].clone();
            let pred = model.call(input);
            let loss = cross_entropy(pred, target);
            let loss_asum = loss.get_data().asum();
            // バックワードとパラメータ更新
            loss.backward();
            optimizer.update(&model);
            loss.clear_grad();
            train_loss += loss_asum;
            num_iter += 1;
        }
        train_loss /= num_iter as f32;
        // バリデーション
        no_train(); // PyTorch の with torch.no_grad()
        let val_loader = DataLoader::new(MnistDataset { data: val.clone() }, 1);
        let mut val_loss = 0.0;
        let mut num_val_iter = 0;
        for batch in val_loader {
            let input = batch[0].clone();
            let target = batch[1].clone();
            let pred = model.call(input);
            let loss = cross_entropy(pred, target);
            val_loss += loss.get_data().asum();
            num_val_iter += 1;
        }
        val_loss /= num_val_iter as f32;
        println!("Epoch: {epoch}, Train Loss: {train_loss}, Val Loss: {val_loss}");
    }
    // テスト
    let mut test_loss = 0.0;
    let mut num_test_iter = 0;
    for batch in test_dataloader {
        let input = batch[0].clone();
        let target = batch[1].clone();
        let pred = model.call(input);
        let loss = cross_entropy(pred, target);
        test_loss += loss.get_data().asum();
        num_test_iter += 1;
    }
    println!("Test Loss: {}", test_loss / num_test_iter as f32);
}
GPU で走らせたい場合は、feature で nvidia を有効にして、Variableやモデルに対して.to::<Nvidia>() を呼び出すだけです。
他の Rust 製フレームワークとの比較
- 
candle
 Rust で自前実装を頑張っているフレームワーク。Zenu と方向性は似ていますが、Metal や FlashAttention など先進的なバックエンドもサポートしていてすごい。
- 
tch-rs / burn
- tch-rs は PyTorch C++ バインディング。中身は PyTorch と同じなので実績は十分だけど、Rust の所有権や型システムを最大限活かすわけではない。
- burn は独自 DSL を構築するアプローチが印象的。Zenu とは別のベクトルで面白い。
 
今後の展望
- 
SIMD
 CPU での SSE / AVX などを利用した高速化
- マルチ GPU / 分散学習
- Transformer 系モデルの実装
- PyTorch モデルの読み書き / ONNX 対応
- 
ドキュメントの充実
- チュートリアルや API リファレンスを拡充
 
- 
サンプルコードの拡充
- 画像分類、物体検出、セグメンテーションなど
- 音声系の会社に転職したので音声系も頑張りたい
 
- 
エラーハンドリングの改善
- よりわかりやすいエラーメッセージを目指す
 
まとめ
- Rust 製の深層学習フレームワーク Zenu を作っています。
- CPU / GPU 両方に対応し、型安全とメモリ安全を重視した設計です。
- MNIST や GAN などのサンプルコードがありますので、ぜひ試してみてください。
- Issues、PR、マサカリ、大歓迎です!スターをいただけるとモチベが爆上がりします。
参考リンク
- 
ndarray
- zenu-matrix の開発にあたり、多次元配列の実装でお世話になりました。
 
- 
ゼロから作る Deep Learning 3
- 自動微分の実装で参考にしました。
 
- 
【作品紹介】Common Lispで深層学習フレームワークを0から作ってる話
- 同じく「自作フレームワークをやってみるぞ!」という気持ちになったきっかけの記事です。
 



Discussion