Open5

[WIP]Rustでplamo-embedding-1bを使ってベクトル化する

Yos_KYos_K

動機:容量の大きなドキュメントだったり、数が増えてくると処理に時間がかかってしまうのでRustでやると処理時間が短くなるのではとかちょっとした期待から

Yos_KYos_K
Yos_KYos_K

流れ

  1. PythonのTransformersを使ってplamo-embedding-1bのonnxファイルとtokenizerを生成
  2. Rustでsentencepieceとortを使ってベクトル化
Yos_KYos_K
  1. Transformersを使ってplamo-embedding-1bのonnxファイルとtokenizerを生成
from transformers import AutoModel, AutoTokenizer
from pathlib import Path
import torch


# モデルとトークナイザーのロード
tokenizer = AutoTokenizer.from_pretrained("pfnet/plamo-embedding-1b", trust_remote_code=True)
model = AutoModel.from_pretrained("pfnet/plamo-embedding-1b", trust_remote_code=True)

# tokenizerをディレクトリに保存
tokenizer.save_pretrained("./tokenizer_dir")

# ONNX出力先
onnx_path = Path("./onnx/plamo-embedding-1b.onnx")

dummy_text = "これはサンプル文です。"
inputs = tokenizer(dummy_text, return_tensors="pt")

# ONNXエクスポート
torch.onnx.export(
    model,
    (inputs["input_ids"], inputs.get("attention_mask", None)),
    str(onnx_path),
    input_names=["input_ids", "attention_mask"],
    output_names=["output"],
    dynamic_axes={"input_ids": {0: "batch_size"}, "attention_mask": {0: "batch_size"}},
    opset_version=14,
    use_external_data_format=True
)
Yos_KYos_K
use std::sync::Arc;
use std::fs;
use serde::Deserialize;
use std::collections::HashMap;

use sentencepiece::SentencePieceProcessor;
use ort::{Environment, SessionBuilder, Value};
use ndarray::{Array2, CowArray};


#[derive(Debug, Deserialize)]
pub struct AddedToken {
    pub content: String,
    pub lstrip: bool,
    pub normalized: bool,
    pub rstrip: bool,
    pub single_word: bool,
    pub special: bool,
}

#[derive(Deserialize)]
pub struct TokenizerConfig {
    pub add_bos_token: bool,
    pub add_eos_token: bool,
    pub added_tokens_decoder: HashMap<String, AddedToken>,
    pub auto_map: HashMap<String, Vec<Option<String>>>,
    pub bos_token: Option<String>,
    pub clean_up_tokenization_spaces: bool,
    pub cls_token: Option<String>,
    pub eos_token: Option<String>,
    pub extra_special_tokens: HashMap<String, String>,
    pub local_file_only: bool,
    pub mask_token: Option<String>,
    pub model_max_length: u128,
    pub pad_token: Option<String>,
    pub sep_token: Option<String>,
    pub sp_model_kwargs: HashMap<String, String>,
    pub tokenizer_class: String,
    pub unk_token: Option<String>,
}


fn main() -> anyhow::Result<()> {
    let expected_len = 8;

    // tokenizer_config.jsonの読み込み
    let config_json = fs::read_to_string("./tokenizer/tokenizer_config.json")?;
    let config: TokenizerConfig = serde_json::from_str(&config_json)?;

    let bos_token = config.bos_token.unwrap();
    let eos_token = config.eos_token.unwrap();
    let pad_token = config.pad_token.unwrap();

    let mut sp = SentencePieceProcessor::open("./tokenizer/tokenizer.model").unwrap();

    // 特殊トークンIDの取得
    let bos_id = sp.piece_to_id(&bos_token)?;
    let bos_id = bos_id.unwrap() as i64;

    let eos_id = sp.piece_to_id(&eos_token)?;
    let eos_id = eos_id.unwrap() as i64;

    let pad_id = sp.piece_to_id(&pad_token)?;
    let pad_id = pad_id.unwrap() as i64;

    let input_text = "これはサンプル文です。";

    let mut input_ids: Vec<i64> = sp.encode(input_text)?
        .iter()
        .map(|piece| piece.id as i64)
        .collect();

    input_ids.insert(0, bos_id);
    input_ids.push(eos_id);

    let mut attention_mask: Vec<i64> = vec![1; input_ids.len()];

    match input_ids.len().cmp(&expected_len) {
        std::cmp::Ordering::Less => {
            input_ids.resize(expected_len, pad_id);
            attention_mask.resize(expected_len, 0);
        }
        std::cmp::Ordering::Greater => {
            input_ids.truncate(expected_len);
            attention_mask.truncate(expected_len);
        }
        std::cmp::Ordering::Equal => {}
    }

    let input_ids = Array2::from_shape_vec((1, input_ids.len()), input_ids)?.into_dyn();
    let attention_mask = Array2::from_shape_vec((1, attention_mask.len()), attention_mask)?.into_dyn();

    let input_ids_cow = CowArray::from(input_ids).into_dyn();

    let attention_mask_cow = CowArray::from(attention_mask).into_dyn();

    let environment = Arc::new(Environment::builder().build()?);
    let session = SessionBuilder::new(&environment)?
        .with_model_from_file("./onnx/plamo-embedding-1b.onnx")?;

    let input_ids_value = Value::from_array(session.allocator(), &input_ids_cow)?;
    let attention_mask_value = Value::from_array(session.allocator(), &attention_mask_cow)?;

    let outputs = session.run(vec![input_ids_value, attention_mask_value])?;

    // 出力ベクトル取得
    let embedding_tensor = outputs[0].try_extract::<f32>()?;
    let embedding: Vec<f32> = embedding_tensor.view().iter().cloned().collect();
    
    Ok(())
}