🙌
PyTorchのsqueezeの基本
squeezeとは?
公式ドキュメントによると、squeeze
「入力されたテンソルから指定されたすべてのサイズ1の次元を取り除いたテンソルを返す」です。
数式だと以下になります。
A×B×C×D = torch.squeeze(A×1×B×C×1×D)
分かりにくいので実際にコードを動かしてみましょう。
実例
import torch
x = torch.randn(3, 1, 5)
x.squeeze()
print(x) # shape: (3. 5)
つまり意味のない「長さ1の次元を消すことが出来ます」
第二引数
二番目の引数は、消したい長さ1の次元のインデックスです。
例えば以下のように指定したインデックスのみ、squeezeを適用できます。
import torch
x = torch.zeros(2, 1, 2, 1, 2)
print(x.size()) # torch.Size([2, 1, 2, 1, 2])
y = torch.squeeze(x, 0)
print(y.size()) # torch.Size([2, 1, 2, 1, 2])
y = torch.squeeze(x, 1)
print(y.size()) # torch.Size([2, 2, 1, 2])
y = torch.squeeze(x, (1, 2, 3))
print(y.size()) # torch.Size([2, 2, 2])
Discussion