🤗

[メモ]Relative Positional Encodingのためのrel_shift

2022/03/21に公開

コード

def rel_shift(x):
    l1, l2 = x.size()
    x_padded = F.pad(x, (1, 0))

    x_padded = x_padded.view(l2 + 1, l1)
    x = x_padded[1:].view_as(x)[
        :, :l2//2+1
    ]
    return x

print文を追加して随時確認してみる

def rel_shift(x):
    print("x = ")
    print(x)
    l1, l2 = x.size()
    x_padded = F.pad(x, (1, 0))
    print("x_padded = ")
    print(x_padded)

    x_padded = x_padded.view(l2 + 1, l1)
    print("x_padded = ")
    print(x_padded)
    x = x_padded[1:].view_as(x)[
        :, :l2//2+1
    ]  # only keep the positions from 0 to time2
    print("x_padded[1:] = ")
    print(x_padded[1:])
    print("result = ")
    print(x)
    return x
>>> x = torch.stack([torch.tensor([-2, -1, 0, 1, 2]) for i in range(5)], dim=0)
>>> x, x.size()
(tensor([[-2, -1,  0,  1,  2],
         [-2, -1,  0,  1,  2],
         [-2, -1,  0,  1,  2],
         [-2, -1,  0,  1,  2],
         [-2, -1,  0,  1,  2]]), torch.Size([5, 5]))
>>> _ = rel_shift(x
x = 
tensor([[-3, -2, -1,  0,  1,  2,  3],
        [-3, -2, -1,  0,  1,  2,  3],
        [-3, -2, -1,  0,  1,  2,  3],
        [-3, -2, -1,  0,  1,  2,  3]])
x_padded = 
tensor([[ 0, -3, -2, -1,  0,  1,  2,  3],
        [ 0, -3, -2, -1,  0,  1,  2,  3],
        [ 0, -3, -2, -1,  0,  1,  2,  3],
        [ 0, -3, -2, -1,  0,  1,  2,  3]])
x_padded = 
tensor([[ 0, -3, -2, -1],
        [ 0,  1,  2,  3],
        [ 0, -3, -2, -1],
        [ 0,  1,  2,  3],
        [ 0, -3, -2, -1],
        [ 0,  1,  2,  3],
        [ 0, -3, -2, -1],
        [ 0,  1,  2,  3]])
x_padded[1:] = 
tensor([[ 0,  1,  2,  3],
        [ 0, -3, -2, -1],
        [ 0,  1,  2,  3],
        [ 0, -3, -2, -1],
        [ 0,  1,  2,  3],
        [ 0, -3, -2, -1],
        [ 0,  1,  2,  3]])
result = 
tensor([[ 0,  1,  2,  3],
        [-1,  0,  1,  2],
        [-2, -1,  0,  1],
        [-3, -2, -1,  0]])

Discussion