Open5
[WIP]Rustでplamo-embedding-1bを使ってベクトル化する
流れ
- PythonのTransformersを使ってplamo-embedding-1bのonnxファイルとtokenizerを生成
- Rustでsentencepieceとortを使ってベクトル化
- Transformersを使ってplamo-embedding-1bのonnxファイルとtokenizerを生成
from transformers import AutoModel, AutoTokenizer
from pathlib import Path
import torch
# モデルとトークナイザーのロード
tokenizer = AutoTokenizer.from_pretrained("pfnet/plamo-embedding-1b", trust_remote_code=True)
model = AutoModel.from_pretrained("pfnet/plamo-embedding-1b", trust_remote_code=True)
# tokenizerをディレクトリに保存
tokenizer.save_pretrained("./tokenizer_dir")
# ONNX出力先
onnx_path = Path("./onnx/plamo-embedding-1b.onnx")
dummy_text = "これはサンプル文です。"
inputs = tokenizer(dummy_text, return_tensors="pt")
# ONNXエクスポート
torch.onnx.export(
model,
(inputs["input_ids"], inputs.get("attention_mask", None)),
str(onnx_path),
input_names=["input_ids", "attention_mask"],
output_names=["output"],
dynamic_axes={"input_ids": {0: "batch_size"}, "attention_mask": {0: "batch_size"}},
opset_version=14,
use_external_data_format=True
)
use std::sync::Arc;
use std::fs;
use serde::Deserialize;
use std::collections::HashMap;
use sentencepiece::SentencePieceProcessor;
use ort::{Environment, SessionBuilder, Value};
use ndarray::{Array2, CowArray};
#[derive(Debug, Deserialize)]
pub struct AddedToken {
pub content: String,
pub lstrip: bool,
pub normalized: bool,
pub rstrip: bool,
pub single_word: bool,
pub special: bool,
}
#[derive(Deserialize)]
pub struct TokenizerConfig {
pub add_bos_token: bool,
pub add_eos_token: bool,
pub added_tokens_decoder: HashMap<String, AddedToken>,
pub auto_map: HashMap<String, Vec<Option<String>>>,
pub bos_token: Option<String>,
pub clean_up_tokenization_spaces: bool,
pub cls_token: Option<String>,
pub eos_token: Option<String>,
pub extra_special_tokens: HashMap<String, String>,
pub local_file_only: bool,
pub mask_token: Option<String>,
pub model_max_length: u128,
pub pad_token: Option<String>,
pub sep_token: Option<String>,
pub sp_model_kwargs: HashMap<String, String>,
pub tokenizer_class: String,
pub unk_token: Option<String>,
}
fn main() -> anyhow::Result<()> {
let expected_len = 8;
// tokenizer_config.jsonの読み込み
let config_json = fs::read_to_string("./tokenizer/tokenizer_config.json")?;
let config: TokenizerConfig = serde_json::from_str(&config_json)?;
let bos_token = config.bos_token.unwrap();
let eos_token = config.eos_token.unwrap();
let pad_token = config.pad_token.unwrap();
let mut sp = SentencePieceProcessor::open("./tokenizer/tokenizer.model").unwrap();
// 特殊トークンIDの取得
let bos_id = sp.piece_to_id(&bos_token)?;
let bos_id = bos_id.unwrap() as i64;
let eos_id = sp.piece_to_id(&eos_token)?;
let eos_id = eos_id.unwrap() as i64;
let pad_id = sp.piece_to_id(&pad_token)?;
let pad_id = pad_id.unwrap() as i64;
let input_text = "これはサンプル文です。";
let mut input_ids: Vec<i64> = sp.encode(input_text)?
.iter()
.map(|piece| piece.id as i64)
.collect();
input_ids.insert(0, bos_id);
input_ids.push(eos_id);
let mut attention_mask: Vec<i64> = vec![1; input_ids.len()];
match input_ids.len().cmp(&expected_len) {
std::cmp::Ordering::Less => {
input_ids.resize(expected_len, pad_id);
attention_mask.resize(expected_len, 0);
}
std::cmp::Ordering::Greater => {
input_ids.truncate(expected_len);
attention_mask.truncate(expected_len);
}
std::cmp::Ordering::Equal => {}
}
let input_ids = Array2::from_shape_vec((1, input_ids.len()), input_ids)?.into_dyn();
let attention_mask = Array2::from_shape_vec((1, attention_mask.len()), attention_mask)?.into_dyn();
let input_ids_cow = CowArray::from(input_ids).into_dyn();
let attention_mask_cow = CowArray::from(attention_mask).into_dyn();
let environment = Arc::new(Environment::builder().build()?);
let session = SessionBuilder::new(&environment)?
.with_model_from_file("./onnx/plamo-embedding-1b.onnx")?;
let input_ids_value = Value::from_array(session.allocator(), &input_ids_cow)?;
let attention_mask_value = Value::from_array(session.allocator(), &attention_mask_cow)?;
let outputs = session.run(vec![input_ids_value, attention_mask_value])?;
// 出力ベクトル取得
let embedding_tensor = outputs[0].try_extract::<f32>()?;
let embedding: Vec<f32> = embedding_tensor.view().iter().cloned().collect();
Ok(())
}