🦌
【Tips】Behaviour of squeeze() in PyTorch
1. squeeze
Genellary Use
squeeze() search input and remove the dimentional that has only 1 element when it is found.
Specifying Index
squeeze(index) search input with index and remove 1 if it exists there.
・squeeze
import torch
features = torch.randn(1, 3, 4, 1)
print(features.shape) # Output: torch.Size([1, 3, 1, 4])
features_ = features.squeeze()
print(features_.shape) # Output: torch.Size([3, 4])
features__ = features.squeeze(-1)
print(features__.shape) # Output: torch.Size([1, 3, 4])
features___ = features.squeeze(-2)
print(features___.shape) # Output: torch.Size([1, 3, 4, 1])
This is often used at pytorch operation to align dimentions.
Summary
squeeze() remove the dimentions that only one element.
Discussion