RustでHuggingFace HubのDatasetsを読む
candle-datasets
, hf-hub
, parquet
を使うことで HuggingFace Datasets を Rust から扱うことができる。
全然知らないAPIばかり+ググっても出てこない情報が多かったのでメモ
使うもの
hf-hub
huggingface hub の API を叩くための Rust ライブラリ。今回はAPIクライアントを作成するだけ。
candle-datasets
huggingface が作ってる PyTorch の Rust 代替である candle の一部。データセットを雑にダウンロードするのに使う。
parquet
huggingface datasets でよく見かけるよくわからない怪しい .parquet
拡張子のファイルの読み書き(今回は書かない)ができる。
parquet
ついてはこの記事を参照:
使うクレートを追加していく
hf-hub
cargo add hf-hub
tokio
feature があるらしいが、なくても動くので今回は使わない。お好みで。
candle-datasets
GitHub の最新版を使う。
cargo add --git https://github.com/huggingface/candle.git candle-datasets
features は特に指定しなくていい。
参考:
parquet
cargo add parquet
完了
Cargo.toml
の dependencies が多分こんな感じになる
[dependencies]
candle-datasets = { git = "https://github.com/huggingface/candle", version = "0.4.2" }
hf-hub = "0.3.2"
parquet = "51.0.0"
現状(2024/04/02)、candle-datasets
の使用例が見つからない(少なくとも公式のはない)のでかなりめんどかった。一応下のコードで最低限動いた。Rust初心者なので、変な書き方があるかもしれないが許して欲しい。
use candle_datasets::hub::from_hub as load_dataset;
use hf_hub::api::sync::Api;
use parquet::file::reader::FileReader;
fn main() {
let api = Api::new().unwrap();
let repo_name = "elyza/ELYZA-tasks-100".to_string();
let ds = load_dataset(&api, repo_name).unwrap();
while let Some(file) = ds.iter().next() {
let schema = file.metadata().file_metadata().schema();
println!("schema: {:?}", schema);
if let Ok(row_iter) = file.get_row_iter(Some(schema.clone())) {
for row in row_iter {
if let Ok(row) = row {
println!("{:?}", row);
}
break;
}
}
break;
}
}
default/test/0000.parquet [00:00:01] [████████████████████████████] 67.19 KiB/67.19 KiB 41.16 KiB/s (0s)
schema: GroupType { basic_info: BasicTypeInfo { name: "schema", repetition: None, converted_type: NONE, logical_type: None, id: None }, fields: [PrimitiveType { basic_info: BasicTypeInfo { name: "input", repetition: Some(OPTIONAL), converted_type: UTF8, logical_type: Some(String), id: None }, physical_type: BYTE_ARRAY, type_length: -1, scale: -1, precision: -1 }, PrimitiveType { basic_info: BasicTypeInfo { name: "output", repetition: Some(OPTIONAL), converted_type: UTF8, logical_type: Some(String), id: None }, physical_type: BYTE_ARRAY, type_length: -1, scale: -1, precision: -1 }, PrimitiveType { basic_info: BasicTypeInfo { name: "eval_aspect", repetition: Some(OPTIONAL), converted_type: UTF8, logical_type: Some(String), id: None }, physical_type: BYTE_ARRAY, type_length: -1, scale: -1, precision: -1 }] }
Row { fields: [("input", Str("仕事の熱意を取り戻すためのアイデアを5つ挙げてください。")), ("output", Str("1. 自分の仕事に対する興味を再発見するために、新しい技能や知識を学ぶこと。\n2. カレッジやセミナーなどで講演を聴くことで、仕事に対する新しいアイデアや視点を得ること。\n3. 仕事に対してストレスを感じている場合 は、ストレスマネジメントのテクニックを学ぶこと。\n4. 仕事以外の楽しいことをすることで、ストレスを発散す ること。\n5. 仕事に対して自己評価をすることで、自分がどのように進化しているのかを知ること。")), ("eval_aspect", Str("- 熱意を取り戻すのではなく、仕事の効率化・スキルアップのような文脈になっていたら1点減点\n- 出したアイデアが5つより多い、少ない場合は1点減点\n- 5つのアイデアのうち、内容が重複しているものがあれば1点減点\n\n"))] }
ここではサイズが小さくて検証によかったので elyza/ELYZA-tasks-100 を使った。
説明
冒頭で指定している candle_datasets::hub::from_hub
を使ってデータセットをダウンロード + Parquet の読み込みを行っている。ここのダウンロードはいつもの HuggingFace エコシステムと同じような感じになる。from_hub
の名前なんのことかよくわからないので datasets ライブラリの雰囲気で load_dataset
という名前にした。
キャッシュディレクトリの指定方法は後述。
let api = Api::new().unwrap();
let repo_name = "elyza/ELYZA-tasks-100".to_string();
let ds = load_dataset(&api, repo_name).unwrap();
API クライアントを作ってレポを読み込んでいる。
while let Some(file) = ds.iter().next() {
Python の datasets ライブラリと異なり、返ってくるのは Parquet ファイルのリストなので、それぞれのリストごとにループする。
let schema = file.metadata().file_metadata().schema();
スキーマは、データセットの構造が指定されているのだと思う。
if let Ok(row) = file.get_row_iter(Some(schema.clone())) {
スキーマに従って行のイテレーターを取得する。(スキーマも get_row_iter
も file
に由来しているのに、一発で行えるメソッドは自分の調べた限りだとなさそう...? こちらは parquet
クレートの領域なので、もしかしたらもっとスマートなやりかたがあるかもしれない)
for iter in row {
if let Ok(iter) = iter {
println!("{:?}", iter);
}
break;
}
1行ずつループしていって、正常に読めたら内容を出力する。今回は1行目で終わり。
キャッシュディレクトリの指定
ファイルをダウンロードしてくるためのキャッシュディレクトリは API クライアントの方で指定する。(Pythonのdatasetsでは読み込み時に指定できたが、それはできないみたい?)
use hf_hub::api::sync::Api;
fn main() {
let api = Api::new().unwrap();
の部分を以下に変更する
use hf_hub::api::sync::ApiBuilder; // 変更
use std::path::PathBuf; // 追加。パスの指定用
fn main() {
let builder = ApiBuilder::new();
let builder = builder.with_cache_dir(PathBuf::from("/huggingface/cache"));
let api = builder.build().unwrap();
ApiBuilder
を使ってキャッシュディレクトリを指定し、 build
して Api
を作成している。