numpy.stackのaxis=-1について調べてみた

2024/09/23に公開

こんにちは、沙代です。
axis=-1ってよく分からなくなってしまうことありますよね。今回はそれについて調べてみました。

環境
MacOS: macOS Montery 12.6.2 (Apple M1 Pro)
iTerm2: 3.5.4
Python: 3.10.0
numpy: 1.22.4

ドキュメントを見てみます

numpy.stack(arrays, axis=0, out=None)
Join a sequence of arrays along a new axis.

The axis parameter specifies the index of the new axis in the dimensions of the result. For example, if axis=0 it will be the first dimension and if axis=-1 it will be the last dimension.

「新しい軸(axis)に沿って、配列(array)らを繋げる。
軸(axis)は、出力されるものの次元において新しい軸が何番目になるか(index)を指定する。例えば、axis=0は最初の次元となることを示し、axis=-1は最後の次元となることを示す。」(筆者超訳)

例を試してみます

準備(arrays)
>>> arrays = [np.random.randn(3, 4) for _ in range(10)]
axis=0
>>> np.stack(arrays, axis=0).shape
(10, 3, 4)
axis=1
>>> np.stack(arrays, axis=1).shape
(3, 10, 4)
axis=2 (or axis=-1)
>>> np.stack(arrays, axis=2).shape
(3, 4, 10)
準備(a,b)
>>> a = np.array([1, 2, 3])
>>> b = np.array([4, 5, 6])
axis=0(default)
>>> np.stack((a, b))
array([[1, 2, 3],
       [4, 5, 6]])
axis=-1
>>> np.stack((a, b), axis=-1)
array([[1, 4],
       [2, 5],
       [3, 6]])
>>> c = [np.array([[i*6+j*2, i*6+j*2+1] for j in range(3)]) for i in range(4)]
>>> c
[array([[0, 1],
       [2, 3],
       [4, 5]]), array([[ 6,  7],
       [ 8,  9],
       [10, 11]]), array([[12, 13],
       [14, 15],
       [16, 17]]), array([[18, 19],
       [20, 21],
       [22, 23]])]
axis=0
>>> np.stack(c, axis=0)
array([[[ 0,  1],
        [ 2,  3],
        [ 4,  5]],
       [[ 6,  7],
        [ 8,  9],
        [10, 11]],
       [[12, 13],
        [14, 15],
        [16, 17]],
       [[18, 19],
        [20, 21],
        [22, 23]]])
axis=1
>>> np.stack(c, axis=1)
array([[[ 0,  1],
        [ 6,  7],
        [12, 13],
        [18, 19]],
       [[ 2,  3],
        [ 8,  9],
        [14, 15],
        [20, 21]],
       [[ 4,  5],
        [10, 11],
        [16, 17],
        [22, 23]]])
axis=2 (or axis=-1)
>>> np.stack(c, axis=2)
array([[[ 0,  6, 12, 18],
        [ 1,  7, 13, 19]],
       [[ 2,  8, 14, 20],
        [ 3,  9, 15, 21]],
       [[ 4, 10, 16, 22],
        [ 5, 11, 17, 23]]])

(出力の空行は省きました)

最後に

axis=-1は、転置かと思っていましたが、確かに2次元になる場合はそのように言えるのかもしれません。

参考ページ

https://numpy.org/doc/stable/reference/generated/numpy.stack.html

Discussion