🥷
指定したインデックスでテンソルからベクトルを抽出する
概要
テンソルに対して指定したインデックスでベクトルを抽出するPytorchのコード。
例えば、batch size
やり方
import torch
tensor = torch.randn(2, 2, 10)
indices = torch.tensor([0, 1])
indices = indices.reshape(-1, 1)
zero_vector = torch.zeros(tensor.shape[0], tensor.shape[1])
one_hot = zero_vector.scatter_(1, indices, 1)
one_hot = one_hot.unsqueeze(-1)
extracted_vector = (tensor * one_hot).sum(1)
解説
このコードではone hotベクトルを作成して、それをテンソルと掛け合わせて指定したベクトルを抽出する。
tensor
は抽出したいベクトルを含むテンソルである。
tensor([[[ 0.1454, -0.9785, -0.8343, 0.4616, 0.5629, -0.4330, -0.1140,
0.5299, 0.8999, -1.6505],
[-1.2407, -2.0927, -0.0927, -0.1455, 0.7530, -0.0190, -2.3263,
-2.7178, -0.5882, 0.5455]],
[[ 0.4071, 0.3638, 0.5934, -1.6908, -0.4121, -0.8543, -0.2508,
1.2647, -1.2513, -2.1701],
[-0.1980, -0.0038, 0.3369, -1.3647, -0.3377, 1.3121, -0.8706,
0.1129, -0.1284, -1.0524]]])
indices
は抽出したいベクトルのインデックスを表すテンソルである。
indices
をone-hotベクトルに変換するために、まずtensor
と同じサイズのゼロベクトルzero_vector
を作成する。
scatter_メソッドは第一引数のベクトルの第二引数のインデックスに第三引数の値を代入する。そのため、one_hote
は以下のようになる。
tensor([[1., 0.],
[0., 1.]])
unsqueezeメソッドにより次元を追加することで、one_hot
は(2, 2, 1)のテンソルになる。(2, 2, 10)のテンソルとかけわせることで、抽出したいベクトル以外のベクトルには0が掛けられゼロベクトルとなる。次元にそって足し合わせることで、抽出したいベクトルのみが残る。
tensor([[ 0.1454, -0.9785, -0.8343, 0.4616, 0.5629, -0.4330, -0.1140, 0.5299,
0.8999, -1.6505],
[-0.1980, -0.0038, 0.3369, -1.3647, -0.3377, 1.3121, -0.8706, 0.1129,
-0.1284, -1.0524]])
Discussion