🦁
distilgpt2モデルのONNX最適化
概要
本記事ではHugging FaceのTransformersライブラリを使用してdistilgpt2
モデルをONNX形式にエクスポートし、さらにONNXモデルを最適化して前後のモデルをssc4onnxで比較する。
モデルのエクスポート
# !pip install transformers
# !pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# !pip3 install jupyterlab
# !pip install onnx
# !pip install onnxsim
from transformers import GPT2Model, GPT2Tokenizer
import torch
from pathlib import Path
model_name = "distilgpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2Model.from_pretrained(model_name)
# ダミー
input_ids = torch.tensor([tokenizer.encode("Hello, my dog is cute")])
onnx_output_path = Path("distilgpt2_20230307.onnx")
torch.onnx.export(model, input_ids, onnx_output_path, export_params=True, opset_version=11, do_constant_folding=True, input_names=['input_ids'], output_names=['outputs'], dynamic_axes={'input_ids': {0: 'batch_size'}, 'outputs': {0: 'batch_size'}}) # 可変長の入力をサポート
# ONNXモデルを最適化、他にできる処理はあるが今回はsimplifyのみ
!onnxsim distilgpt2_20230307.onnx distilgpt2_20230307_simplified.onnx
ssc4onnxで最適化前後の比較
最適化前
!ssc4onnx -if distilgpt2_20230307.onnx
┌────────────────────────┬────────────┬────────────┐
│ OP Type │ OPs │ Sizes │
├────────────────────────┼────────────┼────────────┤
│ Add │ 57 │ 39.0KiB │
│ Cast │ 19 │ 532.0B │
│ Concat │ 74 │ 1.7KiB │
│ Constant │ 420 │ 1.0MiB │
│ ConstantOfShape │ 6 │ 24.0B │
│ Div │ 19 │ 0.0B │
│ Gather │ 99 │ 150.2MiB │
│ Gemm │ 24 │ 162.2MiB │
│ MatMul │ 12 │ 0.0B │
│ Mul │ 37 │ 39.0KiB │
│ Pow │ 25 │ 0.0B │
│ Range │ 1 │ 0.0B │
│ ReduceMean │ 26 │ 1.6KiB │
│ Reshape │ 74 │ 0.0B │
│ Shape │ 140 │ 0.0B │
│ Slice │ 55 │ 0.0B │
│ Softmax │ 6 │ 168.0B │
│ Split │ 6 │ 648.0B │
│ Sqrt │ 13 │ 0.0B │
│ Squeeze │ 43 │ 2.7KiB │
│ Sub │ 19 │ 0.0B │
│ Tanh │ 6 │ 0.0B │
│ Transpose │ 30 │ 2.6KiB │
│ Unsqueeze │ 142 │ 8.9KiB │
│ Where │ 6 │ 0.0B │
│ ---------------------- │ ---------- │ ---------- │
│ Total number of OPs │ 1,359 │ │
│ ---------------------- │ ---------- │ ---------- │
│ Total params │ 79.1M │ │
│ ====================== │ ========== │ ========== │
│ Model Size │ 313.6MiB │ 313.5MiB │
└────────────────────────┴────────────┴────────────┘
INFO: file: distilgpt2_20230307.onnx
INFO: producer: pytorch 2.2.1
INFO: opset: 11
INFO: input_name.1: input_ids shape: ['batch_size', 6] dtype: int64
INFO: output_name.1: outputs shape: ['batch_size', 'Reshapeoutputs_dim_1', 'Reshapeoutputs_dim_2'] dtype: float32
INFO: output_name.2: key.3 shape: ['Transposekey.3_dim_0', 12, 'Transposekey.3_dim_2', 64] dtype: float32
INFO: output_name.3: value.3 shape: ['Transposevalue.3_dim_0', 12, 'Transposevalue.3_dim_2', 64] dtype: float32
INFO: output_name.4: key.11 shape: ['Transposekey.11_dim_0', 12, 'Transposekey.11_dim_2', 64] dtype: float32
INFO: output_name.5: value.11 shape: ['Transposevalue.11_dim_0', 12, 'Transposevalue.11_dim_2', 64] dtype: float32
INFO: output_name.6: key.19 shape: ['Transposekey.19_dim_0', 12, 'Transposekey.19_dim_2', 64] dtype: float32
INFO: output_name.7: value.19 shape: ['Transposevalue.19_dim_0', 12, 'Transposevalue.19_dim_2', 64] dtype: float32
INFO: output_name.8: key.27 shape: ['Transposekey.27_dim_0', 12, 'Transposekey.27_dim_2', 64] dtype: float32
INFO: output_name.9: value.27 shape: ['Transposevalue.27_dim_0', 12, 'Transposevalue.27_dim_2', 64] dtype: float32
INFO: output_name.10: key.35 shape: ['Transposekey.35_dim_0', 12, 'Transposekey.35_dim_2', 64] dtype: float32
INFO: output_name.11: value.35 shape: ['Transposevalue.35_dim_0', 12, 'Transposevalue.35_dim_2', 64] dtype: float32
INFO: output_name.12: key.43 shape: ['Transposekey.43_dim_0', 12, 'Transposekey.43_dim_2', 64] dtype: float32
INFO: output_name.13: value.43 shape: ['Transposevalue.43_dim_0', 12, 'Transposevalue.43_dim_2', 64] dtype: float32
INFO: Finish!
最適化後
!ssc4onnx -if distilgpt2_20230307_simplified.onnx
┌────────────────────────┬────────────┬────────────┐
│ OP Type │ OPs │ Sizes │
├────────────────────────┼────────────┼────────────┤
│ Add │ 51 │ 57.1KiB │
│ Div │ 19 │ 24.0B │
│ Gather │ 1 │ 147.2MiB │
│ Gemm │ 24 │ 162.2MiB │
│ MatMul │ 12 │ 0.0B │
│ Mul │ 37 │ 39.1KiB │
│ Pow │ 19 │ 76.0B │
│ ReduceMean │ 26 │ 1.6KiB │
│ Reshape │ 66 │ 1.5KiB │
│ Softmax │ 6 │ 168.0B │
│ Split │ 6 │ 648.0B │
│ Sqrt │ 13 │ 0.0B │
│ Sub │ 13 │ 0.0B │
│ Tanh │ 6 │ 0.0B │
│ Transpose │ 30 │ 2.6KiB │
│ Where │ 6 │ 240.0B │
│ ---------------------- │ ---------- │ ---------- │
│ Total number of OPs │ 335 │ │
│ ---------------------- │ ---------- │ ---------- │
│ Total params │ 77.4M │ │
│ ====================== │ ========== │ ========== │
│ Model Size │ 309.5MiB │ 309.5MiB │
└────────────────────────┴────────────┴────────────┘
INFO: file: distilgpt2_20230307_simplified.onnx
INFO: producer: pytorch 2.2.1
INFO: opset: 11
INFO: input_name.1: input_ids shape: ['unk__0', 6] dtype: int64
INFO: output_name.1: outputs shape: ['batch_size', 6, 768] dtype: float32
INFO: output_name.2: key.3 shape: ['Transposekey.3_dim_0', 12, 6, 64] dtype: float32
INFO: output_name.3: value.3 shape: ['Transposevalue.3_dim_0', 12, 6, 64] dtype: float32
INFO: output_name.4: key.11 shape: ['Transposekey.11_dim_0', 12, 6, 64] dtype: float32
INFO: output_name.5: value.11 shape: ['Transposevalue.11_dim_0', 12, 6, 64] dtype: float32
INFO: output_name.6: key.19 shape: ['Transposekey.19_dim_0', 12, 6, 64] dtype: float32
INFO: output_name.7: value.19 shape: ['Transposevalue.19_dim_0', 12, 6, 64] dtype: float32
INFO: output_name.8: key.27 shape: ['Transposekey.27_dim_0', 12, 6, 64] dtype: float32
INFO: output_name.9: value.27 shape: ['Transposevalue.27_dim_0', 12, 6, 64] dtype: float32
INFO: output_name.10: key.35 shape: ['Transposekey.35_dim_0', 12, 6, 64] dtype: float32
INFO: output_name.11: value.35 shape: ['Transposevalue.35_dim_0', 12, 6, 64] dtype: float32
INFO: output_name.12: key.43 shape: ['Transposekey.43_dim_0', 12, 6, 64] dtype: float32
INFO: output_name.13: value.43 shape: ['Transposevalue.43_dim_0', 12, 6, 64] dtype: float32
INFO: Finish!
- Operator数 1,359から335に削減
- モデルサイズ 313.6MiBから309.5MiBへと減少
- 計算集約のopであるGatherが最適化の過程で影響を受けている。99から1に削減
- simplifyすると可変長のテンソルは
unk__xxx
表記になる?
TODO:
- モデルがシーケンス長に対しても動的軸をサポートしているか確認
- シーケンスの長さを変えてinferenceをテスト
- 可変サイズでinferenceをテスト
ex.
texts = ["Hello, my dog is cute", "Good morning", "I enjoyed the movie last night very much"]
- opsetの違いによるモデルの違い(主にサイズ)をテスト
Discussion