👻

Julia におけるブロードキャストの使い方

2022/08/17に公開

ブロードキャストを使うにあたって公式ドキュメントで書き方を調べていると、様々な場所に説明が散らばっていて困るので、関わりがあることをこの記事にまとめておきたい。

記法

x, y = rand(3), rand(3)

sin.(x)
broadcast(sin, x)
Broadcast.BroadcastFunction(sin)(x)

# 中置記法が可能な2引数関数
# +, *, =, <, ==, in, isa など
x .+ y
(+).(x, y)
broadcast(+, x, y)
.+(x, y) # (.+) == Broadcast.BroadcastFunction(+)

.+ も関数なので、そのまま他の関数に渡すこともできる。

https://docs.julialang.org/en/v1/manual/arrays/#Broadcasting
https://github.com/JuliaLang/julia/blob/742b9abb4dd4621b667ec5bb3434b8b3602f96fd/base/broadcast.jl#L1307-L1337

Map との違い

broadcast は配列の次元が違った場合は適当に拡張して大きいほうに合わせてくれる。
map は次元が異なっていれば1次元として扱い、要素数の少ないほうに合わせる。

broadcast(+, rand(2), rand(2, 3)) # 2×3 Matrix
      map(+, rand(2), rand(2, 3)) # 2-element Vector

broadcast(+, rand(3), rand(2, 3)) # ERROR: DimensionMismatch
      map(+, rand(3), rand(2, 3)) # 3-element Vector

https://docs.julialang.org/en/v1/base/collections/#Base.map

例:直積

A = ["a$i$j" for i in 1:2, j in 1:3] # 2×3 Matrix
B = ["b$i" for i in 1:4] # 4-element Vector

# 遅延評価
C1 = Iterators.product(A, B)

# 正格評価
"""
    Bx = expanddim(B, A)

`Bx = reshape(B, (1, 1, 1, m, n))` where `ndims(A) == 3`, `size(B) == (m, n)`
"""
expanddim(B, A) = reshape(B, (ntuple(_ -> 1, ndims(A))..., size(B)...))
C2 = tuple.(A, expanddim(B, A)) # 2×3×4 Array

@assert collect(C1) == C2

ブロードキャストを利用した正格評価のほうでは、 tuple のところを適当な二項演算に替えることもできる。 A .* expanddim(B, A) など。

https://docs.julialang.org/en/v1/base/iterators/#Base.Iterators.product

特定の引数をマスクする

ある引数 z をその要素に開いてブロードキャストしてしまうことを「防ぎ」、スカラーのように振る舞ってほしい場合、適当なコンテナに入れればよい。

# (x + z, y + z) を得たい

x, y, z = rand(3), rand(3), rand(3)

(x, y) .+ z # ERROR: DimensionMismatch
# x + z[1], y + z[2] となり、 z[3] の相方がいないため

# 1要素タプルに入れる
(x, y) .+ (z,) # == (x + z, y + z)

# Ref に入れる
(x, y) .+ Ref(z) # == (x + z, y + z)

# もし (x .+ z[1], y .+ z[2], z .+ z[3]) を得たいなら

(x, y, z) .+ z # ERROR: MethodError: no method matching +(::Vector{Float64}, ::Float64)
# ブロードキャストは成功している (x + z[1], y + z[2], z + z[3])
# ただし x + z[1] の演算が許されていない(ちなみに x * z[1] なら動作する)

broadcast(.+, (x, y, z), z) # == (x .+ z[1], y .+ z[2], z .+ z[3])

https://docs.julialang.org/en/v1/manual/arrays/#Broadcasting
https://docs.julialang.org/en/v1/base/c/#Core.Ref

マクロ

マクロ @. を使えば、指定した部分全体にブロードキャストを適用することができる。

@. y = f(x)
y .= f.(x)
broadcast!(f, y, x)

マクロ内で特定の関数をブロードキャストしたくない場合は $ を付けるとよい。

@. sqrt(abs($(sort(x))))
sqrt.(abs.(sort(x)))
broadcast(sqrt∘abs, sort(x))

Discussion