Open5

RustでHuggingFace HubのDatasetsを読む

PlatPlat

candle-datasets, hf-hub, parquet を使うことで HuggingFace Datasets を Rust から扱うことができる。

全然知らないAPIばかり+ググっても出てこない情報が多かったのでメモ

PlatPlat

使うもの

hf-hub

https://github.com/huggingface/hf-hub

https://crates.io/crates/hf-hub

https://docs.rs/hf-hub/0.3.2/hf_hub/

huggingface hub の API を叩くための Rust ライブラリ。今回はAPIクライアントを作成するだけ。

candle-datasets

https://github.com/huggingface/candle/tree/main/candle-datasets

https://docs.rs/candle-datasets/0.4.1/candle_datasets/

huggingface が作ってる PyTorch の Rust 代替である candle の一部。データセットを雑にダウンロードするのに使う。

parquet

https://docs.rs/parquet/latest/parquet/

huggingface datasets でよく見かけるよくわからない怪しい .parquet 拡張子のファイルの読み書き(今回は書かない)ができる。

parquet ついてはこの記事を参照:

https://dev.classmethod.jp/articles/convert-from-csv-to-parquet-using-arrrow-rs/

PlatPlat

使うクレートを追加していく

hf-hub

cargo add hf-hub 

tokio feature があるらしいが、なくても動くので今回は使わない。お好みで。

candle-datasets

GitHub の最新版を使う。

cargo add --git https://github.com/huggingface/candle.git candle-datasets

features は特に指定しなくていい。

参考:

https://huggingface.github.io/candle/guide/installation.html

parquet

cargo add parquet
完了

Cargo.toml の dependencies が多分こんな感じになる

[dependencies]
candle-datasets = { git = "https://github.com/huggingface/candle", version = "0.4.2" }
hf-hub = "0.3.2"
parquet = "51.0.0"
PlatPlat

現状(2024/04/02)、candle-datasets の使用例が見つからない(少なくとも公式のはない)のでかなりめんどかった。一応下のコードで最低限動いた。Rust初心者なので、変な書き方があるかもしれないが許して欲しい。

main.rs
use candle_datasets::hub::from_hub as load_dataset;
use hf_hub::api::sync::Api;
use parquet::file::reader::FileReader;

fn main() {
    let api = Api::new().unwrap();
    let repo_name = "elyza/ELYZA-tasks-100".to_string();
    let ds = load_dataset(&api, repo_name).unwrap();

    while let Some(file) = ds.iter().next() {
        let schema = file.metadata().file_metadata().schema();
        println!("schema: {:?}", schema);
        if let Ok(row_iter) = file.get_row_iter(Some(schema.clone())) {
            for row in row_iter {
                if let Ok(row) = row {
                    println!("{:?}", row);
                }
                break;
            }
        }
        break;
    }
}
出力
default/test/0000.parquet [00:00:01] [████████████████████████████] 67.19 KiB/67.19 KiB 41.16 KiB/s (0s)
schema: GroupType { basic_info: BasicTypeInfo { name: "schema", repetition: None, converted_type: NONE, logical_type: None, id: None }, fields: [PrimitiveType { basic_info: BasicTypeInfo { name: "input", repetition: Some(OPTIONAL), converted_type: UTF8, logical_type: Some(String), id: None }, physical_type: BYTE_ARRAY, type_length: -1, scale: -1, precision: -1 }, PrimitiveType { basic_info: BasicTypeInfo { name: "output", repetition: Some(OPTIONAL), converted_type: UTF8, logical_type: Some(String), id: None }, physical_type: BYTE_ARRAY, type_length: -1, scale: -1, precision: -1 }, PrimitiveType { basic_info: BasicTypeInfo { name: "eval_aspect", repetition: Some(OPTIONAL), converted_type: UTF8, logical_type: Some(String), id: None }, physical_type: BYTE_ARRAY, type_length: -1, scale: -1, precision: -1 }] }
Row { fields: [("input", Str("仕事の熱意を取り戻すためのアイデアを5つ挙げてください。")), ("output", Str("1. 自分の仕事に対する興味を再発見するために、新しい技能や知識を学ぶこと。\n2. カレッジやセミナーなどで講演を聴くことで、仕事に対する新しいアイデアや視点を得ること。\n3. 仕事に対してストレスを感じている場合 は、ストレスマネジメントのテクニックを学ぶこと。\n4. 仕事以外の楽しいことをすることで、ストレスを発散す ること。\n5. 仕事に対して自己評価をすることで、自分がどのように進化しているのかを知ること。")), ("eval_aspect", Str("- 熱意を取り戻すのではなく、仕事の効率化・スキルアップのような文脈になっていたら1点減点\n- 出したアイデアが5つより多い、少ない場合は1点減点\n- 5つのアイデアのうち、内容が重複しているものがあれば1点減点\n\n"))] }

ここではサイズが小さくて検証によかったので elyza/ELYZA-tasks-100 を使った。

説明

冒頭で指定している candle_datasets::hub::from_hub を使ってデータセットをダウンロード + Parquet の読み込みを行っている。ここのダウンロードはいつもの HuggingFace エコシステムと同じような感じになる。from_hub の名前なんのことかよくわからないので datasets ライブラリの雰囲気で load_dataset という名前にした。

キャッシュディレクトリの指定方法は後述。

    let api = Api::new().unwrap();
    let repo_name = "elyza/ELYZA-tasks-100".to_string();
    let ds = load_dataset(&api, repo_name).unwrap();

API クライアントを作ってレポを読み込んでいる。

    while let Some(file) = ds.iter().next() {

Python の datasets ライブラリと異なり、返ってくるのは Parquet ファイルのリストなので、それぞれのリストごとにループする。

    let schema = file.metadata().file_metadata().schema();

スキーマは、データセットの構造が指定されているのだと思う。

    if let Ok(row) = file.get_row_iter(Some(schema.clone())) {

スキーマに従って行のイテレーターを取得する。(スキーマも get_row_iterfile に由来しているのに、一発で行えるメソッドは自分の調べた限りだとなさそう...? こちらは parquet クレートの領域なので、もしかしたらもっとスマートなやりかたがあるかもしれない)

            for iter in row {
                if let Ok(iter) = iter {
                    println!("{:?}", iter);
                }
                break;
            }

1行ずつループしていって、正常に読めたら内容を出力する。今回は1行目で終わり。

PlatPlat

キャッシュディレクトリの指定

ファイルをダウンロードしてくるためのキャッシュディレクトリは API クライアントの方で指定する。(Pythonのdatasetsでは読み込み時に指定できたが、それはできないみたい?)

use hf_hub::api::sync::Api;

fn main() {
    let api = Api::new().unwrap();

の部分を以下に変更する

use hf_hub::api::sync::ApiBuilder; // 変更
use std::path::PathBuf; // 追加。パスの指定用

fn main() {
    let builder = ApiBuilder::new();
    let builder = builder.with_cache_dir(PathBuf::from("/huggingface/cache"));

    let api = builder.build().unwrap();

ApiBuilder を使ってキャッシュディレクトリを指定し、 build して Api を作成している。