👏

SparseAttention実装調査

2024/10/13に公開
2

色々試したが batch=1, seq=300k は厳しそう。

今後の方針:私のユースケースでは、1サンプルの長さはせいぜい128(最短2)。seqence郡を複数バッチに事前にpackingする方法でよいのではないか。累計300k tokenを約3k batch x 128 seqlenにパッキングするのは、128単位でSparseBlockMaskしているのと同じようなことなので、こちらの方針で次はトライする

Flash Attension

実装されていなさそう

xformers

使用例:

  • MEXMA: Token-level objectives improve sentence representations
    • 7.1Encoder backbone
      The available implementation of XLM-RoBERTa in HuggingFace employs an inefficient attention mechanism, which we have modified to incorporate the memory-efficient attention from xFormers (Lefaudeux et al., 2022). This modification was necessary due to the random batching process used in our training, which results in a significant amount of padding and increased computational cost. To address this issue and eliminate padding, we have employed the BlockDiagonalMask 2, which through custom CUDA kernels, avoids computations in padding altogether

動作確認

  • コード例
  • 試したところ、3token x 10k document の合計30k tokenでは動いた
  • 3token x 100k document 設定では動かなかった...
    • RuntimeError: CUDA error: invalid configuration argument

上限が決まっている模様..

https://github.com/facebookresearch/xformers/issues/998#issuecomment-2016721065

No, because the optimized kernels are built for specific sizes, and the maximum size anything is built for in sm_80 for ampere is 32768 because of the number of possible in-flight operations IIRC.

https://github.com/facebookresearch/xformers/issues/845#issuecomment-1852076953

I've traced back a bit to cuda code here. I found the problem is came from that the batch size used in the original attention layer will build corresponding SM threads on GPU. If the threads(batch) size is larger than one GPU can support (A100 can only support up to 32 x 2048 = 65536 threads), the error occurred.

Also took a quick look at pytorch source code and found that they always have a constraint constant (one calledMAX_BLOCK_SIZE) to deal with large amount of resource. Using the similar logic might solve this issue.

コード

xformers

# code from
# https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.fmha.attn_bias.BlockDiagonalMask

import torch
from xformers.ops import fmha


K = 16
dtype = torch.float16
device = "cuda"

if 1:
    #seqlen = [3,6,2]
    seqlen = [3] * 10000
    total_len = sum(seqlen)
    attn_bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens(seqlen)
    x = torch.randn([1, total_len, 1, K], dtype=dtype, device=device)
else:
    list_x = [
        torch.randn([1, 3, 1, K], dtype=dtype, device=device),
        torch.randn([1, 6, 1, K], dtype=dtype, device=device),
        torch.randn([1, 2, 1, K], dtype=dtype, device=device),
    ]
    attn_bias, x = fmha.attn_bias.BlockDiagonalMask.from_tensor_list(list_x)

linear = torch.nn.Linear(K, K * 3).to(device=device, dtype=dtype)
q, k, v = linear(x).reshape([1, -1, 1, 3, K]).unbind(-2)
out = fmha.memory_efficient_attention(q, k, v, attn_bias=attn_bias)

list_out = attn_bias.split(out)
print("len(list_out)", len(list_out))
print(list_out[0].shape)  # [1, 3, 1, K]
assert tuple(list_out[0].shape) == (1, 3, 1, K)

flash attention (非sparse)

"""
検証コード

3token x 10k  doc -> OK
3token x 100k doc -> NG. OOM

"""
import torch
import numpy as np
# flash attn 1.x
from flash_attn.flash_attn_interface import flash_attn_unpadded_func


device = "cuda"
len_arr = [3] * 10000  # e.g. [3,2]
cum_len = np.cumsum([0] + len_arr)
cu_seqlens_q = cu_seqlens_k = torch.tensor(cum_len, device=device, dtype=torch.int32)  #  doc, len3,len1

max_seq_len = cu_seqlens_q[-1]
B, H, SEQ_LEN, HEAD_DIM = 1, 1, max_seq_len, 16

def make_tensor():
    return torch.randn(SEQ_LEN, H, HEAD_DIM // H, device=device, dtype=torch.float16)
    # return torch.randn(B, SEQ_LEN, H, HEAD_DIM // H, device=device, dtype=torch.float16)
    # return torch.ones(B, H, SEQ_LEN, HEAD_DIM, device=device, dtype=torch.float16)

q, k, v = make_tensor(), make_tensor(), make_tensor()
seqlen_q = seqlen_k = max_seq_len
is_causal = False
dropout_p = 0.0
softmax_scale = None

output = flash_attn_unpadded_func(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    seqlen_q,
    seqlen_k,
    dropout_p,
    softmax_scale=softmax_scale,
    causal=is_causal,
)
print("output", output.shape, output)

q[-1,0,0] = 7
k[-1,0,0] = 7
v[-1,0,0] = 7
output2 = flash_attn_unpadded_func(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    seqlen_q,
    seqlen_k,
    dropout_p,
    softmax_scale=softmax_scale,
    causal=is_causal,
)
diff = output - output2
print("diff", diff)

Discussion

lisosialisosia

今後の方針:私のユースケースでは、1サンプルの長さはせいぜい128(最短2)。seqence郡を複数バッチに事前にpackingする方法でよいのではどうか。累計300k tokenを約3k batch x 128 seqlenにパッキングするのは、128単位でSparseBlockMaskしているのと同じようなことなので、こちらの方針で次はトライする

これでよさそう。デフォルトのPytorch MultiheadAttentionで、3k batch x 128 seqlenや 30k x 128 seqlenで動く。

随分遠回りしてしまった。

import torch
import torch.nn as nn

# パラメータ設定
batch_size = 3000  # 3k
seq_len = 128
hidden_dim = 16
num_heads = 1  # hidden_dimが16なので、例えば4ヘッドに分割可能

# サンプルの入力テンソルを作成
input_tensor = torch.randn(seq_len, batch_size, hidden_dim)

# MultiheadAttention モジュールの初期化
attention_layer = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads)

# アテンションマスクの作成
# ここでは一部のエントリを無視するために、ランダムにマスクを生成
# maskの形状は (batch_size, seq_len, seq_len)
attention_mask = (torch.randn(batch_size * num_heads, seq_len, seq_len) > 0)
print("attention_mask[0]", attention_mask[0])
print("attention_mask[1]", attention_mask[1])

# 出力を計算
output, attn_weights = attention_layer(input_tensor, input_tensor, input_tensor, attn_mask=attention_mask)

print("Output shape:", output.shape)
print("Attention Weights shape:", attn_weights.shape)
lisosialisosia

パディング個所で計算不要の部分は key_padding_mask で制御できる。Shapeは(Batch, Seq) 。