🎃

(フル)アテンションマスクをSlidingなマスクやSparseなマスクに加工する方法

2025/03/01に公開

Sliding Window attentionやMaxViTのGrid attention(Sparse attention)やBlock attentionなど、アテンションにはいくつかのパターンがある。フルアテンションマスクやチャンクアテンションマスクを加工することで、Sliding window版マスク、Grid attention版マスク、Block attention版マスクに変換できる。

具体的には以下のように加工できる ※MaxViTは2D入力だがここでは1D入力で説明する

  • MaskにSlidingWindowMaskを掛け算すると、Sliding window版マスクになる
  • Maskを間引きすると、Sparseアテンション用のマスクになる
  • Maskに対角Blockなマスクを掛け算するとBlockAttention用マスクになる ※この記事では省略





# チャンクの定義
chunks = [3, 5, 2]
size = sum(chunks)
# マスクの初期化
attn_mask = np.zeros((size, size), dtype=int)
# チャンクごとにアクセス可能な範囲を設定
start = 0
for chunk in chunks:
    attn_mask[start:start+chunk, start:start+chunk] = 1
    start += chunk


window = 3
window_mask = np.zeros((size, size), dtype=int)
for i in range(size):
    start = max(0, i - window // 2)
    end = min(size, i + window // 2 + 1)
    window_mask[i, start:end] = 1


# 両者を掛け算して合成マスクを作成
combined_mask = attn_mask * window_mask


# Grid mask
grid1_mask = attn_mask[::2, ::2]
grid2_mask = attn_mask[1::2, 1::2]


# 可視化
def visualize_attn_mask(attn_mask: np.ndarray, title: str):
    """ Attention Maskを可視化する """
    plt.figure(figsize=(6, 6))
    plt.imshow(attn_mask, cmap='viridis')
    plt.colorbar()
    
    N = attn_mask.shape[0]
    ticks = np.arange(0, N, 1)
    plt.xticks(ticks)
    plt.yticks(ticks)
    
    # グリッド線の設定
    plt.gca().set_xticks(np.arange(-0.5, N, 1), minor=True)
    plt.gca().set_yticks(np.arange(-0.5, N, 1), minor=True)
    plt.grid(which='minor', color='black', linestyle='-', linewidth=1)
    
    plt.title(title)
    plt.show()

Discussion