🤗
[メモ]Relative Positional Encodingのためのrel_shift
コード
def rel_shift(x):
l1, l2 = x.size()
x_padded = F.pad(x, (1, 0))
x_padded = x_padded.view(l2 + 1, l1)
x = x_padded[1:].view_as(x)[
:, :l2//2+1
]
return x
print文を追加して随時確認してみる
def rel_shift(x):
print("x = ")
print(x)
l1, l2 = x.size()
x_padded = F.pad(x, (1, 0))
print("x_padded = ")
print(x_padded)
x_padded = x_padded.view(l2 + 1, l1)
print("x_padded = ")
print(x_padded)
x = x_padded[1:].view_as(x)[
:, :l2//2+1
] # only keep the positions from 0 to time2
print("x_padded[1:] = ")
print(x_padded[1:])
print("result = ")
print(x)
return x
>>> x = torch.stack([torch.tensor([-2, -1, 0, 1, 2]) for i in range(5)], dim=0)
>>> x, x.size()
(tensor([[-2, -1, 0, 1, 2],
[-2, -1, 0, 1, 2],
[-2, -1, 0, 1, 2],
[-2, -1, 0, 1, 2],
[-2, -1, 0, 1, 2]]), torch.Size([5, 5]))
>>> _ = rel_shift(x
x =
tensor([[-3, -2, -1, 0, 1, 2, 3],
[-3, -2, -1, 0, 1, 2, 3],
[-3, -2, -1, 0, 1, 2, 3],
[-3, -2, -1, 0, 1, 2, 3]])
x_padded =
tensor([[ 0, -3, -2, -1, 0, 1, 2, 3],
[ 0, -3, -2, -1, 0, 1, 2, 3],
[ 0, -3, -2, -1, 0, 1, 2, 3],
[ 0, -3, -2, -1, 0, 1, 2, 3]])
x_padded =
tensor([[ 0, -3, -2, -1],
[ 0, 1, 2, 3],
[ 0, -3, -2, -1],
[ 0, 1, 2, 3],
[ 0, -3, -2, -1],
[ 0, 1, 2, 3],
[ 0, -3, -2, -1],
[ 0, 1, 2, 3]])
x_padded[1:] =
tensor([[ 0, 1, 2, 3],
[ 0, -3, -2, -1],
[ 0, 1, 2, 3],
[ 0, -3, -2, -1],
[ 0, 1, 2, 3],
[ 0, -3, -2, -1],
[ 0, 1, 2, 3]])
result =
tensor([[ 0, 1, 2, 3],
[-1, 0, 1, 2],
[-2, -1, 0, 1],
[-3, -2, -1, 0]])
Discussion