🧨

diffusers で Attention の処理をカスタマイズする方法 | AttnProcessor

2024/02/23に公開

はじめに

Stable Diffusion に用いられている U-Net では、内部に複数の Attention モジュールが含まれており、画像生成時に Attention の処理を操作することで生成画像のピクセルの情報を確認したり、生成画像の領域とプロンプトとの対応関係を明示的にコントロールすることが出来るようになりますが、その forward 処理は高度にラップされていることが多く、コードの深掘りが必要なことがよくあります。

そこで、本記事では Huggingface の diffusers ライブラリにおいて、Attention Processor という仕組みを利用することで、StableDiffusionPipeline に含まれる U-Net (UNet2DConditionModel) 内の Attention の forward を上書きする方法を解説します。

忙しい人へ

U-Netの Unet2DConditionModel 内に存在する Attention クラスをカスタマイズするには、以下のようにします。

  1. 独自の Attention Processor クラスを定義する
  2. UNet2DConditionModel.set_attn_processor() で 新しい Attention Processor を上書きする

Attention Processor の詳細

UNet2DConditionModel に含まれる Attention クラスの実装を見てみると、以下のようなコードがあります。

https://github.com/huggingface/diffusers/blob/bb1b76d3bf9ef78a827086d1b9449975237ecbac/src/diffusers/models/attention_processor.py#L488-L529

PyTorch 2.0 を利用している場合、デフォルトで AttnProcessor2_0 が使われており、その実装は以下のようになっています。

https://github.com/huggingface/diffusers/blob/bb1b76d3bf9ef78a827086d1b9449975237ecbac/src/diffusers/models/attention_processor.py#L1193-L1275

まとめると、 Unet2DConditionModel では Attention Processor と呼ばれるクラスが Attention の処理を担っており、このクラスを変えることが出来れば良さそうです。

U-Net 内の Attention の観察

UNet2DConditionModel 内にどれくらい Attention が含まれているのかを見てみましょう。次のようなコードで取り出すことが出来ます。

from diffusers import UNet2DConditionModel

unet = UNet2DConditionModel.from_pretrained(
    'CompVis/stable-diffusion-v1-4', subfolder='unet'
)

for key, value in unet.attn_processors.items():
    print(key, value)

出力は以下のようになります。

UNet2DConditionModel に含まれる attn_processors の詳細
down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x131194050>
down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x131194690>
down_blocks.0.attentions.1.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x131195b90>
down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x131196210>
down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x1311c80d0>
down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x1311c8590>
down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x1311c9b90>
down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x1311ca3d0>
down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x12495f910>
down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x132110510>
down_blocks.2.attentions.1.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x132111ad0>
down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x132112150>
up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x13217c690>
up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x13217cb10>
up_blocks.1.attentions.1.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x13217e1d0>
up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x124921690>
up_blocks.1.attentions.2.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x130ed3c10>
up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x1321b4550>
up_blocks.2.attentions.0.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x1321b6250>
up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x1321b68d0>
up_blocks.2.attentions.1.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x124a239d0>
up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x1321e45d0>
up_blocks.2.attentions.2.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x1321e5cd0>
up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x1321e6350>
up_blocks.3.attentions.0.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x132234050>
up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x132234690>
up_blocks.3.attentions.1.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x132235d50>
up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x1322363d0>
up_blocks.3.attentions.2.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x132237a90>
up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x132260150>
mid_block.attentions.0.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x1321488d0>
mid_block.attentions.0.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x132148f50>

この出力から、Attention Processor の個数の内訳は、

  • Down ブロック: 12
  • Up ブロック: 18
  • Mid ブロック: 2

というようになっていることが分かります。

UNet2DConditionModel の Attention Processor を上書きする

Attention クラスの forward 関数は processor.__call__ によって定義されており、U-Net が含んでいる全ての Attention クラスに属する processorUnet2DConditionModelset_attn_processor 関数を呼び出すことで再帰的に上書きすることが出来ます。

https://github.com/huggingface/diffusers/blob/bb1b76d3bf9ef78a827086d1b9449975237ecbac/src/diffusers/models/unets/unet_2d_condition.py#L716-L748

また、Stable Diffusion U-Net の Attention では、以下の二種類があります。

  • Self-Attention: U-Net 内でノイズ除去の対象となっている Latent 内部で Attention の計算を行う
  • Cross-Attention: U-Net 内の Latent と プロンプト等の条件との間で Attention の計算を行う

これらの区別は、Attention Processor の __call__ 関数の引数の encoder_hidden_statesNone であったら Self-Attention, None でなかったら Cross-Attention と判断することが出来ます。それぞれの Attention で使われている要素について、新しい Attention Processor を定義することで見ていきましょう。

以下のように、 NewAttnProcessor を定義します。

import dataclasses

import torch
from diffusers.models.attention_processor import AttnProcessor, Attention
from diffusers.utils.constants import USE_PEFT_BACKEND

@dataclasses.dataclass
class ShapeStore:
    """shapeを保存しておく用のクラス"""
    q: torch.Size  # query
    k: torch.Size  # key
    v: torch.Size  # value
    attn: torch.Size  # attention score/probs

class NewAttnProcessor(AttnProcessor):
    def __init__(self):
        super().__init__()
        # Self/Cross Attention の保存先を追加
        self.self_attentions = []
        self.cross_attentions = []

    def __call__(
            self,
            attn: Attention,
            hidden_states: torch.FloatTensor,
            encoder_hidden_states: torch.FloatTensor | None = None,
            attention_mask: torch.FloatTensor | None = None,
            temb: torch.FloatTensor | None = None,
            scale: float = 1.0,
    ) -> torch.Tensor:
        residual = hidden_states

        args = () if USE_PEFT_BACKEND else (scale,)

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states, *args)

        # !!! ここを追加 (1)
        is_cross_attn = encoder_hidden_states is not None

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states, *args)
        value = attn.to_v(encoder_hidden_states, *args)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)

        # !!! ここを追加 (2)
        if is_cross_attn:
            self.cross_attentions.append(
                ShapeStore(q=query.shape, k=key.shape, v=value.shape, attn=attention_probs.shape)
            )
        else:
            self.self_attentions.append(
                ShapeStore(q=query.shape, k=key.shape, v=value.shape, attn=attention_probs.shape)
            )

        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states, *args)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

コードが長くて見づらいですが、二点ほど元の Attention Processor からの修正点があります。

  1. encoder_hidden_states の有無によって Cross-Attention か Self-Attention かを判断する真偽値 is_cross_attn を定義
  2. 各 Attention の処理毎に query, key, value, attention_probs を保存

次に、この Attention Processor を U-Net に登録して見ましょう。

unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(
    'CompVis/stable-diffusion-v1-4', subfolder='unet'
)
unet.set_attn_processor(NewAttnProcessor())

for key, value in unet.attn_processors.items():
    print(key, value)

確認してみると、無事に新しい Attention Processor で上書きされていることが分かります。

変更後の attn_processors の詳細
down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
down_blocks.0.attentions.1.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
down_blocks.2.attentions.1.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.1.attentions.1.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.1.attentions.2.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.2.attentions.0.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.2.attentions.1.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.2.attentions.2.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.3.attentions.0.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.3.attentions.1.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.3.attentions.2.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
mid_block.attentions.0.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
mid_block.attentions.0.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>

試しに、NewAttnProcessor の中身も確認して見ます。

unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(
    'CompVis/stable-diffusion-v1-4', subfolder='unet'
)
unet.set_attn_processor(NewAttnProcessor())

# Sample Input
latents = torch.randn(2, 4, 64, 64)
text_embs = torch.randn(2, 77, 768)
timestep = 0

with torch.inference_mode():
    _ = unet(latents, encoder_hidden_states=text_embs, timestep=timestep)

# 今は全ての Attention Processor が同一のものを参照してるので一つだけ取り出す
attn_processor = next(iter(unet.attn_processors.values()))

print(attn_processor.self_attentions)
print(attn_processor.cross_attentions)
# Self-Attention
[
    ShapeStore(q=torch.Size([16, 4096, 40]), k=torch.Size([16, 4096, 40]), v=torch.Size([16, 4096, 40]), attn=torch.Size([16, 4096, 4096])),
    ShapeStore(q=torch.Size([16, 4096, 40]), k=torch.Size([16, 4096, 40]), v=torch.Size([16, 4096, 40]), attn=torch.Size([16, 4096, 4096])),
    ShapeStore(q=torch.Size([16, 1024, 80]), k=torch.Size([16, 1024, 80]), v=torch.Size([16, 1024, 80]), attn=torch.Size([16, 1024, 1024])),
    ShapeStore(q=torch.Size([16, 1024, 80]), k=torch.Size([16, 1024, 80]), v=torch.Size([16, 1024, 80]), attn=torch.Size([16, 1024, 1024])),
    ShapeStore(q=torch.Size([16, 256, 160]), k=torch.Size([16, 256, 160]), v=torch.Size([16, 256, 160]), attn=torch.Size([16, 256, 256])),
    ShapeStore(q=torch.Size([16, 256, 160]), k=torch.Size([16, 256, 160]), v=torch.Size([16, 256, 160]), attn=torch.Size([16, 256, 256])),
    ShapeStore(q=torch.Size([16, 64, 160]), k=torch.Size([16, 64, 160]), v=torch.Size([16, 64, 160]), attn=torch.Size([16, 64, 64])),
    ShapeStore(q=torch.Size([16, 256, 160]), k=torch.Size([16, 256, 160]), v=torch.Size([16, 256, 160]), attn=torch.Size([16, 256, 256])),
    ShapeStore(q=torch.Size([16, 256, 160]), k=torch.Size([16, 256, 160]), v=torch.Size([16, 256, 160]), attn=torch.Size([16, 256, 256])),
    ShapeStore(q=torch.Size([16, 256, 160]), k=torch.Size([16, 256, 160]), v=torch.Size([16, 256, 160]), attn=torch.Size([16, 256, 256])),
    ShapeStore(q=torch.Size([16, 1024, 80]), k=torch.Size([16, 1024, 80]), v=torch.Size([16, 1024, 80]), attn=torch.Size([16, 1024, 1024])),
    ShapeStore(q=torch.Size([16, 1024, 80]), k=torch.Size([16, 1024, 80]), v=torch.Size([16, 1024, 80]), attn=torch.Size([16, 1024, 1024])),
    ShapeStore(q=torch.Size([16, 1024, 80]), k=torch.Size([16, 1024, 80]), v=torch.Size([16, 1024, 80]), attn=torch.Size([16, 1024, 1024])),
    ShapeStore(q=torch.Size([16, 4096, 40]), k=torch.Size([16, 4096, 40]), v=torch.Size([16, 4096, 40]), attn=torch.Size([16, 4096, 4096])),
    ShapeStore(q=torch.Size([16, 4096, 40]), k=torch.Size([16, 4096, 40]), v=torch.Size([16, 4096, 40]), attn=torch.Size([16, 4096, 4096])),
    ShapeStore(q=torch.Size([16, 4096, 40]), k=torch.Size([16, 4096, 40]), v=torch.Size([16, 4096, 40]), attn=torch.Size([16, 4096, 4096]))
]

# Cross-Attention
[
    ShapeStore(q=torch.Size([16, 4096, 40]), k=torch.Size([16, 77, 40]), v=torch.Size([16, 77, 40]), attn=torch.Size([16, 4096, 77])),
    ShapeStore(q=torch.Size([16, 4096, 40]), k=torch.Size([16, 77, 40]), v=torch.Size([16, 77, 40]), attn=torch.Size([16, 4096, 77])),
    ShapeStore(q=torch.Size([16, 1024, 80]), k=torch.Size([16, 77, 80]), v=torch.Size([16, 77, 80]), attn=torch.Size([16, 1024, 77])),
    ShapeStore(q=torch.Size([16, 1024, 80]), k=torch.Size([16, 77, 80]), v=torch.Size([16, 77, 80]), attn=torch.Size([16, 1024, 77])),
    ShapeStore(q=torch.Size([16, 256, 160]), k=torch.Size([16, 77, 160]), v=torch.Size([16, 77, 160]), attn=torch.Size([16, 256, 77])),
    ShapeStore(q=torch.Size([16, 256, 160]), k=torch.Size([16, 77, 160]), v=torch.Size([16, 77, 160]), attn=torch.Size([16, 256, 77])),
    ShapeStore(q=torch.Size([16, 64, 160]), k=torch.Size([16, 77, 160]), v=torch.Size([16, 77, 160]), attn=torch.Size([16, 64, 77])),
    ShapeStore(q=torch.Size([16, 256, 160]), k=torch.Size([16, 77, 160]), v=torch.Size([16, 77, 160]), attn=torch.Size([16, 256, 77])),
    ShapeStore(q=torch.Size([16, 256, 160]), k=torch.Size([16, 77, 160]), v=torch.Size([16, 77, 160]), attn=torch.Size([16, 256, 77])),
    ShapeStore(q=torch.Size([16, 256, 160]), k=torch.Size([16, 77, 160]), v=torch.Size([16, 77, 160]), attn=torch.Size([16, 256, 77])),
    ShapeStore(q=torch.Size([16, 1024, 80]), k=torch.Size([16, 77, 80]), v=torch.Size([16, 77, 80]), attn=torch.Size([16, 1024, 77])),
    ShapeStore(q=torch.Size([16, 1024, 80]), k=torch.Size([16, 77, 80]), v=torch.Size([16, 77, 80]), attn=torch.Size([16, 1024, 77])),
    ShapeStore(q=torch.Size([16, 1024, 80]), k=torch.Size([16, 77, 80]), v=torch.Size([16, 77, 80]), attn=torch.Size([16, 1024, 77])),
    ShapeStore(q=torch.Size([16, 4096, 40]), k=torch.Size([16, 77, 40]), v=torch.Size([16, 77, 40]), attn=torch.Size([16, 4096, 77])),
    ShapeStore(q=torch.Size([16, 4096, 40]), k=torch.Size([16, 77, 40]), v=torch.Size([16, 77, 40]), attn=torch.Size([16, 4096, 77])),
    ShapeStore(q=torch.Size([16, 4096, 40]), k=torch.Size([16, 77, 40]), v=torch.Size([16, 77, 40]), attn=torch.Size([16, 4096, 77]))
]

うまく取り出せていそうです。

まとめ

本記事では、Stable DiffusionのU-Net内で使用されているAttentionメカニズムのカスタマイズ方法について解説しました。U-NetのAttentionモジュールを理解し、独自の処理を組み込むことで、画像生成プロセスにおけるより細かいコントロールが可能になります。個人的には Attention Map を可視化すると面白いので、余力があれば別の記事で書こうと思います。

Discussion