diffusers で Attention の処理をカスタマイズする方法 | AttnProcessor
はじめに
Stable Diffusion に用いられている U-Net では、内部に複数の Attention モジュールが含まれており、画像生成時に Attention の処理を操作することで生成画像のピクセルの情報を確認したり、生成画像の領域とプロンプトとの対応関係を明示的にコントロールすることが出来るようになりますが、その forward 処理は高度にラップされていることが多く、コードの深掘りが必要なことがよくあります。
そこで、本記事では Huggingface の diffusers ライブラリにおいて、Attention Processor という仕組みを利用することで、StableDiffusionPipeline に含まれる U-Net (UNet2DConditionModel) 内の Attention の forward を上書きする方法を解説します。
忙しい人へ
U-Netの Unet2DConditionModel
内に存在する Attention
クラスをカスタマイズするには、以下のようにします。
- 独自の
Attention Processor
クラスを定義する -
UNet2DConditionModel.set_attn_processor()
で 新しいAttention Processor
を上書きする
Attention Processor の詳細
UNet2DConditionModel に含まれる Attention クラスの実装を見てみると、以下のようなコードがあります。
PyTorch 2.0 を利用している場合、デフォルトで AttnProcessor2_0
が使われており、その実装は以下のようになっています。
まとめると、 Unet2DConditionModel
では Attention Processor と呼ばれるクラスが Attention の処理を担っており、このクラスを変えることが出来れば良さそうです。
U-Net 内の Attention の観察
UNet2DConditionModel
内にどれくらい Attention
が含まれているのかを見てみましょう。次のようなコードで取り出すことが出来ます。
from diffusers import UNet2DConditionModel
unet = UNet2DConditionModel.from_pretrained(
'CompVis/stable-diffusion-v1-4', subfolder='unet'
)
for key, value in unet.attn_processors.items():
print(key, value)
出力は以下のようになります。
UNet2DConditionModel に含まれる attn_processors の詳細
down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x131194050>
down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x131194690>
down_blocks.0.attentions.1.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x131195b90>
down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x131196210>
down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x1311c80d0>
down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x1311c8590>
down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x1311c9b90>
down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x1311ca3d0>
down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x12495f910>
down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x132110510>
down_blocks.2.attentions.1.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x132111ad0>
down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x132112150>
up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x13217c690>
up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x13217cb10>
up_blocks.1.attentions.1.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x13217e1d0>
up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x124921690>
up_blocks.1.attentions.2.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x130ed3c10>
up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x1321b4550>
up_blocks.2.attentions.0.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x1321b6250>
up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x1321b68d0>
up_blocks.2.attentions.1.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x124a239d0>
up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x1321e45d0>
up_blocks.2.attentions.2.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x1321e5cd0>
up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x1321e6350>
up_blocks.3.attentions.0.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x132234050>
up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x132234690>
up_blocks.3.attentions.1.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x132235d50>
up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x1322363d0>
up_blocks.3.attentions.2.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x132237a90>
up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x132260150>
mid_block.attentions.0.transformer_blocks.0.attn1.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x1321488d0>
mid_block.attentions.0.transformer_blocks.0.attn2.processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x132148f50>
この出力から、Attention Processor の個数の内訳は、
- Down ブロック: 12
- Up ブロック: 18
- Mid ブロック: 2
というようになっていることが分かります。
UNet2DConditionModel の Attention Processor を上書きする
Attention
クラスの forward
関数は processor.__call__
によって定義されており、U-Net が含んでいる全ての Attention
クラスに属する processor
は Unet2DConditionModel
の set_attn_processor
関数を呼び出すことで再帰的に上書きすることが出来ます。
また、Stable Diffusion U-Net の Attention では、以下の二種類があります。
- Self-Attention: U-Net 内でノイズ除去の対象となっている Latent 内部で Attention の計算を行う
- Cross-Attention: U-Net 内の Latent と プロンプト等の条件との間で Attention の計算を行う
これらの区別は、Attention Processor の __call__
関数の引数の encoder_hidden_states
が None
であったら Self-Attention, None
でなかったら Cross-Attention と判断することが出来ます。それぞれの Attention で使われている要素について、新しい Attention Processor を定義することで見ていきましょう。
以下のように、 NewAttnProcessor
を定義します。
import dataclasses
import torch
from diffusers.models.attention_processor import AttnProcessor, Attention
from diffusers.utils.constants import USE_PEFT_BACKEND
@dataclasses.dataclass
class ShapeStore:
"""shapeを保存しておく用のクラス"""
q: torch.Size # query
k: torch.Size # key
v: torch.Size # value
attn: torch.Size # attention score/probs
class NewAttnProcessor(AttnProcessor):
def __init__(self):
super().__init__()
# Self/Cross Attention の保存先を追加
self.self_attentions = []
self.cross_attentions = []
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor | None = None,
attention_mask: torch.FloatTensor | None = None,
temb: torch.FloatTensor | None = None,
scale: float = 1.0,
) -> torch.Tensor:
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, *args)
# !!! ここを追加 (1)
is_cross_attn = encoder_hidden_states is not None
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
# !!! ここを追加 (2)
if is_cross_attn:
self.cross_attentions.append(
ShapeStore(q=query.shape, k=key.shape, v=value.shape, attn=attention_probs.shape)
)
else:
self.self_attentions.append(
ShapeStore(q=query.shape, k=key.shape, v=value.shape, attn=attention_probs.shape)
)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
コードが長くて見づらいですが、二点ほど元の Attention Processor からの修正点があります。
-
encoder_hidden_states
の有無によって Cross-Attention か Self-Attention かを判断する真偽値is_cross_attn
を定義 - 各 Attention の処理毎に query, key, value, attention_probs を保存
次に、この Attention Processor を U-Net に登録して見ましょう。
unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(
'CompVis/stable-diffusion-v1-4', subfolder='unet'
)
unet.set_attn_processor(NewAttnProcessor())
for key, value in unet.attn_processors.items():
print(key, value)
確認してみると、無事に新しい Attention Processor で上書きされていることが分かります。
変更後の attn_processors の詳細
down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
down_blocks.0.attentions.1.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
down_blocks.2.attentions.1.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.1.attentions.1.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.1.attentions.2.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.2.attentions.0.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.2.attentions.1.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.2.attentions.2.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.3.attentions.0.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.3.attentions.1.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.3.attentions.2.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
mid_block.attentions.0.transformer_blocks.0.attn1.processor <__main__.NewAttnProcessor object at 0x144050850>
mid_block.attentions.0.transformer_blocks.0.attn2.processor <__main__.NewAttnProcessor object at 0x144050850>
試しに、NewAttnProcessor
の中身も確認して見ます。
unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(
'CompVis/stable-diffusion-v1-4', subfolder='unet'
)
unet.set_attn_processor(NewAttnProcessor())
# Sample Input
latents = torch.randn(2, 4, 64, 64)
text_embs = torch.randn(2, 77, 768)
timestep = 0
with torch.inference_mode():
_ = unet(latents, encoder_hidden_states=text_embs, timestep=timestep)
# 今は全ての Attention Processor が同一のものを参照してるので一つだけ取り出す
attn_processor = next(iter(unet.attn_processors.values()))
print(attn_processor.self_attentions)
print(attn_processor.cross_attentions)
# Self-Attention
[
ShapeStore(q=torch.Size([16, 4096, 40]), k=torch.Size([16, 4096, 40]), v=torch.Size([16, 4096, 40]), attn=torch.Size([16, 4096, 4096])),
ShapeStore(q=torch.Size([16, 4096, 40]), k=torch.Size([16, 4096, 40]), v=torch.Size([16, 4096, 40]), attn=torch.Size([16, 4096, 4096])),
ShapeStore(q=torch.Size([16, 1024, 80]), k=torch.Size([16, 1024, 80]), v=torch.Size([16, 1024, 80]), attn=torch.Size([16, 1024, 1024])),
ShapeStore(q=torch.Size([16, 1024, 80]), k=torch.Size([16, 1024, 80]), v=torch.Size([16, 1024, 80]), attn=torch.Size([16, 1024, 1024])),
ShapeStore(q=torch.Size([16, 256, 160]), k=torch.Size([16, 256, 160]), v=torch.Size([16, 256, 160]), attn=torch.Size([16, 256, 256])),
ShapeStore(q=torch.Size([16, 256, 160]), k=torch.Size([16, 256, 160]), v=torch.Size([16, 256, 160]), attn=torch.Size([16, 256, 256])),
ShapeStore(q=torch.Size([16, 64, 160]), k=torch.Size([16, 64, 160]), v=torch.Size([16, 64, 160]), attn=torch.Size([16, 64, 64])),
ShapeStore(q=torch.Size([16, 256, 160]), k=torch.Size([16, 256, 160]), v=torch.Size([16, 256, 160]), attn=torch.Size([16, 256, 256])),
ShapeStore(q=torch.Size([16, 256, 160]), k=torch.Size([16, 256, 160]), v=torch.Size([16, 256, 160]), attn=torch.Size([16, 256, 256])),
ShapeStore(q=torch.Size([16, 256, 160]), k=torch.Size([16, 256, 160]), v=torch.Size([16, 256, 160]), attn=torch.Size([16, 256, 256])),
ShapeStore(q=torch.Size([16, 1024, 80]), k=torch.Size([16, 1024, 80]), v=torch.Size([16, 1024, 80]), attn=torch.Size([16, 1024, 1024])),
ShapeStore(q=torch.Size([16, 1024, 80]), k=torch.Size([16, 1024, 80]), v=torch.Size([16, 1024, 80]), attn=torch.Size([16, 1024, 1024])),
ShapeStore(q=torch.Size([16, 1024, 80]), k=torch.Size([16, 1024, 80]), v=torch.Size([16, 1024, 80]), attn=torch.Size([16, 1024, 1024])),
ShapeStore(q=torch.Size([16, 4096, 40]), k=torch.Size([16, 4096, 40]), v=torch.Size([16, 4096, 40]), attn=torch.Size([16, 4096, 4096])),
ShapeStore(q=torch.Size([16, 4096, 40]), k=torch.Size([16, 4096, 40]), v=torch.Size([16, 4096, 40]), attn=torch.Size([16, 4096, 4096])),
ShapeStore(q=torch.Size([16, 4096, 40]), k=torch.Size([16, 4096, 40]), v=torch.Size([16, 4096, 40]), attn=torch.Size([16, 4096, 4096]))
]
# Cross-Attention
[
ShapeStore(q=torch.Size([16, 4096, 40]), k=torch.Size([16, 77, 40]), v=torch.Size([16, 77, 40]), attn=torch.Size([16, 4096, 77])),
ShapeStore(q=torch.Size([16, 4096, 40]), k=torch.Size([16, 77, 40]), v=torch.Size([16, 77, 40]), attn=torch.Size([16, 4096, 77])),
ShapeStore(q=torch.Size([16, 1024, 80]), k=torch.Size([16, 77, 80]), v=torch.Size([16, 77, 80]), attn=torch.Size([16, 1024, 77])),
ShapeStore(q=torch.Size([16, 1024, 80]), k=torch.Size([16, 77, 80]), v=torch.Size([16, 77, 80]), attn=torch.Size([16, 1024, 77])),
ShapeStore(q=torch.Size([16, 256, 160]), k=torch.Size([16, 77, 160]), v=torch.Size([16, 77, 160]), attn=torch.Size([16, 256, 77])),
ShapeStore(q=torch.Size([16, 256, 160]), k=torch.Size([16, 77, 160]), v=torch.Size([16, 77, 160]), attn=torch.Size([16, 256, 77])),
ShapeStore(q=torch.Size([16, 64, 160]), k=torch.Size([16, 77, 160]), v=torch.Size([16, 77, 160]), attn=torch.Size([16, 64, 77])),
ShapeStore(q=torch.Size([16, 256, 160]), k=torch.Size([16, 77, 160]), v=torch.Size([16, 77, 160]), attn=torch.Size([16, 256, 77])),
ShapeStore(q=torch.Size([16, 256, 160]), k=torch.Size([16, 77, 160]), v=torch.Size([16, 77, 160]), attn=torch.Size([16, 256, 77])),
ShapeStore(q=torch.Size([16, 256, 160]), k=torch.Size([16, 77, 160]), v=torch.Size([16, 77, 160]), attn=torch.Size([16, 256, 77])),
ShapeStore(q=torch.Size([16, 1024, 80]), k=torch.Size([16, 77, 80]), v=torch.Size([16, 77, 80]), attn=torch.Size([16, 1024, 77])),
ShapeStore(q=torch.Size([16, 1024, 80]), k=torch.Size([16, 77, 80]), v=torch.Size([16, 77, 80]), attn=torch.Size([16, 1024, 77])),
ShapeStore(q=torch.Size([16, 1024, 80]), k=torch.Size([16, 77, 80]), v=torch.Size([16, 77, 80]), attn=torch.Size([16, 1024, 77])),
ShapeStore(q=torch.Size([16, 4096, 40]), k=torch.Size([16, 77, 40]), v=torch.Size([16, 77, 40]), attn=torch.Size([16, 4096, 77])),
ShapeStore(q=torch.Size([16, 4096, 40]), k=torch.Size([16, 77, 40]), v=torch.Size([16, 77, 40]), attn=torch.Size([16, 4096, 77])),
ShapeStore(q=torch.Size([16, 4096, 40]), k=torch.Size([16, 77, 40]), v=torch.Size([16, 77, 40]), attn=torch.Size([16, 4096, 77]))
]
うまく取り出せていそうです。
まとめ
本記事では、Stable DiffusionのU-Net内で使用されているAttentionメカニズムのカスタマイズ方法について解説しました。U-NetのAttentionモジュールを理解し、独自の処理を組み込むことで、画像生成プロセスにおけるより細かいコントロールが可能になります。個人的には Attention Map を可視化すると面白いので、余力があれば別の記事で書こうと思います。
Discussion