【備忘録】torch.catのdim
torch.cat
って何しているの?
shape = (2, 3, 4)
の3階テンソル[1]に対して、cat
を施した後の、それぞれのprint
の出力結果が想像つきますか?
少なくとも、初見で私は全く想像できませんでした。
そこで備忘録として、torch.cat
の使い方や、考えのイメージをまとめました。
import torch
shape = (2, 3, 4)
tensor = torch.ones(shape)
print(tensor)
print("--------dim=0---------")
print(torch.cat([tensor, tensor, tensor], dim=0))
print("--------dim=1---------")
print(torch.cat([tensor, tensor, tensor], dim=1))
print("--------dim=2---------")
print(torch.cat([tensor, tensor, tensor], dim=2))
(結果を先に知りたい人はアコーディオンを開いてね)
dim=0
tensor([[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]],
[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]],
[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]],
[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]],
[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]],
[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]]])
dim=1
tensor([[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]],
[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]]])
dim=2
tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],
[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]])
PyTorch公式による解説
簡単に訳すと、
(ここから訳)
torch.cat
torch.cat(tensors, dim=0, *, out=None) -> Tensor
指定された次元で、与えられたテンソル列を連結します。すべてのテンソルは同じ形(shape)であるか(連結する次元を除いて)、またはサイズ (0,)
の
-
torch.cat()
はtroch.split()
とtorch.chunk()
に対する逆演算子と見なせます。 -
torch.cat()
は例を知ることで、ベストな理解が得られます。
(中略)
です。どうやら、複数のテンソルを繋げてくれる関数で、連結する方向に制約があるそうだ。そして、一番ひっかかったのが、連結する次元ってどっちやねんということです。
図解
shape = (2, 3, 4)
の3階テンソルは次のようなイメージです。
3次元空間の中にある、torch.ones(shape)
なので、小立方体の中の値は全て1です。
小立方体のどの値が欲しいかは、3次元空間の位置を示すのと同じように、3つのインデックスを指定します。
例えば、
tensor = torch.tensor([[[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]],
[[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24]]])
という3階テンソルの場合には、tensor[1][2][0]
は21です。
では本題に入ります。torch.cat
でdim
を指定することは、下記のような図で表されます:
torch.cat([tensor, tensor, tensor], dim=0)
torch.cat([tensor, tensor, tensor], dim=1)
torch.cat([tensor, tensor, tensor], dim=2)
まとめ
図から想像できるように、torch.cat
は、小立方体で構成された立体(tensor)を繋ぎ合わせます。繋ぎ合わせるときの「接合面」は互いに同じ数の小立体でないといけないため、dim
の指定が必要になります。
なので、接合面が一致しておけば、それ以外が一致しなくてもOKです。
例えば、
# shape (2, 2)
tensor1 = torch.tensor([[1, 2],
[3, 4]])
# shape (3, 2)
tensor2 = torch.tensor([[5, 6],
[8, 9],
[10, 11]])
# shape(5, 2)
tensor3 = torch.cat([tensor1, tensor2], dim=0)
print(tensor3)
>>>tensor([[ 1, 2],
[ 3, 4],
[ 5, 6],
[ 8, 9],
[10, 11]])
ですね。
一般に、shape = (a_0, a_1, ..., a_i, ..., a_n)
のテンソルとshape = (b_0, b_1, ..., b_i, ..., b_n)
があると、torch.cat
可能ということです。
-
数学、物理の分野では、テンソルは単に複数の成分を持った量を意味しません。特に物理学では、テンソルは座標に依存しない量で、テンソルの成分は座標変換に対する変換生を持ちます。話が噛み合わない時は、相手の専門分野を聞いてみましょう。 ↩︎
Discussion