Open5

深層学習の色々(PyTorch)

calcifercalcifer

Flatten(平坦化)

デフォルトは start_dim=1, end_dim=-1 の範囲が平坦化.

# t = torch.Size([2, 3, 4, 5])

flatten = nn.Flatten()
print(flatten(t).shape)
# torch.Size([2, 60])

start_dim=0, end_dim=-1 の範囲を平坦化すると,

# t = torch.Size([2, 3, 4, 5])

flatten = nn.Flatten(0, -1)
print(flatten(t).shape)
# torch.Size([120])

https://note.nkmk.me/python-pytorch-flatten/

calcifercalcifer

周波数軸とチャンネル軸に沿って平坦化するとは,チャネルと周波数の次元を1つのベクトルにまとめるという意味なので,
[バッチサイズ、チャネル数、周波数、時間]

[バッチサイズ、時間、チャネル×周波数]
にする.