Zenn
🦜

CogView4-6B の量子化を試してみる

2025/03/19に公開

TL;DR

  • テキストエンコーダーをどれだけ精度落とさずに量子化できるか が重要そう
  • テキストエンコーダーには Optimum Quanto Int8 が消費 VRAM 削減しつつ品質あんまり落ちなくて良さそう
    • このとき、デノイザーは bitsandbytes の NF4 が一番消費 VRAM 削減できて、品質重視なら PyTorch float8_e4m3fnOptimum Quanto Int8 が良さそう

はじめに

CogView4-6B は Tsinghua University から公開された、(デノイズする transformer が) 6B パラメータの画像生成モデルです。

https://huggingface.co/THUDM/CogView4-6B

モデルカードの表には、消費 VRAM について以下のように書いてありました。

解像度 オフロードなし オフロードあり テキストエンコーダーを 4bit にしてオフロード
1024x1024 35 GB 20 GB 13 GB

公式ページによると、bfloat16 精度で 1024x1024 解像度の画像をそのまま生成する場合、35 GB VRAM を消費するため、家庭用の GPU では実行が難しいです。モデルの一部を少しずつ VRAM に載せるオフロードを行うと、20GB VRAM まで削減することができますが、それでも RTX 3090/4090/5090 などが必要になります。

量子化を行うことで 16GB に収まれば、より多くの GPU で動かすことができそうですね。既に CogView4-6B に対して量子化を試されている例がいくつかあります:
https://x.com/ostrisai/status/1897344493747355990
https://note.com/hakomikan/n/ne3e483adb45b

というわけで今回は、いくつかの量子化手法を比較 しながら、本当に量子化して生成できるのか/VRAM を削減できるのか の調査をしてみました。

生成結果の比較のみ見たい場合は #生成画像 に飛んでください。

モデル構造

モデルのパーツごとのパラメータ数は以下のようになっています。(カッコ内は diffusers での名称)

Text Encoder (text_encoder) Denoiser (transformer) VAE (vae)
9B 6B 0.4B

テキストエンコーダー

テキストエンコーダーには、同じ Tsinghua University から公開されている 9B の Decoder-only の LLM である、THUDM/glm-4-9b の一部が使用されています。

余談: テキストエンコーダーの層の選択

ベースとなる THUDM/glm-4-9b では、transformer ブロックが 40 層ある[1] のに対して、CogView4-6B ではそのうちの 最後の層を除いた 39 層のみを利用 しています [2]。言い換えると、最後から2番目の層を使用 しています。また、テキストエンコーダーではテキストは生成しないため、単語の予測に使われる LM ヘッドも付属していません。

テキストエンコーダーの最後から2番目の層を使うのは、NovelAI DiffusionStable Diffusion XL 等でも採用されていました。ソースを忘れてしまったのですが、実際にレイヤーの層ごとに SAE を学習すると、最終層は再構成が簡単でロスが低くなるのに、最後から2番目ではロスが高くなる(情報量が多くて再構成が難しい)というのがあるというのを聞いたことがあります。要は、最後の層の出力はトークンの選択用等に情報が疎になってしまうので、より情報が密になっている最後から2番目の層がよく使われるようです。

デノイザー

diffusers では transformer と呼ばれる部分です。Stable Diffusion 等では U-Net がデノイズを担っていたので、単に U-Net といえばどの部分か明確だったのですが、CogView4 を含む最近のモデルでは、テキストエンコーダーもデノイズも transformer を採用しているため、どっちのことを言っているのか紛らわしいので、この記事ではデノイザーと呼ぶ ことにします。

基本的な情報は以下の記事を参考にしてください。
https://zenn.dev/discus0434/articles/cogview4-6b-commentary#ditとadaln-zeroについて

余談: 従来の MMDiT との違い

CogView4 の DiT は、基本的には SD3 のような MMDiT をベースにしていますが、ところどころ異なる点があるので、少し紹介します。

  • SDXL 的な画像のクロップ条件
    • 学習時にSDXL のように、元画像の画像サイズやクロップ位置、生成サイズの情報を Pixart-Alpha と同じような方法でタイムステップ埋め込みと一緒に条件に加えています [3]
  • Transformer の Feed Forward 使い回し:
    • MMDiT の Transformer は通常、テキスト条件と画像パッチでそれぞれ別の Feed Forward に通すのが普通 [4][5] ですが、CogView4 では同じ Feed Forward に通しています [6]。Feed Forward はパラメータ数がデカくなりがちなので、使い回してパラメータ数削減を図っているのかもみたいな感じっぽいです。

コード

以下のレポジトリで作業を行いました。

https://github.com/p1atdev/vision-ft

ただ、頻繁に変更を加えるので、この記事を書いた時点とコード内容が変わっている可能性があることに注意してください。

この記事を書いた時点で使用したライブラリのバージョンは以下です:

名前 バージョン
Python 3.12.2
PyTorch 2.5.1+cu124
bitsandbytes 0.45.3
TorchAO 0.9.0
Optimum Quanto 0.2.6
Flash Attention 2.7.4.post1

量子化手法

今回比較する量子化手法を紹介します。

bitsandbytes

https://github.com/bitsandbytes-foundation/bitsandbytes

NF4 という 4bit への量子化手法が通常の FP4 や FP8 と比べても比較的精度を落とさずに量子化できるので、LLM ではよく採用されているイメージです。

bitsandbytes での量子化は 前回の記事 で書いたような方法で量子化を行いました。

https://github.com/p1atdev/vision-ft/blob/main/src/modules/quant/bnb.py

https://github.com/p1atdev/vision-ft/blob/b4bfe9c1f25d7aa73fcd68cf39f3719b86b18f40/src/modules/quant/functional.py#L159-L183

TorchAO

https://github.com/pytorch/ao

PyTorch 公式の最適化ライブラリです。NF4、Int4、Int8 に対応しています。今回は NF4 と Int8 を試します。

実装では、NF4 は、TorchAO で用意されている NF4Tensor に置き換える形で、Int8 は Float8Linear を単にラップした層で置き換えました:
https://github.com/p1atdev/vision-ft/blob/main/src/modules/quant/ao.py

https://github.com/p1atdev/vision-ft/blob/b4bfe9c1f25d7aa73fcd68cf39f3719b86b18f40/src/modules/quant/functional.py#L185-L200

Optimum Quanto

https://github.com/huggingface/optimum-quanto

HuggingFace による量子化ライブラリで、transformers や diffusers ライブラリとの統合がメインに推されているイメージです。Int4 と Int8 に対応しています。

実装では、単に optimu.quanto.nn.QLinear をラップして、from_moduleweights のみを指定して置き換えました:

https://github.com/p1atdev/vision-ft/blob/b4bfe9c1f25d7aa73fcd68cf39f3719b86b18f40/src/modules/quant/functional.py#L202-L221

PyTorch の float8_e4m3fn

https://pytorch.org/docs/2.4/tensors.html

PyTorch には素で fp8 の Data Type が存在するので、これも試します。これを使うのはすごく簡単で、普段 tensor.to(torch.float16) とかで変換するのと同じように tensor.to(torch.float8_e4m3fn) で変換できます。

https://github.com/p1atdev/vision-ft/blob/b4bfe9c1f25d7aa73fcd68cf39f3719b86b18f40/src/modules/quant/functional.py#L223-L225

公式の方法

公式 GitHub には2種類の量子化の方法がコメントアウトで紹介されています:

  • テキストエンコーダーのみ bitsandbytes NF4

https://github.com/THUDM/CogView4/blob/962816cc760188032713dc5293c4588d42fe88e5/inference/cli_demo_cogview4.py#L20-L32

  • テキストエンコーダーもデノイザーも TorchAO Int8

https://github.com/THUDM/CogView4/blob/962816cc760188032713dc5293c4588d42fe88e5/inference/cli_demo_cogview4_int8.py#L20-L47

結果

以下のような設定で生成を行いました。

生成設定

プロンプトは、ComfyUI の Flux のサンプルワークフロー で使用されているものを使用しました。

パラメータ
prompt cute anime girl with massive fluffy fennec ears and a big fluffy tail blonde messy long hair blue eyes wearing a maid outfit with a long black gold leaf pattern dress and a white apron mouth open holding a fancy black forest cake with candles on top in the kitchen of an old dark Victorian mansion lit by candlelight with a bright window to the foggy forest and very expensive stuff everywhere
negative prompt blurry, low quality, horror
height 1024
width 1024
cfg_scale 3.5
num_inference_steps 20
device cuda:0
seed 0

GPU は RTX A6000 Ada (VRAM 48GB) を使用しました。

量子化の設定
  • テキストエンコーダーの対象層
    • ["q_proj", "k_proj", "v_proj", "o_proj", "mlp.down_proj", "mlp.gate_up_proj"]
    • Transformer の 線形層全部
  • デノイザーの対象層
    • ["to_q", "to_k", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2"]
    • AdaLayerNormZero で使われる埋め込み作成に使われる norm を除いたTransformer の線形層
スクリプト

読み込んでる bfloat16 の単一 safetensors はここに置いてあります:
https://huggingface.co/p1atdev/CogView4-6B-bf16-AIO/blob/main/cogview4-6b.bf16.safetensors

./tools/cogview4_quant_compare.py
import click

from PIL import Image
from pathlib import Path

import torch
import torch.nn as nn

from src.models.cogview4.pipeline import CogView4Model
from src.models.cogview4.config import CogView4Config, DenoiserConfig

from src.modules.quant import quantize_inplace, QUANT_TYPE


def quantize_model(model: nn.Module, text_encoder: QUANT_TYPE, denoiser: QUANT_TYPE):
    if text_encoder != "bf16":
        quantize_inplace(  # text encoder
            model,
            quant_type=text_encoder,
            include_keys=[
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "mlp.down_proj",
                "mlp.gate_up_proj",
            ],
            exclude_keys=["denoiser.", "vae."],
        )
    if denoiser != "bf16":
        quantize_inplace(  # denoiser
            model,
            quant_type=denoiser,
            include_keys=[
                "to_q",
                "to_k",
                "to_v",
                "to_out.0",
                "ff.net.0.proj",
                "ff.net.2",
            ],
            exclude_keys=[
                "time_condition_embed",
                "patch_embed",
                "norm_out",
                "proj_out",
                "norm1",  # do not quantize layernorm
                "text_encoder.",
                "vae.",
            ],
        )


@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
@torch.inference_mode()
def generate_image(
    model: CogView4Model,
    prompt: str,
    height: int,
    width: int,
    cfg_scale: float,
    num_inference_steps: int,
    do_offloading: bool,
    device: str,
    seed: int,
) -> Image.Image:
    image = model.generate(
        prompt=prompt,
        negative_prompt="blurry, low quality, horror",
        height=height,
        width=width,
        cfg_scale=cfg_scale,
        num_inference_steps=num_inference_steps,
        do_offloading=do_offloading,
        device=device,
        seed=seed,
    )[0]

    return image


def get_run_name(
    text_encoder: QUANT_TYPE, denoiser: QUANT_TYPE, skip_offload: bool
) -> str:
    return f"text-encoder-{text_encoder}_denoiser-{denoiser}_offload-{not skip_offload}"


@click.command()
@click.option("--model_path", default="./models/cogview4-6b.bf16.safetensors")
@click.option("--text_encoder", default="bf16", type=str)
@click.option("--denoiser", default="bf16", type=str)
@click.option("--skip_offload", is_flag=True)
@click.option(
    "--prompt",
    default="cute anime girl with massive fluffy fennec ears and a big fluffy tail blonde messy long hair blue eyes wearing a maid outfit with a long black gold leaf pattern dress and a white apron mouth open holding a fancy black forest cake with candles on top in the kitchen of an old dark Victorian mansion lit by candlelight with a bright window to the foggy forest and very expensive stuff everywhere",
)
@click.option("--height", default=1024)
@click.option("--width", default=1024)
@click.option("--cfg_scale", default=3.5)
@click.option("--num_inference_steps", default=20)
@click.option("--device", default="cuda:0")
@click.option("--seed", default=0)
@click.option("--output_dir", default="output")
def main(
    model_path: str,
    text_encoder: QUANT_TYPE,
    denoiser: QUANT_TYPE,
    skip_offload: bool,
    prompt: str,
    height: int,
    width: int,
    cfg_scale: float,
    num_inference_steps: int,
    device: str,
    seed: int,
    output_dir: str,
):
    torch.cuda.memory._record_memory_history()

    config = CogView4Config(
        checkpoint_path=model_path,
        denoiser=DenoiserConfig(
            attention_backend="flash_attention_2",
        ),
    )
    model = CogView4Model.from_checkpoint(config)

    quantize_model(model, text_encoder, denoiser)

    if skip_offload:
        model.to(device)
    else:
        model.to("cpu")

    image = generate_image(
        model,
        prompt=prompt,
        height=height,
        width=width,
        cfg_scale=cfg_scale,
        num_inference_steps=num_inference_steps,
        do_offloading=not skip_offload,
        device=device,
        seed=seed,
    )

    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    run_name = get_run_name(text_encoder, denoiser, skip_offload)
    image.save(output_path / f"{run_name}.webp")
    print(f"Image saved to {output_path / f'{run_name}.webp'}")

    torch.cuda.memory._dump_snapshot((output_path / f"{run_name}.pickle").as_posix())


if __name__ == "__main__":
    main()
./scripts/cogview4_quant_compare.sh
#!/bin/bash

source .venv/bin/activate

text_encoders=("bf16" "fp8_e4m3fn" "bnb_int8" "bnb_fp4" "bnb_nf4" "quanto_int4" "quanto_int8" "ao_nf4" "ao_fp8")
denoisers=("bf16" "fp8_e4m3fn" "bnb_int8" "bnb_fp4" "bnb_nf4" "quanto_int4" "quanto_int8" "ao_nf4" "ao_fp8")
skip_offload=("false" "true")

for te in "${text_encoders[@]}"; do
  for dn in "${denoisers[@]}"; do
    for skip in "${skip_offload[@]}"; do
      echo "Running with text_encoder=${te}, denoiser=${dn}, skip_offload=${skip}"
      if [ "$skip" == "true" ]; then
        python ./tools/cogview4_quant_compare.py --text_encoder "${te}" --denoiser "${dn}" --skip_offload
      else
        python ./tools/cogview4_quant_compare.py --text_encoder "${te}" --denoiser "${dn}"
      fi
    done
  done
done

メモリアロケーションサイズの確認用 (ChatGPT が書いた):

./tools/check_memory.py
import pickle
import click
from typing import Any, Dict, List


def load_snapshot(file_path: str) -> Dict[str, Any]:
    with open(file_path, "rb") as f:
        snapshot = pickle.load(f)
    return snapshot


def format_bytes(size: float) -> str:
    for unit in ["B", "KB", "MB", "GB", "TB"]:
        if size < 1024:
            return f"{size:.2f} {unit}"
        size /= 1024
    return f"{size:.2f} PB"


def find_peak_allocated_memory(snapshot: Dict[str, Any]) -> int:
    """
    記録されたtraceイベントから、alloc/freeのシミュレーションを行い、
    ピーク時の総アロケーションサイズを計算する。

    ※ここでは、"alloc" イベントでメモリを加算し、
    "free_completed" イベントでメモリを減算しています。
    """
    peak_memory: int = 0
    current_memory: int = 0
    # device_traces は各デバイスのイベントリスト(各イベントは dict として記録)
    device_traces: List[List[Dict[str, Any]]] = snapshot.get("device_traces", [])
    for trace in device_traces:
        for event in trace:
            action: str = event.get("action", "")
            size: int = event.get("size", 0)
            if action == "alloc":
                current_memory += size
            elif action == "free_completed":
                current_memory -= size
            # その他のイベント(例:"segment_alloc", "segment_free" など)も必要に応じて処理する
            if current_memory > peak_memory:
                peak_memory = current_memory
    return peak_memory


@click.command()
@click.argument("pickle_path", type=click.Path(exists=True))
def main(pickle_path: str) -> None:
    snapshot = load_snapshot(pickle_path)
    peak_memory = find_peak_allocated_memory(snapshot)
    print(f"ピーク時の合計アロケーションサイズ: {format_bytes(float(peak_memory))}")


if __name__ == "__main__":
    main()

生成結果

以下の表に、Text encoder と Denoiser それぞれに量子化を適用して生成した画像 (クリックで拡大表示) と最大消費 VRAM を示します (単位は GiB1 GiB = 1024 x 1024 x 1024 byte)。カッコ外はオフロードなし、(カッコ内)はオフロードありです。太字は、量子化なしと比較して改善されて、かつ同じ行の中で最も消費が少なかったものを強調しています。

Denoiser→ ↓Text_encoder PyTorch bfloat16 PyTorch float8_e4m3fn bitsandbytes FP4 bitsandbytes NF4 bitsandbytes Int8 Quanto Int4 Quanto Int8 TorchAO NF4 TorchAO FP8
PyTorch bfloat16 33.74 (16.43) 30.49 (16.43) 25.95 (16.43) 25.95 (16.43) 33.74 (16.43) 26.03 (16.43) 28.49 (16.43) 33.74 (16.43) 33.74 (16.43)
PyTorch float8_e4m3fn 28.14 (13.14) 22.89 (9.04) 20.35 (9.04) 20.35 (9.04) 28.15 (13.68) 20.43 (9.04) 22.89 (9.04) 28.14 (13.14) 28.14 (13.64)
bitsandbytes FP4 24.46 (13.14) 19.21 (7.89) 16.67 (7.52) 16.67 (7.52) 24.47 (13.68) 16.75 (7.52) 19.22 (7.90) 24.46 (13.14) 24.46 (13.64)
bitsandbytes NF4 24.46 (13.14) 19.21 (7.89) 16.67 (7.52) 16.67 (7.52) 24.47 (13.68) 16.75 (7.52) 19.22 (7.90) 24.46 (13.14) 24.46 (13.64)
bitsandbytes Int8 35.75 (20.75) 30.50 (16.58) 27.96 (16.58) 27.96 (16.58) 35.75 (21.28) 28.04 (16.58) 30.50 (16.58) 35.75 (20.75) 35.75 (21.24)
Quanto Int4 24.58 (13.14) 19.33 (7.89) 16.79 (7.52) 16.79 (7.52) 24.59 (13.68) 16.87 (7.52) 19.33 (7.90) 24.58 (13.14) 24.58 (13.64)
Quanto Int8 28.15 (13.14) 22.89 (9.04) 20.36 (9.04) 20.36 (9.04) 28.15 (13.68) 20.44 (9.04) 22.90 (9.04) 28.15 (13.14) 28.15 (13.64)
TorchAO NF4 35.74 (16.43) 30.49 (16.43) 27.95 (16.43) 27.95 (16.43) 35.75 (16.43) 28.03 (16.43) 30.49 (16.43) 35.74 (16.43) 35.74 (16.43)
TorchAO FP8 35.74 (17.36) 30.49 (17.36) 27.95 (17.36) 27.95 (17.36) 35.75 (17.36) 28.03 (17.36) 30.49 (17.36) 35.74 (17.36) 35.74 (17.36)
________ ________________ ________________ ________________ ________________ ________________ ________________ ________________ ________________ ________________
消費 VRAM のみの表
Denoiser→ ↓Text_encoder PyTorch bfloat16 PyTorch float8_e4m3fn bitsandbytes FP4 bitsandbytes NF4 bitsandbytes Int8 Quanto Int4 Quanto Int8 TorchAO NF4 TorchAO FP8
PyTorch bfloat16 33.74 (16.43) 30.49 (16.43) 25.95 (16.43) 25.95 (16.43) 33.74 (16.43) 26.03 (16.43) 28.49 (16.43) 33.74 (16.43) 33.74 (16.43)
PyTorch float8_e4m3fn 28.14 (13.14) 22.89 (9.04) 20.35 (9.04) 20.35 (9.04) 28.15 (13.68) 20.43 (9.04) 22.89 (9.04) 28.14 (13.14) 28.14 (13.64)
bitsandbytes FP4 24.46 (13.14) 19.21 (7.89) 16.67 (7.52) 16.67 (7.52) 24.47 (13.68) 16.75 (7.52) 19.22 (7.90) 24.46 (13.14) 24.46 (13.64)
bitsandbytes NF4 24.46 (13.14) 19.21 (7.89) 16.67 (7.52) 16.67 (7.52) 24.47 (13.68) 16.75 (7.52) 19.22 (7.90) 24.46 (13.14) 24.46 (13.64)
bitsandbytes Int8 35.75 (20.75) 30.50 (16.58) 27.96 (16.58) 27.96 (16.58) 35.75 (21.28) 28.04 (16.58) 30.50 (16.58) 35.75 (20.75) 35.75 (21.24)
Quanto Int4 24.58 (13.14) 19.33 (7.89) 16.79 (7.52) 16.79 (7.52) 24.59 (13.68) 16.87 (7.52) 19.33 (7.90) 24.58 (13.14) 24.58 (13.64)
Quanto Int8 28.15 (13.14) 22.89 (9.04) 20.36 (9.04) 20.36 (9.04) 28.15 (13.68) 20.44 (9.04) 22.90 (9.04) 28.15 (13.14) 28.15 (13.64)
TorchAO NF4 35.74 (16.43) 30.49 (16.43) 27.95 (16.43) 27.95 (16.43) 35.75 (16.43) 28.03 (16.43) 30.49 (16.43) 35.74 (16.43) 35.74 (16.43)
TorchAO FP8 35.74 (17.36) 30.49 (17.36) 27.95 (17.36) 27.95 (17.36) 35.75 (17.36) 28.03 (17.36) 30.49 (17.36) 35.74 (17.36) 35.74 (17.36)
________ ____________ ____________ ____________ ____________ ____________ ____________ ____________ ____________ ____________

真っ黒な画像は途中でオーバーフローかアンダーフローによって NaN が発生してしまったものになります。

bitsandbytes の Int8 は、計算の途中で float16 にキャストされるらしく[7]、これを適用したものは全て NaN が発生して真っ黒な画像になっています。これは、モデルカードに書かれていた、float16 に対応していないという報告に合致していそうです。

ざっくりと眺めると、テキストエンコーダーの精度がデノイザーよりも重要 になることが伺えます。表の行(横→)に注目すると、テキストエンコーダーの量子化手法を固定した時、(生成に失敗しているものを除くと、) デノイザーの量子化手法の違いによる品質の劣化は相対的に小さい ですが、列(縦↓)に注目すると、テキストエンコーダーの量子化手法の違いによる品質の劣化は大きい ことがわかります。

このことを踏まえると、テキストエンコーダーに適用できそうな量子化手法は、

  • PyTorch bfloat16 (量子化なし)
  • Optimum Quanto Int4
  • Optimum Quanto Int8
  • TorchAO NF4
  • TorchAO FP8

になりそうです。

しかし、消費 VRAM に注目すると、TorchAO では量子化しない場合と比べて、比較的 VRAM 消費が減っていません。むしろ増えている組み合わせもあります。
視覚的な品質の劣化が少なく、消費 VRAM の少ないものは、テキストエンコーダーに Optimum Quanto Int4/Int8 を適用したものになりそうです。Quanto の行のみを取り出して見てみます。(bitsandbytes の Int8 は除外)

量子化なし参考画像


両方 bfloat16 で消費 VRAM 33.74 GiB (16.43 GiB)

Denoiser→ ↓Text_encoder PyTorch bfloat16 PyTorch float8_e4m3fn bitsandbytes FP4 bitsandbytes NF4 Quanto Int4 Quanto Int8 TorchAO NF4 TorchAO FP8
Quanto Int4 24.58 (13.14) 19.33 (7.89) 16.79 (7.52) 16.79 (7.52) 16.87 (7.52) 19.33 (7.90) 24.58 (13.14) 24.58 (13.64)
Quanto Int8 28.15 (13.14) 22.89 (9.04) 20.36 (9.04) 20.36 (9.04) 20.44 (9.04) 22.90 (9.04) 28.15 (13.14) 28.15 (13.64)
________ ________________ ________________ ________________ ________________ ________________ ________________ ________________ ________________

こうして見てみると、テキストエンコーダーのみを Quanto Int4 で量子化したときでは、Int8 と比べると生成結果の品質は大きくは劣化していないものの、すぐわかる程度の変化が見られます。テキストエンコーダーのみを Quanto Int8 にした場合は、生成結果に差異はほとんど見られませんが、VRAM の削減は乏しいです。

テキストエンコーダーを Quanto Int8 で固定した時、最も VRAM 消費が少ないのは、デノイザーに PyTorch の float8_e4m3fnbitsandbytes の FP4/NF4、または Quanto Int8 を適用してオフロードを行ったときで 9.04 GiB となりました。NF4 の場合のほうが FP4 よりも若干出力内容の差異が少なそうです。Quanto の Int4 と Int8 を比較すると、Int4 の方は茶色くぼやけて しまっており、Int8 のほうが品質が高いと言えるでしょう。オフロードしない場合では、bitsandbytes は 20.36 GiB、Quanto Int8 では 22.44 GiB と、2.08 GiB だけ bitsandbytes のほうが消費が少なく、PyTorch float8_e4m3fn は 22.89 GiB と bitsandbytes と比べて PyTorch float8_e4m3fn の方が 2.56 GiB 多い結果となりました。 TorchAO は あんまり VRAM 消費を抑えられないようです。

いくつか良さげな品質と VRAM 消費の組み合わせをまとめてみます。以下は TE:テキストエンコーダー DiT:デノイザー の量子化組み合わせです。(太字は最小の VRAM 消費)

TE:bf16, DiT:bf16 TE:quanto_int8 DiT:bf16 TE:quanto_int8 DiT:fp8_e4m3fn TE:quanto_int8 DiT:bnb_nf4 TE:quanto_int8 DiT:quanto_int8
33.74 GiB (16.43 GiB) 28.15 GiB (13.14 GiB) 22.89 GiB (9.04 GiB) 20.36 GiB (9.04 GiB) 22.90 GiB (9.04 GiB)
________________ ________________ ________________ ________________ ________________

ここから、テキストエンコーダーは Optimum Quanto Int8デノイザーは Pytorch float8_e4m3fn、bitsandbytes NF4、または Optimum Quanto Int8 を適用したものが品質の劣化を抑えながら VRAM 消費も抑えられる選択肢となりそうです。残念ながら、オフロードせずに VRAM 16 GB に収めることはできませんでしたが、オフロードした場合は収められそうですね。

ボトルネック

先程テキストエンコーダーが Optimum Quanto Int8 のとき、デノイザーの3つの量子化手法の間で 非オフロード時の消費 VRAM が異なるのにも関わらず、オフロード時の消費 VRAM が 9.04 GiB と 3 つとも等しくなっていた ことに気づいたかもしれません。このとき、なぜこのようなことが発生するのか、どの部分が VRAM 消費のボトルネックになっているか調べてみました。

以下は、テキストエンコーダーに Quanto Int8、デノイザーに PyTorch float8_e4m3fn を適用したときの VRAM 消費の様子を表したグラフです。左側の数値が消費 VRAM になります。参考: https://pytorch.org/blog/understanding-gpu-memory-2/

  • オフロードなしの場合:


最初に全てのモジュールを17GBほどで VRAM に載せた読み込んだあと、18GB程度の消費でデノイザーが仕事しているのが伺える。最後に VAE の使用で2~3GBほど消費されているようだ。

  • オフロードありの場合:


最初にテキストエンコーダーで9GBほど読み込んだあと、オフロードによって消費 VRAMがちゃんとゼロになっている。その後、デノイザーは8GB前後の消費で動作して、VAE の使用時も 3GB ほどの消費で済んでいる。

オフロードをしなかった場合は、普通に全てのコンポーネントで 17 GBを消費し、画像のデコードで 3 GB消費していたことから、合計で 20 GiB ほどの消費となったみたいです。
一方、オフロードをした場合では、最初のテキストエンコーダーで 9GB を消費した以降は、それを超える VRAM 消費がないため、デノイザー部分がそれよりも小さい場合、常にテキストエンコーダーの時の消費が最大消費となり、量子化手法に関わらず最大消費 VRAM 量がテキストエンコーダーの分で等しくなっていた ようです。

CogView4-6B では、デノイザーよりもテキストエンコーダー部分のほうがパラメータ数が多いため、オフロード時の最大消費 VRAM はテキストエンコーダーの消費 VRAM に依存している、ということがわかります。したがって、いかにテキストエンコーダーを品質を落とさずに量子化できるかが重要 になりそうです。加えて、最初の表の縦向きと横向きを比較してわかる通り、テキストエンコーダーは量子化手法による品質の変化に敏感で、デノイザーは比較的鈍感なため、より一層テキストエンコーダーの量子化手法の選定が重要だと思います。

おわりに

今回は CogView4-6B の量子化を試してみました。今回試した手法では、残念ながら品質を維持したまま VRAM 8GB に収めることはできませんでしたが、VRAM 16 GB であれば現実的な品質で生成ができそうです。テキストエンコーダーの出力をキャッシュすれば、VRAM 16 GB の GPU でも LoRA の学習はできそうな気がします。

テキストエンコーダーとデノイザーを Optimum Quanto Int8 で量子化した safetensors ファイルを一応↓にアップしてありますが、別にどこかの WebUI で使えるわけではなく、今回の作業レポでしか使えないです。Optimum Quanto ライブラリを直接利用すれば、diffusers の形式で読めるかもしれないですが、試してないのでわからんです。

https://huggingface.co/p1atdev/CogView4-6B-quanto_int8

今回、なぜか TorchAO で消費 VRAM をほとんど削減できなかったのが少し不可解です。量子化に使ったコードがどこかバグってる気がします。もし何かバグっていてそれを修正し、VRAM 消費を改善できれば良い選択肢になるかもしれません。

従来の画像生成モデルでは、テキストエンコーダーがここまで大きいサイズのものはあまりなかったので、テキストエンコーダーの VRAM 消費をどのように解決するかが CogView4-6B の取り回しを改善する上での課題になりそう です。

LLM でよく使われる GGUF フォーマットも、最近では FluxSD3.5 で使われていたりするので、今後こちらも試してみたいです。

脚注
  1. https://huggingface.co/THUDM/glm-4-9b/blob/main/config.json#L31 ↩︎

  2. https://github.com/huggingface/diffusers/blob/2e83cbbb6de84be7241218c8f5ea914ceb68c149/src/diffusers/pipelines/cogview4/pipeline_cogview4.py#L218 ↩︎

  3. https://github.com/huggingface/diffusers/blob/813d42cc96d000abe4788227310329ad0027f14c/src/diffusers/models/embeddings.py#L1640-L1670 ↩︎

  4. https://github.com/huggingface/diffusers/blob/3fe3bc0642cf6ebfa1a815367afd0dc57675ecc7/src/diffusers/models/attention.py#L228-L255 ↩︎

  5. https://github.com/huggingface/diffusers/blob/3fe3bc0642cf6ebfa1a815367afd0dc57675ecc7/src/diffusers/models/transformers/transformer_flux.py#L170-L185 ↩︎

  6. https://github.com/huggingface/diffusers/blob/813d42cc96d000abe4788227310329ad0027f14c/src/diffusers/models/transformers/transformer_cogview4.py#L251-L252 ↩︎

  7. 実行中に MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization と言われた ↩︎

GitHubで編集を提案

Discussion

ログインするとコメントできます