📝
PytorchのFlexAttentionの動作確認
背景
Pytorch 2.6.0 (nightly) でFlexAttentionが導入されて、カスタムのMaskやBias付きでAttentionできる。
可変長のサンプルを学習する場合、1バッチにサンプルをまとめてる方法(sequence packing)があり、パディングするよりメモリ効率がよい。この場合、サンプルのtoken間のみでAttentionする必要があり、具体的には block diagonal な mask でAttentionする必要がある。
これについてFlexAttentionで動作確認した。
なお、このようなMaskつきのAttentionはblock-diagonal attentionなどと呼ばれる。上のPytorchの記事では、Document Maskingと呼ばれている 。
動作環境
- Windows WSL
- RTX 2070 Super
- Driver Version: 536.23
- torch==2.6.0.dev20241012+cu118
torchは以下でインストール ※公式サイトのPreview,Linux,Pip,Python,CUDA 11.8を選択して出てくるインストール手順。
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118
実施結果
意図通り動作していそう。入力の特定位置のKey,Value,Queryを変えると、Attention Maskした範囲で出力が変わった。
コード
import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
### casual mask
# APIドキュメントのコードから引用
# https://pytorch.org/docs/main/nn.attention.flex_attention.html#torch.nn.attention.flex_attention.create_block_mask
if 0:
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
block_mask = create_block_mask(causal_mask, 1, 1, 8192, 8192, device="cuda", BLOCK_SIZE=128)
query = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
key = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
value = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
output = flex_attention(query, key, value, block_mask=block_mask)
print(f"casual mask:\n{block_mask}")
# change last value
query[:,:,-1] = 7
key[:,:,-1] = 7
value[:,:,-1] = 7
output2 = flex_attention(query, key, value, block_mask=block_mask)
diff = output - output2
print(f"output diff after last element modified:\n{diff}")
### block diagonal attention
# v1だとエラーになったのでv2に変更
# v1エラー内容: RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow. We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 .
# v2参考: https://github.com/pytorch/pytorch/issues/136427#issue-2542452642
if 1:
BORDER_ARR = (0, 1024, 1024+256, 8192 - 2, 8192) # 4 documentでdocument内のtoken間でattentionする
def find_first(arr, value):
for i, e in enumerate(arr):
if e > value:
return i - 1
return -1
def block_diagonal_mask_v1(b, h, q_idx, kv_idx):
q_doc_pos = find_first(BORDER_ARR, q_idx)
kv_doc_pos = find_first(BORDER_ARR, kv_idx)
return q_doc_pos != kv_doc_pos
def gen_idx_to_doc_idx_arr():
"""kvqのidnexからdocument indexに変換するための配列を作成: [0,0,0,1,1,1,1,...]"""
r = []
for i in range(len(BORDER_ARR) - 1):
r.extend([i] * (BORDER_ARR[i+1]-BORDER_ARR[i]))
return torch.tensor(r, device='cuda')
QKV_IDX_TO_DOC_IDX_MAP = gen_idx_to_doc_idx_arr()
def block_diagonal_mask_v2(b, h, q_idx, kv_idx):
q_doc_pos = QKV_IDX_TO_DOC_IDX_MAP[q_idx]
kv_doc_pos = QKV_IDX_TO_DOC_IDX_MAP[kv_idx]
return q_doc_pos == kv_doc_pos
block_mask2 = create_block_mask(block_diagonal_mask_v2, 1, 1, 8192, 8192, device="cuda", BLOCK_SIZE=128)
query = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
key = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
value = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
print(f"block diagonal mask:\n{block_mask2}")
output = flex_attention(query, key, value, block_mask=block_mask2)
# 最後の要素だけ変更して再度推論すると、最後の2要素だけdiffが発生する.
# 想定通りのattentionになっている
CHANGE_IDX = -1 # or -2
query[:,:,CHANGE_IDX] = 9
key[:,:,CHANGE_IDX] = 9
value[:,:,CHANGE_IDX] = 9
output2 = flex_attention(query, key, value, block_mask=block_mask2)
diff = output - output2
print(f"output diff after last element modified:\n{diff}")
標準出力
block diagonal mask:
BlockMask(shape=(1, 1, 8192, 8192), sparsity=27.15%,
(0, 0)
████
████
░░░░░░░░░░░░░░░░░░░░░░░░░░░░
░░██████████████████████████
░░██████████████████████████
░░██████████████████████████
░░██████████████████████████
░░██████████████████████████
░░██████████████████████████
░░██████████████████████████
░░██████████████████████████
░░██████████████████████████
░░██████████████████████████
░░██████████████████████████
░░██████████████████████████
░░██████████████████████████
)
output diff after last element modified:
tensor([[[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
...,
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ -8.5391, -9.8516, -8.2969, ..., -8.7969, -10.1875, -10.2266],
[ -8.6250, -9.8594, -8.3359, ..., -8.8438, -10.2500, -10.2812]]]],
Discussion