SparseAttention実装調査
色々試したが batch=1, seq=300k は厳しそう。
今後の方針:私のユースケースでは、1サンプルの長さはせいぜい128(最短2)。seqence郡を複数バッチに事前にpackingする方法でよいのではないか。累計300k tokenを約3k batch x 128 seqlenにパッキングするのは、128単位でSparseBlockMaskしているのと同じようなことなので、こちらの方針で次はトライする
Flash Attension
実装されていなさそう
-
https://github.com/Dao-AILab/flash-attention/issues/31
- 動かなくて困っている
-
https://github.com/Dao-AILab/flash-attention/issues/352
-
Attention mask isn't supported (either in v1 or v2). I might implement it at some point but there are other priorities now.
-
xformers
使用例:
-
MEXMA: Token-level objectives improve sentence representations
-
7.1Encoder backbone
The available implementation of XLM-RoBERTa in HuggingFace employs an inefficient attention mechanism, which we have modified to incorporate the memory-efficient attention from xFormers (Lefaudeux et al., 2022). This modification was necessary due to the random batching process used in our training, which results in a significant amount of padding and increased computational cost. To address this issue and eliminate padding, we have employed the BlockDiagonalMask 2, which through custom CUDA kernels, avoids computations in padding altogether
-
動作確認
- コード例
- 試したところ、3token x 10k document の合計30k tokenでは動いた
- 3token x 100k document 設定では動かなかった...
-
RuntimeError: CUDA error: invalid configuration argument
-
上限が決まっている模様..
No, because the optimized kernels are built for specific sizes, and the maximum size anything is built for in sm_80 for ampere is 32768 because of the number of possible in-flight operations IIRC.
I've traced back a bit to cuda code here. I found the problem is came from that the batch size used in the original attention layer will build corresponding SM threads on GPU. If the threads(batch) size is larger than one GPU can support (A100 can only support up to 32 x 2048 = 65536 threads), the error occurred.
Also took a quick look at pytorch source code and found that they always have a constraint constant (one calledMAX_BLOCK_SIZE) to deal with large amount of resource. Using the similar logic might solve this issue.
コード
xformers
# code from
# https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.fmha.attn_bias.BlockDiagonalMask
import torch
from xformers.ops import fmha
K = 16
dtype = torch.float16
device = "cuda"
if 1:
#seqlen = [3,6,2]
seqlen = [3] * 10000
total_len = sum(seqlen)
attn_bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens(seqlen)
x = torch.randn([1, total_len, 1, K], dtype=dtype, device=device)
else:
list_x = [
torch.randn([1, 3, 1, K], dtype=dtype, device=device),
torch.randn([1, 6, 1, K], dtype=dtype, device=device),
torch.randn([1, 2, 1, K], dtype=dtype, device=device),
]
attn_bias, x = fmha.attn_bias.BlockDiagonalMask.from_tensor_list(list_x)
linear = torch.nn.Linear(K, K * 3).to(device=device, dtype=dtype)
q, k, v = linear(x).reshape([1, -1, 1, 3, K]).unbind(-2)
out = fmha.memory_efficient_attention(q, k, v, attn_bias=attn_bias)
list_out = attn_bias.split(out)
print("len(list_out)", len(list_out))
print(list_out[0].shape) # [1, 3, 1, K]
assert tuple(list_out[0].shape) == (1, 3, 1, K)
flash attention (非sparse)
"""
検証コード
3token x 10k doc -> OK
3token x 100k doc -> NG. OOM
"""
import torch
import numpy as np
# flash attn 1.x
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
device = "cuda"
len_arr = [3] * 10000 # e.g. [3,2]
cum_len = np.cumsum([0] + len_arr)
cu_seqlens_q = cu_seqlens_k = torch.tensor(cum_len, device=device, dtype=torch.int32) # doc, len3,len1
max_seq_len = cu_seqlens_q[-1]
B, H, SEQ_LEN, HEAD_DIM = 1, 1, max_seq_len, 16
def make_tensor():
return torch.randn(SEQ_LEN, H, HEAD_DIM // H, device=device, dtype=torch.float16)
# return torch.randn(B, SEQ_LEN, H, HEAD_DIM // H, device=device, dtype=torch.float16)
# return torch.ones(B, H, SEQ_LEN, HEAD_DIM, device=device, dtype=torch.float16)
q, k, v = make_tensor(), make_tensor(), make_tensor()
seqlen_q = seqlen_k = max_seq_len
is_causal = False
dropout_p = 0.0
softmax_scale = None
output = flash_attn_unpadded_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
seqlen_q,
seqlen_k,
dropout_p,
softmax_scale=softmax_scale,
causal=is_causal,
)
print("output", output.shape, output)
q[-1,0,0] = 7
k[-1,0,0] = 7
v[-1,0,0] = 7
output2 = flash_attn_unpadded_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
seqlen_q,
seqlen_k,
dropout_p,
softmax_scale=softmax_scale,
causal=is_causal,
)
diff = output - output2
print("diff", diff)
Discussion
これでよさそう。デフォルトのPytorch MultiheadAttentionで、3k batch x 128 seqlenや 30k x 128 seqlenで動く。
随分遠回りしてしまった。
パディング個所で計算不要の部分は key_padding_mask で制御できる。Shapeは(Batch, Seq) 。