🥷

指定したインデックスでテンソルからベクトルを抽出する

2024/01/08に公開

概要

テンソルに対して指定したインデックスでベクトルを抽出するPytorchのコード。
例えば、batch size\timessequence length\timeshidden sizeのテンソルからバッチごとに指定してbatch size\timeshidden sizeのベクトルを抽出する操作をfor文を使わないでバッチで行うことができる。

やり方

import torch

tensor = torch.randn(2, 2, 10)
indices = torch.tensor([0, 1])
indices = indices.reshape(-1, 1)
zero_vector = torch.zeros(tensor.shape[0], tensor.shape[1])
one_hot = zero_vector.scatter_(1, indices, 1)
one_hot = one_hot.unsqueeze(-1)
extracted_vector = (tensor * one_hot).sum(1)

解説

このコードではone hotベクトルを作成して、それをテンソルと掛け合わせて指定したベクトルを抽出する。

tensorは抽出したいベクトルを含むテンソルである。

tensor([[[ 0.1454, -0.9785, -0.8343,  0.4616,  0.5629, -0.4330, -0.1140,
           0.5299,  0.8999, -1.6505],
         [-1.2407, -2.0927, -0.0927, -0.1455,  0.7530, -0.0190, -2.3263,
          -2.7178, -0.5882,  0.5455]],

        [[ 0.4071,  0.3638,  0.5934, -1.6908, -0.4121, -0.8543, -0.2508,
           1.2647, -1.2513, -2.1701],
         [-0.1980, -0.0038,  0.3369, -1.3647, -0.3377,  1.3121, -0.8706,
           0.1129, -0.1284, -1.0524]]])

indicesは抽出したいベクトルのインデックスを表すテンソルである。
indicesをone-hotベクトルに変換するために、まずtensorと同じサイズのゼロベクトルzero_vectorを作成する。
scatter_メソッドは第一引数のベクトルの第二引数のインデックスに第三引数の値を代入する。そのため、one_hoteは以下のようになる。

tensor([[1., 0.],
        [0., 1.]])

unsqueezeメソッドにより次元を追加することで、one_hotは(2, 2, 1)のテンソルになる。(2, 2, 10)のテンソルとかけわせることで、抽出したいベクトル以外のベクトルには0が掛けられゼロベクトルとなる。次元にそって足し合わせることで、抽出したいベクトルのみが残る。

tensor([[ 0.1454, -0.9785, -0.8343,  0.4616,  0.5629, -0.4330, -0.1140,  0.5299,
          0.8999, -1.6505],
        [-0.1980, -0.0038,  0.3369, -1.3647, -0.3377,  1.3121, -0.8706,  0.1129,
         -0.1284, -1.0524]])

Discussion