🙌

PyTorchのsqueezeの基本

に公開

squeezeとは?

公式ドキュメントによると、squeeze
「入力されたテンソルから指定されたすべてのサイズ1の次元を取り除いたテンソルを返す」です。

数式だと以下になります。

A×B×C×D = torch.squeeze(A×1×B×C×1×D)

分かりにくいので実際にコードを動かしてみましょう。

実例

import torch

x = torch.randn(3, 1, 5)
x.squeeze()
print(x) # shape: (3. 5)

つまり意味のない「長さ1の次元を消すことが出来ます」

第二引数

二番目の引数は、消したい長さ1の次元のインデックスです。
例えば以下のように指定したインデックスのみ、squeezeを適用できます。

import torch

x = torch.zeros(2, 1, 2, 1, 2)
print(x.size()) # torch.Size([2, 1, 2, 1, 2])
y = torch.squeeze(x, 0)
print(y.size()) # torch.Size([2, 1, 2, 1, 2])
y = torch.squeeze(x, 1)
print(y.size()) # torch.Size([2, 2, 1, 2])
y = torch.squeeze(x, (1, 2, 3))
print(y.size()) # torch.Size([2, 2, 2])

Discussion