🦁

distilgpt2モデルのONNX最適化

2024/03/07に公開

概要

本記事ではHugging FaceのTransformersライブラリを使用してdistilgpt2モデルをONNX形式にエクスポートし、さらにONNXモデルを最適化して前後のモデルをssc4onnxで比較する。

https://huggingface.co/docs/transformers/model_doc/gpt2

モデルのエクスポート

# !pip install transformers
# !pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# !pip3 install jupyterlab
# !pip install onnx
# !pip install onnxsim

from transformers import GPT2Model, GPT2Tokenizer
import torch
from pathlib import Path

model_name = "distilgpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2Model.from_pretrained(model_name)

# ダミー
input_ids = torch.tensor([tokenizer.encode("Hello, my dog is cute")])

onnx_output_path = Path("distilgpt2_20230307.onnx")
torch.onnx.export(model, input_ids, onnx_output_path, export_params=True, opset_version=11, do_constant_folding=True, input_names=['input_ids'], output_names=['outputs'], dynamic_axes={'input_ids': {0: 'batch_size'}, 'outputs': {0: 'batch_size'}})  # 可変長の入力をサポート

# ONNXモデルを最適化、他にできる処理はあるが今回はsimplifyのみ
!onnxsim distilgpt2_20230307.onnx distilgpt2_20230307_simplified.onnx

ssc4onnxで最適化前後の比較

https://github.com/PINTO0309/ssc4onnx

最適化前

!ssc4onnx -if distilgpt2_20230307.onnx

┌────────────────────────┬────────────┬────────────┐
│ OP Type                │ OPs        │ Sizes      │
├────────────────────────┼────────────┼────────────┤
│ Add                    │ 57         │ 39.0KiB    │
│ Cast                   │ 19         │ 532.0B     │
│ Concat                 │ 74         │ 1.7KiB     │
│ Constant               │ 420        │ 1.0MiB     │
│ ConstantOfShape        │ 6          │ 24.0B      │
│ Div                    │ 19         │ 0.0B       │
│ Gather                 │ 99         │ 150.2MiB   │
│ Gemm                   │ 24         │ 162.2MiB   │
│ MatMul                 │ 12         │ 0.0B       │
│ Mul                    │ 37         │ 39.0KiB    │
│ Pow                    │ 25         │ 0.0B       │
│ Range                  │ 1          │ 0.0B       │
│ ReduceMean             │ 26         │ 1.6KiB     │
│ Reshape                │ 74         │ 0.0B       │
│ Shape                  │ 140        │ 0.0B       │
│ Slice                  │ 55         │ 0.0B       │
│ Softmax                │ 6          │ 168.0B     │
│ Split                  │ 6          │ 648.0B     │
│ Sqrt                   │ 13         │ 0.0B       │
│ Squeeze                │ 43         │ 2.7KiB     │
│ Sub                    │ 19         │ 0.0B       │
│ Tanh                   │ 6          │ 0.0B       │
│ Transpose              │ 30         │ 2.6KiB     │
│ Unsqueeze              │ 142        │ 8.9KiB     │
│ Where                  │ 6          │ 0.0B       │
│ ---------------------- │ ---------- │ ---------- │
│ Total number of OPs    │ 1,359      │            │
│ ---------------------- │ ---------- │ ---------- │
│ Total params           │ 79.1M      │            │
│ ====================== │ ========== │ ========== │
│ Model Size             │ 313.6MiB   │ 313.5MiB   │
└────────────────────────┴────────────┴────────────┘
INFO: file: distilgpt2_20230307.onnx
INFO: producer: pytorch 2.2.1
INFO: opset: 11
INFO: input_name.1: input_ids shape: ['batch_size', 6] dtype: int64
INFO: output_name.1: outputs shape: ['batch_size', 'Reshapeoutputs_dim_1', 'Reshapeoutputs_dim_2'] dtype: float32
INFO: output_name.2: key.3 shape: ['Transposekey.3_dim_0', 12, 'Transposekey.3_dim_2', 64] dtype: float32
INFO: output_name.3: value.3 shape: ['Transposevalue.3_dim_0', 12, 'Transposevalue.3_dim_2', 64] dtype: float32
INFO: output_name.4: key.11 shape: ['Transposekey.11_dim_0', 12, 'Transposekey.11_dim_2', 64] dtype: float32
INFO: output_name.5: value.11 shape: ['Transposevalue.11_dim_0', 12, 'Transposevalue.11_dim_2', 64] dtype: float32
INFO: output_name.6: key.19 shape: ['Transposekey.19_dim_0', 12, 'Transposekey.19_dim_2', 64] dtype: float32
INFO: output_name.7: value.19 shape: ['Transposevalue.19_dim_0', 12, 'Transposevalue.19_dim_2', 64] dtype: float32
INFO: output_name.8: key.27 shape: ['Transposekey.27_dim_0', 12, 'Transposekey.27_dim_2', 64] dtype: float32
INFO: output_name.9: value.27 shape: ['Transposevalue.27_dim_0', 12, 'Transposevalue.27_dim_2', 64] dtype: float32
INFO: output_name.10: key.35 shape: ['Transposekey.35_dim_0', 12, 'Transposekey.35_dim_2', 64] dtype: float32
INFO: output_name.11: value.35 shape: ['Transposevalue.35_dim_0', 12, 'Transposevalue.35_dim_2', 64] dtype: float32
INFO: output_name.12: key.43 shape: ['Transposekey.43_dim_0', 12, 'Transposekey.43_dim_2', 64] dtype: float32
INFO: output_name.13: value.43 shape: ['Transposevalue.43_dim_0', 12, 'Transposevalue.43_dim_2', 64] dtype: float32
INFO: Finish!

最適化後

!ssc4onnx -if distilgpt2_20230307_simplified.onnx

┌────────────────────────┬────────────┬────────────┐
│ OP Type                │ OPs        │ Sizes      │
├────────────────────────┼────────────┼────────────┤
│ Add                    │ 51         │ 57.1KiB    │
│ Div                    │ 19         │ 24.0B      │
│ Gather                 │ 1          │ 147.2MiB   │
│ Gemm                   │ 24         │ 162.2MiB   │
│ MatMul                 │ 12         │ 0.0B       │
│ Mul                    │ 37         │ 39.1KiB    │
│ Pow                    │ 19         │ 76.0B      │
│ ReduceMean             │ 26         │ 1.6KiB     │
│ Reshape                │ 66         │ 1.5KiB     │
│ Softmax                │ 6          │ 168.0B     │
│ Split                  │ 6          │ 648.0B     │
│ Sqrt                   │ 13         │ 0.0B       │
│ Sub                    │ 13         │ 0.0B       │
│ Tanh                   │ 6          │ 0.0B       │
│ Transpose              │ 30         │ 2.6KiB     │
│ Where                  │ 6          │ 240.0B     │
│ ---------------------- │ ---------- │ ---------- │
│ Total number of OPs    │ 335        │            │
│ ---------------------- │ ---------- │ ---------- │
│ Total params           │ 77.4M      │            │
│ ====================== │ ========== │ ========== │
│ Model Size             │ 309.5MiB   │ 309.5MiB   │
└────────────────────────┴────────────┴────────────┘
INFO: file: distilgpt2_20230307_simplified.onnx
INFO: producer: pytorch 2.2.1
INFO: opset: 11
INFO: input_name.1: input_ids shape: ['unk__0', 6] dtype: int64
INFO: output_name.1: outputs shape: ['batch_size', 6, 768] dtype: float32
INFO: output_name.2: key.3 shape: ['Transposekey.3_dim_0', 12, 6, 64] dtype: float32
INFO: output_name.3: value.3 shape: ['Transposevalue.3_dim_0', 12, 6, 64] dtype: float32
INFO: output_name.4: key.11 shape: ['Transposekey.11_dim_0', 12, 6, 64] dtype: float32
INFO: output_name.5: value.11 shape: ['Transposevalue.11_dim_0', 12, 6, 64] dtype: float32
INFO: output_name.6: key.19 shape: ['Transposekey.19_dim_0', 12, 6, 64] dtype: float32
INFO: output_name.7: value.19 shape: ['Transposevalue.19_dim_0', 12, 6, 64] dtype: float32
INFO: output_name.8: key.27 shape: ['Transposekey.27_dim_0', 12, 6, 64] dtype: float32
INFO: output_name.9: value.27 shape: ['Transposevalue.27_dim_0', 12, 6, 64] dtype: float32
INFO: output_name.10: key.35 shape: ['Transposekey.35_dim_0', 12, 6, 64] dtype: float32
INFO: output_name.11: value.35 shape: ['Transposevalue.35_dim_0', 12, 6, 64] dtype: float32
INFO: output_name.12: key.43 shape: ['Transposekey.43_dim_0', 12, 6, 64] dtype: float32
INFO: output_name.13: value.43 shape: ['Transposevalue.43_dim_0', 12, 6, 64] dtype: float32
INFO: Finish!
  • Operator数 1,359から335に削減
  • モデルサイズ 313.6MiBから309.5MiBへと減少
  • 計算集約のopであるGatherが最適化の過程で影響を受けている。99から1に削減
  • simplifyすると可変長のテンソルは unk__xxx 表記になる?

TODO:

  • モデルがシーケンス長に対しても動的軸をサポートしているか確認
  • シーケンスの長さを変えてinferenceをテスト
  • 可変サイズでinferenceをテスト
    ex.
texts = ["Hello, my dog is cute", "Good morning", "I enjoyed the movie last night very much"]
  • opsetの違いによるモデルの違い(主にサイズ)をテスト

Discussion