🦌

【Tips】Behaviour of squeeze() in PyTorch

2024/06/05に公開

1. squeeze

Genellary Use

squeeze() search input and remove the dimentional that has only 1 element when it is found.

Specifying Index

squeeze(index) search input with index and remove 1 if it exists there.

・squeeze

import torch

features = torch.randn(1, 3, 4, 1)
print(features.shape)  # Output: torch.Size([1, 3, 1, 4])

features_ = features.squeeze()
print(features_.shape)  # Output: torch.Size([3, 4])

features__ = features.squeeze(-1)
print(features__.shape)  # Output: torch.Size([1, 3, 4])

features___ = features.squeeze(-2)
print(features___.shape)  # Output: torch.Size([1, 3, 4, 1])

This is often used at pytorch operation to align dimentions.

Summary

squeeze() remove the dimentions that only one element.

Discussion