📝

PytorchのFlexAttentionの動作確認

2024/10/13に公開

背景

Pytorch 2.6.0 (nightly) でFlexAttentionが導入されて、カスタムのMaskやBias付きでAttentionできる。
https://pytorch.org/blog/flexattention/

可変長のサンプルを学習する場合、1バッチにサンプルをまとめてる方法(sequence packing)があり、パディングするよりメモリ効率がよい。この場合、サンプルのtoken間のみでAttentionする必要があり、具体的には block diagonal な mask でAttentionする必要がある。

これについてFlexAttentionで動作確認した。

なお、このようなMaskつきのAttentionはblock-diagonal attentionなどと呼ばれる。上のPytorchの記事では、Document Maskingと呼ばれている 。

動作環境

  • Windows WSL
  • RTX 2070 Super
  • Driver Version: 536.23
  • torch==2.6.0.dev20241012+cu118

torchは以下でインストール ※公式サイトのPreview,Linux,Pip,Python,CUDA 11.8を選択して出てくるインストール手順。

pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118

実施結果

意図通り動作していそう。入力の特定位置のKey,Value,Queryを変えると、Attention Maskした範囲で出力が変わった。

コード

import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask


###  casual mask
# APIドキュメントのコードから引用
# https://pytorch.org/docs/main/nn.attention.flex_attention.html#torch.nn.attention.flex_attention.create_block_mask
if 0:
    def causal_mask(b, h, q_idx, kv_idx):
        return q_idx >= kv_idx

    block_mask = create_block_mask(causal_mask, 1, 1, 8192, 8192, device="cuda", BLOCK_SIZE=128)
    query = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
    key = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
    value = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
    output = flex_attention(query, key, value, block_mask=block_mask)

    print(f"casual mask:\n{block_mask}")

    # change last value
    query[:,:,-1] = 7
    key[:,:,-1] = 7
    value[:,:,-1] = 7
    output2 = flex_attention(query, key, value, block_mask=block_mask)

    diff = output - output2
    print(f"output diff after last element modified:\n{diff}")

### block diagonal attention
# v1だとエラーになったのでv2に変更
# v1エラー内容: RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow. We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 .
# v2参考: https://github.com/pytorch/pytorch/issues/136427#issue-2542452642
if 1:
    BORDER_ARR = (0, 1024, 1024+256, 8192 - 2, 8192)  # 4 documentでdocument内のtoken間でattentionする

    def find_first(arr, value):
        for i, e in enumerate(arr):
            if e > value:
                return i - 1
        return -1

    def block_diagonal_mask_v1(b, h, q_idx, kv_idx):
        q_doc_pos  = find_first(BORDER_ARR, q_idx)
        kv_doc_pos = find_first(BORDER_ARR, kv_idx)
        return q_doc_pos != kv_doc_pos

    def gen_idx_to_doc_idx_arr():
        """kvqのidnexからdocument indexに変換するための配列を作成: [0,0,0,1,1,1,1,...]"""
        r = []
        for i in range(len(BORDER_ARR) - 1):
            r.extend([i] * (BORDER_ARR[i+1]-BORDER_ARR[i]))
        return torch.tensor(r, device='cuda')

    QKV_IDX_TO_DOC_IDX_MAP = gen_idx_to_doc_idx_arr()

    def block_diagonal_mask_v2(b, h, q_idx, kv_idx):
        q_doc_pos  = QKV_IDX_TO_DOC_IDX_MAP[q_idx]
        kv_doc_pos = QKV_IDX_TO_DOC_IDX_MAP[kv_idx]
        return q_doc_pos == kv_doc_pos

    block_mask2 = create_block_mask(block_diagonal_mask_v2, 1, 1, 8192, 8192, device="cuda", BLOCK_SIZE=128)
    query = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
    key = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
    value = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)

    print(f"block diagonal mask:\n{block_mask2}")

    output = flex_attention(query, key, value, block_mask=block_mask2)

    # 最後の要素だけ変更して再度推論すると、最後の2要素だけdiffが発生する. 
    # 想定通りのattentionになっている
    CHANGE_IDX = -1  # or -2
    query[:,:,CHANGE_IDX] = 9
    key[:,:,CHANGE_IDX] = 9
    value[:,:,CHANGE_IDX] = 9
    output2 = flex_attention(query, key, value, block_mask=block_mask2)
    diff = output - output2
    print(f"output diff after last element modified:\n{diff}")

標準出力

block diagonal mask:
BlockMask(shape=(1, 1, 8192, 8192), sparsity=27.15%,
(0, 0)
████
████
    ░░░░░░░░░░░░░░░░░░░░░░░░░░░░
    ░░██████████████████████████
    ░░██████████████████████████
    ░░██████████████████████████
    ░░██████████████████████████
    ░░██████████████████████████
    ░░██████████████████████████
    ░░██████████████████████████
    ░░██████████████████████████
    ░░██████████████████████████
    ░░██████████████████████████
    ░░██████████████████████████
    ░░██████████████████████████
    ░░██████████████████████████
)
output diff after last element modified:
tensor([[[[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [ -8.5391,  -9.8516,  -8.2969,  ...,  -8.7969, -10.1875, -10.2266],
          [ -8.6250,  -9.8594,  -8.3359,  ...,  -8.8438, -10.2500, -10.2812]]]],

Discussion