🔥

RustのDeep Learningフレームワークburnで足りない関数の実装例(onehot matrixを作る関数とか)

2024/06/05に公開

はじめに

  • この記事は2023/12月に記述したもので、最新の情報ではないので注意してください。

趣味でRustのDeep Learningフレームワークのburnを使用していて、burnに組み込まれていない関数が欲しいという状態になりました。
そこで実装に足りていなかった関数の実装方法をメモとして残します。
burnは2023年12月現在活発に開発されているフレームワークで、onnx形式のimportも実装中となっているのでonnxを読み込んでburnを使ったモデル定義のコードを自動生成するような仕組みもあります。
サポートされている関数などが限定されるので今後に期待です。
burnを使ってモデルを学習し、デプロイもburnを使うというようなこともできます。

https://github.com/tracel-ai/burn

onehot行列の作成

onehot vectorを作成する関数は実装されているが、onehotのindexベクトル([1, 10, 17]みたいな)をonehotのマトリックスに変換する関数は実装されていない。

fn one_hot_int<B: Backend, const D: usize, const D2: usize>(
    indices: Tensor<B, D, Int>,
    num_classes: usize,
) -> Tensor<B, D2, Int> {
    debug_assert!(D + 1 == D2);
    let dims = {
        let mut dims = [0; D2];
        let (last, init) = dims.split_last_mut().unwrap();
        *last = num_classes;
        init.copy_from_slice(&indices.dims());
        dims
    };

    let alt_dims = {
        let mut alt_dims = dims.clone();
        alt_dims[D] = 1;
        alt_dims
    };

    let indices = indices
        .unsqueeze::<D2>()
        .reshape(alt_dims)
        .repeat(D, num_classes);

    Tensor::zeros(dims).scatter(D, indices, Tensor::ones(dims)) / num_classes as f32
}

randint

この実装でいいのか微妙だが、Uniformで作ったFloatのTensorをintにcastしている。

fn randint<B: Backend, const D: usize, S: Into<burn::tensor::Shape<D>>>(
    low: i32,
    high: i32,
    shape: S,
) -> Tensor<B, D, Int> {
    let mut t: Tensor<B, D, Float> = Tensor::random(
        shape,
        burn::tensor::Distribution::Uniform(low as f64, high as f64),
    );
    let t = t.int();
    t
}

Discussion