💨

Stable DiffusionのUnetをjson化する

2024/03/26に公開

Stable DiffusionのネットワークをAIアクセラレーターに適した形で読ませるためにonnxをjson化して編集したい。

使用したライブラリのVersionは
onnx 1.15.0
torch 2.2.1
transformers 4.39.1
accelerate 0.28.0
huggingface-hub 0.21.4

これより古いとエラーがでてくる(dockerを使ったほうが良さそう)。
例えばtorchのaten::scaled_dot_product_attentionがONNX似できない問題は
https://github.com/pytorch/pytorch/issues/97262
で対応済み

diffusersのソースコード内のdiffusers/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
あるいはここに書かれているところからスクリプトをダウンロードして実行すると
VAE encoder,VAE decoder,text encoder, safety_checker,Unetのonnxファイルが生成される。safety_checker(センシティブ画像のチェックのためのモジュール)が最も大きいのは興味深い。Unetはパラメーターサイズが2GBを超えたため分離されパラメーターはweights.pbとして保存された。

ファイルサイズ
ファイルサイズ
VAE encoder
VAE encoder
Unet拡大
Unet拡大
Attention拡大
Attention拡大
Unet全体
Unet全体

diffusersのUnetをそのままonnx化したものは巨大すぎてNETRONでは見れたもののonnx2jsonやprotocol bufferのダンプは強制終了してしまった。そこでネットワークの層数を減らして出力してみることにした。以下のようにconvert_stable_diffusion_checkpoint_to_onnx.pyを変更した。

# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse
import os
import shutil
from pathlib import Path

import torch
from torch.onnx import export

import onnx
from diffusers.models import UNet2DConditionModel
from diffusers.onnx_utils import OnnxRuntimeModel
from packaging import version

is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11")

def onnx_export(
    model,
    model_args: tuple,
    output_path: Path,
    ordered_input_names,
    output_names,
    dynamic_axes,
    opset,
    use_external_data_format=False,
):
    output_path.parent.mkdir(parents=True, exist_ok=True)
    # PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
    # so we check the torch version for backwards compatibility
    if is_torch_less_than_1_11:
        export(
            model,
            model_args,
            f=output_path.as_posix(),
            input_names=ordered_input_names,
            output_names=output_names,
            dynamic_axes=dynamic_axes,
            do_constant_folding=True,
            use_external_data_format=use_external_data_format,
            enable_onnx_checker=True,
            opset_version=opset,
        )
    else:
        export(
            model,
            model_args,
            f=output_path.as_posix(),
            input_names=ordered_input_names,
            output_names=output_names,
            dynamic_axes=dynamic_axes,
            do_constant_folding=True,
            opset_version=opset,
        )


@torch.no_grad()
def convert_models(output_path: str, opset: int, fp16: bool = False):
    dtype = torch.float16 if fp16 else torch.float32
    if fp16 and torch.cuda.is_available():
        device = "cuda"
    elif fp16 and not torch.cuda.is_available():
        raise ValueError("`float16` model export is only supported on GPUs with CUDA")
    else:
        device = "cpu"
#    pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
    unet=UNet2DConditionModel(
        down_block_types= ["DownBlock2D"],# "CrossAttnDownBlock2D"),
        up_block_types = ["UpBlock2D"],# "CrossAttnUpBlock2D"),
        mid_block_type = ("UNetMidBlock2DCrossAttn"),
        #block_out_channels = (320, 640, 1280, 1280)
        block_out_channels = (1280,1280)
    )
    output_path = Path(output_path)
    
    num_tokens=  77
    text_hidden_size= 1280 #768
    unet_sample_size = 64
    unet_in_channels = unet.config.in_channels #4
    #mat1 and mat2 shapes cannot be multiplied (154x768 and 1280x320)

    unet_path = output_path / "model.onnx"
    print(unet_in_channels ,unet_sample_size )
    onnx_export(
        unet,
        model_args=(
            torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype),
            torch.randn(2).to(device=device, dtype=dtype),
            torch.randn(2, num_tokens, text_hidden_size).to(device=device, dtype=dtype),
            False,
        ),
        output_path=unet_path,
        ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"],
        output_names=["out_sample"],  # has to be different from "sample" for correct tracing
        dynamic_axes={
            "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
            "timestep": {0: "batch"},
            "encoder_hidden_states": {0: "batch", 1: "sequence"},
        },
        opset=opset,
        use_external_data_format=True,  # UNet is > 2GB, so the weights need to be split
    )
    del unet

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--output_path", type=str, required=True, help="Path to the output model.")

    parser.add_argument(
        "--opset",
        default=14,
        type=int,
        help="The version of the ONNX operator set to use.",
    )
    parser.add_argument("--fp16", action="store_true", default=False, help="Export the models in `float16` mode")

    args = parser.parse_args()

    convert_models(args.output_path, args.opset, args.fp16)

convert_stable_diffusion_checkpoint_to_onnx.pyとの主たる違いは

  • Unetのexport部分のみ残した
  • pretrained modelを読み込まずファイルパスを削除、直接Unetを宣言する。
  • 解像度を下げる層(down_block),上げる層()をそれぞれ一層のみに変更
    unet=UNet2DConditionModel(
        down_block_types= ["DownBlock2D"],# "CrossAttnDownBlock2D"),
        up_block_types = ["UpBlock2D"],# "CrossAttnUpBlock2D"),
        mid_block_type = ("UNetMidBlock2DCrossAttn"),
        #block_out_channels = 
        block_out_channels = (1280,1280)#(320, 640, 1280, 1280)
    )

onnx export時に設定するUnetへの入力パラメーターを

    num_tokens=  77
    text_hidden_size= 1280 #768
    unet_sample_size = 64
    unet_in_channels = unet.config.in_channels #4

と決め打ち

ただし

RuntimeError: Failed to import diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion because of the following error (look up to see its traceback):
Failed to import transformers.models.clip.image_processing_clip because of the following error (look up to see its traceback):
Descriptors cannot be created directly.
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.
If you cannot immediately regenerate your protos, some other possible workarounds are:
 1. Downgrade the protobuf package to 3.20.x or lower.
 2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use pure-Python parsing and will be much slower).

More information: https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates

というエラーが出てきたため

export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python

とした。その結果をNETRONでみたものが以下

このjsonをから演算スケジュールを決定できるようにテンソルのメモリ配置などを決定するコンパイラを作っていきたい。

参考

既にIntel CPUやAMD GPU用にonnx runtimeを用いて実行する方法が多く解説され、ベンチマークが取られている。
https://onnxruntime.ai/generative-ai
https://gist.github.com/harishanand95/75f4515e6187a6aa3261af6ac6f61269

ORTStableDiffusionXLPipelineをもちいてモデルを読み込む
https://huggingface.co/docs/diffusers/optimization/onnx

https://github.com/Amblyopius/Stable-Diffusion-ONNX-FP16

https://onnxruntime.ai/docs/tutorials/csharp/stable-diffusion-csharp.html

各GPU向けの最適化とベンチマーク
https://medium.com/microsoftazure/accelerating-stable-diffusion-inference-with-onnx-runtime-203bd7728540
https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md

https://rocm.blogs.amd.com/artificial-intelligence/stable-diffusion-onnx-runtime/README.html

独自実装したネットワークのONNXへのエクスポート
https://zenn.dev/turing_motors/articles/5b56edb7da1d30

Discussion