🧱

【Numpy】indexing と broadcasting を直感的に理解したい

に公開

はじめに

Numpy の indexing ってなかなかに複雑で、混乱した経験は無いでしょうか?
例えば、こんな numpy 操作を見て結果がすぐに予測できるでしょうか?

import numpy as np
a = np.arange(24).reshape(2, 3, 4)
b = np.array([10, 20])
res = a[0, ..., 1, np.newaxis] + b
# res.shape は? 中身は?

(答えは末尾の Quiz にあります)

機械学習 (ML) 系のプログラムでは必ずといっていいほど Numpy や PyTorch のテンソルを扱いますが、テンソルの形状や broadcasting の動きが理解できないことがきっかけで、何をやっているのかよくわからなくなってしまうことが筆者はよくありました。

この記事では、基本的な indexing, broadcasting 操作について特に結果のテンソルの shape に注目して直感的に理解することを目指します。
形状さえわかってしまえば、各位置にどの値が入るかはさほど迷わないはずです。

記事の内容

Indexing に表われる以下の要素を説明します

  • 整数 (e.g. 2)
  • :
  • ...
  • None (np.newaxis)

この記事では、以下の点に注目して indexing の結果を簡潔に説明することを目指します

  • 結果のテンソルの shape
  • indexing 前後の次元の対応

また、最後に broadcasting における次元の拡張の考え方を簡単に説明します。

Reference

この記事で説明する indexing は、 reference ("Indexing on ndarrays") の "Basic Indexing" (をやや簡略化したもの) におおむね対応します。
ほかの indexing の方法 (e.g. multidimensional array, boolean array) 等については reference をご参照ください。

Python の例について

import numpy as np

により、 np という名前で numpy モジュールを参照できる状態を前提とします

環境

記事中の例は以下の環境で検証しています

  • Numpy 2.4.2

この記事の内容について、 Numpy のバージョン間で大きな差は無いと想定しています

Indexing の基本的な構造

N 次元テンソルに対して、カンマ区切りで N 要素までの index を指定して indexing できます。

>>> a = np.array([[0, 1, 2], [3, 4, 5]])
>>> a.shape
(2, 3)
>>> print(a[0, 2])
2

ここでは、 (0, 2) によって a を indexing しています。
0 は最初の次元から対応する要素を選び、 2 は次の次元から対応する要素を選びます。

: による indexing

整数による indexing では対応する次元から 1 要素だけを選んでいましたが、 : は対応する次元の全ての要素を選びます。

>>> a = np.array([[0, 1, 2], [3, 4, 5]])
>>> a[:, 1] # 0 番目の次元は全て、1 番目の次元はインデックス 1 のみ
array([1, 4])
>>> a[:, 1].shape
(2,)

... による indexing

... は整数や : による indexing で使われずに残った次元の数 (0 個以上) の : に展開されます

>>> a = np.ones((2, 3, 4, 5))
>>> a[0, ..., 0].shape # a[0, ..., 0] は a[0, :, :, 0] と同じ
(3, 4)

... が明示的に指定されていない場合でも、 indexing の末尾に常に暗黙に ... があると思うと考えやすいです。

>>> a = np.ones((2, 3, 4, 5))
>>> a[0].shape # a[0] は a[0, ...] または a[0, :, :, :] と同じ
(3, 4, 5)

None (np.newaxis) による indexing

np.newaxisNone の alias です。

https://numpy.org/doc/2.2/reference/constants.html#numpy.newaxis

整数, :, ... はいずれも元のテンソルに対して対応する次元をもっていましたが、 np.newaxis は少し異なり、元のテンソルと対応しないサイズ 1 の次元を新たに増やします。

>>> a = np.array([0, 1, 2])
>>> a.shape
(3,)
>>> a[:, np.newaxis].shape
(3, 1)
>>> a[np.newaxis, :].shape
(1, 3)

Shape の決定プロセス

手続き的な理解

ここでは indexing した結果のテンソルの shape がどのように決まっているかをアルゴリズム的に説明します。
この考え方を用いると、元のテンソルの形状と index の並びさえ見れば形式的に結果の形状を計算できます。

また、計算の過程で元の配列との次元の対応がわかるので、実際結果テンソルのどの位置に元のテンソルのどの要素が対応するかは容易にわかるでしょう (形式化するのは少し面倒なので省略させてください)。

ここでは index は :, ..., None, 整数のみからなるとします

index = (index_elem (, index_elem)*)
index_elem = : | ... | None | integer

はじめに、次元の数合わせと ... の除去の正規化処理を行います (前処理)。

  • Index が ... を含まない場合、末尾に ... を追加する
  • ... を必要な数の : に展開する

... を除去した後の、残りの Shape の決定を Python 風の疑似コードで表わします。

def calc_result_shape(a: np.array, index: List[IndexElement]):
  current_dim = 0
  result_shape = []

  for index_elem in index:
    if is_integer(index_elem):
      current_dim += 1
    elif index_elem == ":":
      result_shape.append(a.shape[current_dim])
      current_dim += 1
    elif index_elem is None:
      result_shape.append(1)

  return result_shape

各変数の意味は次のように考えてください

  • current_dim:「現在 a の何番目の次元を処理しているか」
  • result_shape: indexing した結果のテンソルの形状

current_dim += 1 というのは、a の次元を進める ("消費" している) と考えることができます。
None のみ a の次元を消費せずに、結果の次元を増やします。

a.shape = (10, 20, 30)
index = (:, None, 5) 

# 正規化
(:, None, 5)
-> (:, None, 5, ...)
-> (:, None, 5, :)
index = (:, None, 5, :) 

calc_result_shape:
0. result_shape = [], current_dim = 0
1. ":" -> result_shape = [10], current_dim = 1  (a.shape[0] を消費)
2. None -> result_shape = [10, 1], current_dim = 1 (a の次元は進めず、result_shape に 1 を追加)
3. 5 -> result_shape = [10, 1], current_dim = 2 (a.shape[1] を消費、整数なので result_shape には追加しない)
4. ":" -> result_shape = [10, 1, 30], current_dim = 3 (a.shape[2] を消費)

result_shape = [10, 1, 30]

視覚的理解 (補足)

テンソルをひとつの "箱" のようなものだと考えて、これを (外から内へ) ネストさせていってテンソルを組み立てるような視覚的イメージで説明してみます。

を次に考えるテンソルを示す箱のようなものだと考えてください。
結果のテンソルははじめ空 ( 1 つ) から始まって、以下の変換を繰り返して形状が決まると考えます。

が複数ある場合はすべての を同じように置きかえます。

  • 整数: ⎵ -> ⎵ (次元を消費)

  • :: ⎵ -> [⎵, ⎵, ⎵, ..., ⎵] (次元を消費)

  • None: ⎵ -> [⎵]

>>> a = np.array([[0, 1, 2], [3, 4, 5]])
>>> b = a[:, np.newaxis, 2]
>>> b
array([[2],
       [5]])
>>> b.shape
(2, 1)

説明

## Init

index: (:, np,newaxis, 2)
a.shape: (2, 3)
b: ⎵


## First (`:`) (∀⎵. ⎵ -> [⎵, ..., ⎵])

index: (:, np,newaxis, 2)
        ^
a.shape: (2, 3)
          ^
b: ⎵ -> b: [⎵, ⎵]


## Second (`np.newaxis`) (∀⎵. ⎵ -> [⎵])

index: (:, np,newaxis, 2)
           ^^^^^^^^^^
a.shape: (2, 3)
             ^ (not used)
b: [⎵, ⎵] -> b: [[⎵], [⎵]]


## Third (`2`) (∀⎵. ⎵ -> ⎵)

index: (:, np,newaxis, 2)
                       ^
a.shape: (2, 3)
             ^
b: [[⎵], [⎵]] -> b: [[⎵], [⎵]]

Broadcasting

Broadcasting とは、形状が異なる配列同士で算術演算を行う際に、形状を自動的に合わせる仕組みのことです。

Broadcasting は、対応する次元を 右から左の順 で考えるとわかりやすいです。

以下のようなルールで次元を拡張します。

  • 右詰めで次元の数を合わせる (もともと無い部分は 1 と考える)
  • 各次元の大きさについて、以下のどちらかが満たされるとき小さくないほうに合わせる
    • a. 値が一致する
    • b. 片方が 1

対応する次元について a, b どちらも満たされない場合は ValueError となります。

  • Shape の例
A = np.random.random((8, 1, 6, 5))
B = np.random.random((7, 1, 5))
Result = A + B

A      (4d array):  8 x 1 x 6 x 5
B      (3d array):      7 x 1 x 5
Result (4d array):  8 x 7 x 6 x 5

次元を 1 から大きいほうに合わせて拡張する際には、論理的にその次元の要素を複製して計算に使います。
np.newaxis と組み合わせることで、 outer product 等を簡単に表現できます。

>>> import numpy as np
>>> a = np.array([1, 2, 3])
>>> b = np.array([4, 5])

>>> a[:, np.newaxis] * b
array([[ 4,  5],
       [ 8, 10],
       [12, 15]])

>>> np.outer(a, b)  # equivalent
array([[ 4,  5],
       [ 8, 10],
       [12, 15]])

Quiz

最後に indexing, broadcasting について簡単な quiz を出題します。答えは折り畳みの中にあります。

  • Shape のみ

結果のテンソルの shape を答えてください。

Q1. np.ones((10, 20, 30))[5, ..., 0].shape
Q2. np.ones((2, 3, 4))[:, None, :, 0].shape
Q3. (np.ones((5, 1, 4)) + np.ones((3, 4))).shape

答え

A1. (20,)
A2. (2, 1, 3)
A3. (5, 3, 4)

  • 値つき

Q. res の中身は?

>>> a = np.arange(24).reshape(2, 3, 4)
>>> a
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])
>>> b = np.array([10, 20])
>>> b
array([10, 20])
>>> res = a[0, ..., 1, np.newaxis] + b  # res = ???
答え
>>> res
array([[11, 21],
       [15, 25],
       [19, 29]])

解説

a.shape = (2, 3, 4)

a[0, ..., 1, np.newaxis].shape = (3, 1)
a[0, ..., 1, np.newaxisa] (= a'):
array([[1],
       [5],
       [9]])

b.shape = (2,)
b: array([10, 20])

a' : 3 x 1
b  :     2
res: 3 x 2
GitHubで編集を提案

Discussion