🤗

2022/03/21に公開

# コード

``````def rel_shift(x):
l1, l2 = x.size()

:, :l2//2+1
]
return x
``````

print文を追加して随時確認してみる

``````def rel_shift(x):
print("x = ")
print(x)
l1, l2 = x.size()

:, :l2//2+1
]  # only keep the positions from 0 to time2
print("result = ")
print(x)
return x
``````
``````>>> x = torch.stack([torch.tensor([-2, -1, 0, 1, 2]) for i in range(5)], dim=0)
>>> x, x.size()
(tensor([[-2, -1,  0,  1,  2],
[-2, -1,  0,  1,  2],
[-2, -1,  0,  1,  2],
[-2, -1,  0,  1,  2],
[-2, -1,  0,  1,  2]]), torch.Size([5, 5]))
``````
``````>>> _ = rel_shift(x
x =
tensor([[-3, -2, -1,  0,  1,  2,  3],
[-3, -2, -1,  0,  1,  2,  3],
[-3, -2, -1,  0,  1,  2,  3],
[-3, -2, -1,  0,  1,  2,  3]])
tensor([[ 0, -3, -2, -1,  0,  1,  2,  3],
[ 0, -3, -2, -1,  0,  1,  2,  3],
[ 0, -3, -2, -1,  0,  1,  2,  3],
[ 0, -3, -2, -1,  0,  1,  2,  3]])
tensor([[ 0, -3, -2, -1],
[ 0,  1,  2,  3],
[ 0, -3, -2, -1],
[ 0,  1,  2,  3],
[ 0, -3, -2, -1],
[ 0,  1,  2,  3],
[ 0, -3, -2, -1],
[ 0,  1,  2,  3]])
tensor([[ 0,  1,  2,  3],
[ 0, -3, -2, -1],
[ 0,  1,  2,  3],
[ 0, -3, -2, -1],
[ 0,  1,  2,  3],
[ 0, -3, -2, -1],
[ 0,  1,  2,  3]])
result =
tensor([[ 0,  1,  2,  3],
[-1,  0,  1,  2],
[-2, -1,  0,  1],
[-3, -2, -1,  0]])
``````