🙃

やっぱり細かすぎるかもだけど伝わってほしいJuliaのTips2024

2024/12/25に公開3

メリークリスマス!

この記事は Julia Advent Calendar 2024 の25日目(最終日)の記事です。

はじめに

一昨年・昨年も同じ日に同じようなネタでお祝いしました(細かすぎて伝わらないかもしれないJuliaのTips細かすぎてたぶん伝わらないJuliaのTips2023)。
昨年も一昨年とは少しだけ方向性を変えましたが、今年もちょっとだけ趣向を変えます。

毎年この枠は「(なるべく)Julia の標準API・文法の範疇で、みんな使ってないかもしれないけれどこんな書き方・機能もあるんだよ」ということを紹介しています(そもそも Tips なので)。でも昨年・一昨年は「とは言えこの機能知らなくてもそんなに困らないよなー」というトピックが多かったような気がします。
そこで今年は、「こんなAPI・機能あるんだよ」という根本はもちろんそのままで、その上でもっとさらに「知っておくと便利」な機能を紹介していこうと思います。
具体的には、今回は以下の点を指針とした記事をお届けします:

  • 今まで書けなかった(書くのがめんどくさかった)のが記述もしやすくなり便利になったモノ(新文法・新API等)の紹介。
  • もしくは「それを実現するAPI・関数初めから用意されているよ」という例。
  • ターゲットは記事公開時点の Julia LTS(v1.10.7)。

それでは早速参りましょう!

1. ジャグ配列→多次元配列への変換

ジャグ配列jagged array)とは、簡単に言うと『配列の配列』のことです。配列の要素として配列を持つことで、多次元配列をエミュレートすることができます。

一方で Julia には初めから 多次元配列 が用意されています。後でも補足しますが Julia の多次元配列はメモリ上で連続した領域にデータを格納するため、ジャグ配列と比較するとメモリパフォーマンスにすぐれています。
また文法レベルで多次元配列のリテラル記法が整備されており、ジャグ配列と比較して記述も簡潔になります。

※以下、コードブロック中の #> で始まるコメントは、実行結果(式の値、Julia の REPL での出力例)を示しています。

# ジャグ配列の例(型は `Vector{Vector{Int64}}` となる)
jagged2d = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
#> 3-element Vector{Vector{Int64}}:
#>  [1, 2, 3]
#>  [4, 5, 6]
#>  [7, 8, 9]

# 上記と同等な多次元配列の例(型は `Matrix{Int64}` となる)
array2d = [1; 2; 3;; 4; 5; 6;; 7; 8; 9]
#> 3×3 Matrix{Int64}:
#>  1  4  7
#>  2  5  8
#>  3  6  9

# 3次元ジャグ配列の例(型は `Vector{Vector{Vector{Int64}}}` となる)
jagged3d = [
    [[ 1,  2,  3],
     [ 4,  5,  6],
     [ 7,  8,  9],
     [10, 11, 12]],
    [[13, 14, 15],
     [16, 17, 18],
     [19, 20, 21],
     [22, 23, 24]]
]
#> 2-element Vector{Vector{Vector{Int64}}}:
#>  [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
#>  [[13, 14, 15], [16, 17, 18], [19, 20, 21], [22, 23, 24]]

# 上記と同等な3次元配列の例(型は `Array{Int64, 3}` となる)
array3d = [ 1;  2;  3;;  4;  5;  6;;  7;  8;  9;; 10; 11; 12;;;
           13; 14; 15;; 16; 17; 18;; 19; 20; 21;; 22; 23; 24]
#> 3×4×2 Array{Int64, 3}:
#> [:, :, 1] =
#>  1  4  7  10
#>  2  5  8  11
#>  3  6  9  12
#> 
#> [:, :, 2] =
#>  13  16  19  22
#>  14  17  20  23
#>  15  18  21  24

このジャグ配列と多次元配列を相互に変換する方法が、Julia 標準で用意されています(標準関数の組合せで実現できます)。

ベクトルのベクトル→行列への変換

Julia の stack() 関数は、『(各要素の配列のサイズが等しい)ジャグ配列を多次元配列に変換する(ことに直接利用できる)』関数です。
ベクトルのベクトルを行列(2次元配列)に変換する場合、stack() 関数をそのまま使えばOKです。

stack(jagged2d)
#> 3×3 Matrix{Int64}:
#>  1  4  7
#>  2  5  8
#>  3  6  9

stack(jagged2d) == array2d
#> true

なおひょっとしたら「Julia の方なんか縦軸と横軸入れ替わってない?」と先ほどから疑問に思っている方もいらっしゃるかもしれません。Julia の(多次元)配列は 列指向column-major order)なので、第1軸は行列で言う列になります。一方で Numpy などの多くの言語・ライブラリは 行指向row-major order)なので第1軸は行です。行指向の方が記述と結果が一致して直感的な場合も多いですが、列指向の方が線形代数演算時などに直感的な表現ができる(Julia のベクトル(=1次元配列)は 列ベクトル です)ので Julia では列指向を採用しています。
ただ他言語・ライブラリとの互換性を考えて行指向の結果がほしい場合もあります。これにも stack() 関数は対応しています。dims キーワード引数を使って「積み上げる方向」を指定すればOKです。

# ↓第1軸方向に沿ってスタックする(=縦に積み上げる=行指向)という意味になる
stack(jagged2d, dims=1)
#> 3×3 Matrix{Int64}:
#>  1  2  3
#>  4  5  6
#>  7  8  9

stack(jagged2d, dims=1) == array2d'
#> true

例えば以下のように stack_t(args...) = stack(args..., dims=1) という関数を定義しておけば、行指向の積み上げも簡単に行えます。

stack_t(args...) = stack(args..., dims=1)
#> stack_t (generic function with 1 method)

stack_t(jagged2d)
#> 3×3 Matrix{Int64}:
#>  1  2  3
#>  4  5  6
#>  7  8  9

stack_t(jagged2d) == array2d'
#> true

多段ネストしたジャグ配列 → 多次元配列一発変換

3次元以上のジャグ配列(3つ以上ネストしたジャグ配列)を多次元配列に変換する場合、stack() 関数を再帰的に適用することで実現できます。
ただし多少煩雑になるので、こちらも以下のような関数を定義しておくと便利です。

function stack_nested(@nospecialize(src::AbstractArray{<:AbstractArray{<:AbstractArray}}))
    stack(stack_nested, src)
end
function stack_nested(@nospecialize(src::AbstractArray{<:AbstractArray}))
    stack(src)
end

先ほど挙げた3次元のジャグ配列(ベクトルのベクトルのベクトル)を同等な3次元配列に変換する例を以下に挙げます。

stack_nested(jagged3d)
#> 3×4×2 Array{Int64, 3}:
#> [:, :, 1] =
#>  1  4  7  10
#>  2  5  8  11
#>  3  6  9  12
#> 
#> [:, :, 2] =
#>  13  16  19  22
#>  14  17  20  23
#>  15  18  21  24

stack_nested(jagged3d) == array3d
#> true

念のため4次元以上の例もおいておきます。

# 4次元ジャグ配列の例(型は `Vector{Vector{Vector{Vector{Int64}}}}` となる)
jagged4d = [
    [
        [[ 1,  2,  3],
         [ 4,  5,  6]],
        [[ 7,  8,  9],
         [10, 11, 12]],
    ], [
        [[13, 14, 15],
         [16, 17, 18]],
        [[19, 20, 21],
         [22, 23, 24]]
    ]
];

# 上記と同等な多次元配列の例(型は `Array{Int64, 4}` となる)
array4d = [ 1;  2;  3;;  4;  5;  6;;;
            7;  8;  9;; 10; 11; 12;;;;
           13; 14; 15;; 16; 17; 18;;;
           19; 20; 21;; 22; 23; 24];

stack_nested(jagged4d) == array4d
#> true

多次元ジャグ配列を行指向(他言語・ライブラリと互換性のある形式)に変換したい場合、dims キーワード引数を使う方法もありますが、煩雑になるので stack_nested() で変換した後で permutedims() で次元を入れ替えた方が簡潔になります。

permutedims(stack_nested(jagged2d), (2, 1))
#> 3×3 Matrix{Int64}:
#>  1  2  3
#>  4  5  6
#>  7  8  9

permutedims(stack_nested(jagged2d), (2, 1)) == stack_t(jagged2d) == array2d'
#> true

permutedims(stack_nested(jagged3d), (3, 2, 1))  # 結果省略

permutedims(stack_nested(jagged4d), (4, 3, 2, 1))  # 結果省略

多次元配列 → ジャグ配列への逆変換

行列(2次元配列)をベクトルのベクトル(ジャグ配列)に変換するには、eachcol() 関数を使えばOKです。

eachcol(array2d)
#> 3-element ColumnSlices{Matrix{Int64}, Tuple{Base.OneTo{Int64}}, SubArray{Int64, 1, Matrix{Int64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}:
#>  [1, 2, 3]
#>  [4, 5, 6]
#>  [7, 8, 9]

eachcol(array2d) isa AbstractVector
#> true

eachcol(array2d) isa AbstractVector{<:AbstractVector}
#> true

eachcol(array2d) == jagged2d
#> true

結果の型が ColumnSlices となっており「本当に配列の配列になってる?」と思われるかも知れませんが、少し調べると分かるとおり「配列の配列と互換性のある型」になっており、対応するジャグ配列と等価であることも確認できるので、普通のジャグ配列として扱えるようになっていることが分かりますね。

多次元配列を対応するネストしたジャグ配列に変換するには、eachslice() をこの eachcol() と組み合わせてネストすれば実現可能です。こちらも以下のように関数化してみました。

tojagged(vec::AbstractVector) = vec
tojagged(mat::AbstractMatrix) = eachcol(mat)
tojagged(arr::AbstractArray{<:Any, N}) where N =
    tojagged(eachslice(arr, dims=ntuple(n->n+1, Val(N-1))))
#> tojagged (generic function with 3 methods)

tojagged(array2d) == eachcol(array2d) == jagged2d
#> true

tojagged(array3d) == jagged3d
#> true

tojagged(array4d) == jagged4d
#> true

typeof(tojagged(array4d)) を調べて見るとものすごいことになっている(多次元になればなるほどすごい(語彙力))ですが、実運用上あまり気にする必要はありません(興味があれば調べてみてください)。

あと、こちらも列指向の変換方法であり、行指向で直接変換するには一工夫必要ですが、こちらも先に permutedims() で次元を入れ替えてから tojagged() で変換すれば取り敢えず実現可能です(コード例は割愛します)。

補足

  • ジャグ配列は多次元配列に比べてメモリ効率が良くありません。ジャグ配列は各要素の配列がメモリ上で連続していないため、キャッシュ効率が悪くなります。一方で多次元配列はメモリ上で連続しているため、キャッシュ効率が良くなります。
    • 本記事で紹介した tojagged() 関数で生成されるジャグ配列は、元の多次元配列のビューとしてジャグ配列を表現しているため、メモリ効率は元の多次元配列と同じで高効率となります。
  • ジャグ配列は、各要素の配列のサイズが異なる場合にも対応できます。一方で多次元配列は、各次元のサイズが等しい必要があります。
    • Julia の stack() 関数は(従って本記事で紹介した stack_nested() 関数も)、各要素の配列のサイズが等しいことを前提としています。各要素のサイズが異なるジャグ配列を渡すとエラーとなります。
    • tojagged() 関数で返されるジャグ配列は、各要素の配列のサイズは必然的に等しくなります。
  • Julia の配列は多次元配列をサポートしているため、「多次元配列の多次元配列(例:各要素が行列の行列)」という形のジャグ配列も表現できます。stack() 関数はこれらにも対応しており、例えば行列の行列に対して適用すると4次元配列が得られます(コード例割愛)。
    • 多次元配列から「多次元配列の多次元配列」という形のジャグ配列を生成する場合、eachslice() 関数の dims キーワード引数の指定を工夫すれば実現可能です(例:eachslice(array4d, dims=(3, 4)) isa AbstractMatrix{<:AbstractMatrix})。
  • 本記事で紹介した stack_nested()stack() をネスト呼び出しして配列を生成しており、ネストが深く(=次元数が大きく)なればなるほど余分なメモリアロケーションが発生します(つまり決して高パフォーマンスとは限りません)。(同じく本記事で紹介した)tojagged() 関数のように「ジャグ配列を多次元配列のように見せるビュー」が定義できればこの問題を回避できます。本記事の範囲を超えるので実装例は割愛しますが、手元で簡単な実装は成功しているので反響があれば別途記事にするかもしれません。

2. 続きから列挙再開できたり巻き戻しができるイテレータ

最初に他言語の例を出しますが、例えば Ruby の Enumerator というクラスは、所謂「外部イテレータ」を提供するクラスで、通常のように .each メソッドでイテレーションを進める他に、.peek メソッドで次の要素を取得(参照)することができます。また、.next メソッドでイテレーションを進めることもできます。さらに、.rewind メソッドでイテレーションを最初に戻すこともできます。

※以下の Ruby コードブロック中の # => で始まるコメントは、実行結果(式の値、Ruby の REPL である irbなどの出力例)を示しています。

require 'prime'
# => true

primes_iterator = Prime.to_enum  # 素数を列挙する外部イテレータを生成
# => #<Enumerator: ...>

primes_iterator.peek  # `.peek` はイテレーションを進めない
# => 2
primes_iterator.peek
# => 2

primes_iterator.next  # `.next` はイテレーションを進める
# => 2
primes_iterator.next
# => 3
primes_iterator.next
# => 5
primes_iterator.peek
# => 7

primes_iterator.rewind  # 巻き戻し
# => #<Enumerator: ...>

primes_iterator.next  # 最初に戻っている
# => 2
primes_iterator.next
# => 3
primes_iterator.peek
# => 5

# `.peek`,`.next` 以外の一般的な `Enumerable` のメソッドは影響受けない(最初から列挙する)
primes_iterator.take 5
# => [2, 3, 5, 7, 11]

これに対して Julia の通常のイテレータ(※1)は、この Ruby の Enumerator のような、.peek/.next/.rewind のような機能はなく、『常に最初から列挙』という挙動になります。

※以下、コードブロック中の ## で始まるコメントは、標準出力への出力結果を示しています。#> で始まるものは先ほどと同様実行結果(式の値)を示しています。

# ]add Primes  # 外部パッケージなので事前に追加してください。
using Primes

primes_iterator = nextprimes(1);  # 1より大きい最初の素数(=2)から列挙するイテレータを生成

@show first(primes_iterator)
## first(primes_iterator) = 2
@show first(primes_iterator)
## first(primes_iterator) = 2

@show Iterators.take(primes_iterator, 5) |> collect
## Iterators.take(primes_iterator, 5) |> collect = [2, 3, 5, 7, 11]
@show Iterators.take(primes_iterator, 5) |> collect
## Iterators.take(primes_iterator, 5) |> collect = [2, 3, 5, 7, 11]

これを、Ruby の Enumerator のように「最初の要素を取得(参照)する」「最初の要素を取り出してイテレーションを進める」「巻き戻す」という機能を提供するイテレータに変換する機能が、実は Julia 標準で用意されています。

Iterators.Stateful の基本

Julia には標準で Iterators.Stateful という型が用意されています。これは名称が示すとおり「状態を保持するイテレータ」です。以下に具体的な例を示します。

primes_iterator_s = Iterators.Stateful(primes_iterator);  # 状態を持つイテレータに変換

# `peek()` は次の要素を取得(するだけでイテレーションは進めない)
# Ruby の `.peek` メソッド相当
@show peek(primes_iterator_s)
## peek(primes_iterator_s) = 2
@show peek(primes_iterator_s)
## peek(primes_iterator_s) = 2

# `popfirst!()` は次の要素を取得し、イテレーションを進める
# Ruby の `.next` メソッド相当の動きになる
@show popfirst!(primes_iterator_s)
## popfirst!(primes_iterator_s) = 2
@show popfirst!(primes_iterator_s)
## popfirst!(primes_iterator_s) = 3
@show popfirst!(primes_iterator_s)
## popfirst!(primes_iterator_s) = 5

# イテレーションをリセット(巻き戻し)する(Ruby の `.rewind` メソッド相当)
Iterators.reset!(primes_iterator_s);
@show popfirst!(primes_iterator_s)
## popfirst!(primes_iterator_s) = 2
@show peek(primes_iterator_s)
## peek(primes_iterator_s) = 3

# その他の一般的なイテレーション関連関数も「続きから」
@show Iterators.take(primes_iterator_s, 5) |> collect
## terators.take(primes_iterator_s, 5) |> collect = [3, 5, 7, 11, 13]
@show peek(primes_iterator_s)  # その次の要素
## peek(primes_iterator_s) = 17

Julia の Iterators.Stateful は、イテレータをラップして 状態を持つイテレータ に変換します。peek() は次の要素を 取得(するだけでイテレーションは進めない)します。popfirst!() は次の要素を 取り出し ます(取得し、イテレーションを進めます)。Iterators.reset!() はイテレーションを リセット(巻き戻し) します。
Ruby の Enumerator に似ていますが、大きく違うのは、「それ以外のイテレーション操作の扱い」です。Ruby の Enumerator.peek/.next/.rewind 以外のイテレーション操作は「状態」の影響を受けず常に最初から列挙しますが、Julia の Iterators.Stateful は「状態」を完全に保持していて popfirst!() などを実行するとその後のイテレーション関連のどの操作も「次の状態」からの再開となる点です(Stateful という名称様々です)。

Iterators.Stateful の応用例

Iterators.Stateful の「peek() で次の要素を取得するけれどイテレーションを進めない」という挙動は結構便利で、それを活かした例をいくつか考えてみました。

1. 2つのイテレータをマージするイテレータ

まずは以下のコードを見てください。

function merge_stream(itr1::I1, itr2::I2) where {I1, I2}
    T1 = eltype(I1)
    T2 = eltype(I2)
    T = promote_type(T1, T2)
    # 状態を持つイテレータに変換
    sitr1 = Iterators.Stateful(itr1)
    sitr2 = Iterators.Stateful(itr2)
    Channel{T}() do chnl
        while !Base.isdone(sitr1) && !Base.isdone(sitr2)
            # itr1 と itr2 の最初の要素のうち小さい方を列挙(等価なら1つにまとめる)
            v1 = peek(sitr1)
            v2 = peek(sitr2)
            v = T(min(v1, v2))
            isequal(v, v1) && popfirst!(sitr1)
            isequal(v, v2) && popfirst!(sitr2)
            put!(chnl, v)
        end
        for v in sitr1
            # itr1 の残りを全部列挙
            put!(chnl, T(v))
        end
        for v in sitr2
            # itr2 の残りを全部列挙
            put!(chnl, T(v))
        end
    end
end

これは「2つのイテレータをマージするイテレータ(を生成する関数)」です。ただし入力となるイテレータは『要素が比較可能』でかついずれも『昇順にソートされている』ことを前提としています。そのようなイテレータ同士なら、「両方のイテレータの要素を列挙する(同じ要素があったら1つにまとめる)」という挙動になります。
具体例は以下のようになります。

# 3 または 5 の倍数を列挙(有限版、標準出力に出力)
for n in merge_stream(0:3:30, 0:5:30)
	println(n)
end
## 0
## 3
## 5
## 6
## 9
## 10
## 12
## 15
## 18
## 20
## 21
## 24
## 25
## 27
## 30

# 3 の倍数、または数字の 3 のつく数字を列(無限版、先頭から25個を抽出)
Iterators.take(merge_stream(
    Iterators.filter(x->x%3==0, Iterators.countfrom(1)),
    Iterators.filter(x->3∈digits(x), Iterators.countfrom(1))
), 25) |> collect
#> [3, 6, 9, 12, 13, 15, 18, 21, 23, 24, 27, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 42, 43, 45, 48]

つまりこれは 演算子(和集合演算子、union() 関数と同じ)の挙動を一般のイテレータに適用したもの、と言えます。
同様に『積集合』を列挙するイテレータ( 演算子、intersect() 関数と同様の挙動)、『差集合』を列挙するイテレータ(setdiff() 関数の挙動)なども同様に実装できます。

2. ハミング数を無限列挙するイテレータ

少しトリッキーで高度な例として、ハミング数(2, 3, 5 のみを素因数とする正の整数)を無限に列挙するイテレータを実装してみました。

struct HammingNumbers end

const hammingnumbers = HammingNumbers()

Base.IteratorSize(::Type{HammingNumbers}) = Base.IsInfinite()

Base.eltype(::Type{HammingNumbers}) = Int

Base.iterate(::HammingNumbers) = (1, 1)
Base.iterate(::HammingNumbers, prev::Int) = iterate(hammingnumbers, (prev, 
    Iterators.Stateful(hammingnumbers), 
    Iterators.Stateful(hammingnumbers), 
    Iterators.Stateful(hammingnumbers), 
))
function Base.iterate(::HammingNumbers, (prev, mul2, mul3, mul5)::Tuple{Int, Vararg{Iterators.Stateful{HammingNumbers}, 3}})
    n2 = 2 * peek(mul2)
    while n2 <= prev; popfirst!(mul2); n2 = 2 * peek(mul2); end
    n3 = 3 * peek(mul3)
    while n3 <= prev; popfirst!(mul3); n3 = 3 * peek(mul3); end
    n5 = 5 * peek(mul5)
    while n5 <= prev; popfirst!(mul5); n5 = 5 * peek(mul5); end
    n = min(n2, n3, n5)
    (n, (n, mul2, mul3, mul5))
end

先ほどは Channel を使って1つの関数で実装していましたが、今回は専用の型を設計(ただしシングルトン型)し、Base.iterate() メソッドを実装することでイテレータとして実装しています。ただコアな部分は先ほどの merge_stream() とよく似ているので解読してみてください。
動作確認結果もご覧ください。きちんと期待通りに動作していることが分かります。

Iterators.take(hammingnumbers, 16) .|> println
## 1
## 2
## 3
## 4
## 5
## 6
## 8
## 9
## 10
## 12
## 15
## 16
## 18
## 20
## 24
## 25

collect(Iterators.takewhile(<(100), hammingnumbers))
#> [1, 2, 3, 4, 5, 6, 8, 9, 10, 12, 15, 16, ...《中略》..., 54, 60, 64, 72, 75, 80, 81, 90, 96]

# 100番目のハミング数
Iterators.drop(hammingnumbers, 99) |> first
#> 1536

3. 2年前の『Y/A findXXX() 系関数』

2年前のアドベントカレンダーの シチュエーション2: Y/A findXXX() 系関数 節で meetfirst() という関数を実装し紹介しました。これは「条件に合致する最初の要素を取得(なければ nothing)」というものでした。
これを Iterators.Stateful を使って実装しなおしてみました。

# 第1引数で条件判定して、第2引数のコレクションで最初に合致する要素を返す
# (なければ `nothing`) ver.3
function meetfirst_v3(pred::Function, itr)
    peek(Iterators.Stateful(Iterators.dropwhile(!pred, itr)))
end

2年前は色々駆使していたのが、なんとたった1行で実装出来てしまいました。これは peek() が「状態を保持するイテレータの最初の要素を取得する、ただし 次の要素がない場合(=イテレーションが終了していた場合)は nothing を返す」という挙動を利用したものとなっています。
念のため動作確認してみましょう。

a = [314, 159, 265, 358, 979, 323, 846, 264];

@show meetfirst_v3(n -> n % 11 == 0, a)
## meetfirst_v3((n->begin
##             #= REPL[31]:1 =#
##             n % 11 == 0
##         end), a) = 979
@show meetfirst_v3(n -> n % 7 == 0, a)
## meetfirst_v3((n->begin
##             #= REPL[32]:1 =#
##             n % 7 == 0
##         end), a) = nothing

# コラッツ数列を列挙するイテレータ(`Channel`)を返す関数
function collatz(n::Int)
    Channel{Int}() do chnl
        put!(chnl, n)
        while n > 1
            n = iseven(n) ? n ÷ 2 : 3n + 1
            put!(chnl, n)
        end
    end
end

@show meetfirst_v3((200), collatz(27))
## meetfirst_v3((≥)(200), collatz(27)) = 214
@show meetfirst_v3((50), collatz(3))
## meetfirst_v3((≥)(50), collatz(3)) = nothing

きちんと期待通りに動作(2年前の記事と同じ結果)となっていますね。ベンチマークも見てみましょう。
(注:以下のコードを実行するには2年前の記事のコードをコピー(meetfirst() 関数は meetfirst_v2() に置換)する必要があります)

using BenchmarkTools, Random

Random.seed!(1234);
@benchmark meetfirst_v1((200), c) setup=(c=collatz(rand(3:100)))  # v.1: `try~catch~end` を利用した実装

Random.seed!(1234);
@benchmark meetfirst_v2((200), c) setup=(c=collatz(rand(3:100)))  # v.2: `Base.iterate()` を利用した高パフォーマンス実装

Random.seed!(1234);
@benchmark meetfirst_v3((200), c) setup=(c=collatz(rand(3:100)))  # v.3: 今回の実装

 関数のパフォーマンス計測(ベンチマーク)結果

前回の v.2(高パフォーマンスな実装)と同等の結果となりました。これで良かったんや!

補足

  • Iterators.Stateful は、正確には『状態を持たないイテレータ(ステートレスなイテレータ) を状態を持つイテレータ(ステートフルなイテレータ)に変換するもの』です。なおこの節の最初の方(※1)で「通常のイテレータ」という言葉を使いましたが、これは ステートレスなイテレータ を意図しています。それに対して元々ステートフルな(状態を内部管理している)イテレータも Julia には存在し、それらに対して Iterators.Stateful は、期待通りの挙動とならないことがあります。
    • Channel はステートフルなイテレータであり、これを Iterators.Stateful でラップすると、peek()popfirst!() は期待通りの動作をしますが、Iterators.reset!() は無効となり、イテレーションはリセットされず引き続きその続きからイテレーション再開するという挙動となります。
    • Iterators.Stateful で生成されたイテレータを Iterators.Stateful でラップしようとすると、ラッピングではなくそれ自身が返されるような実装となっている模様です。同じオブジェクトなので、それに対して popfirst!()reset!() やその他状態の変化を伴うイテレーション操作を行うと、元のイテレータにも反映されます。
  • 毎回いちいち「Iterators.」というモジュールプレフィックスを付けていることからも分かるとおり、この型は export もされておらず public なものではありません(少なくとも Julia v1.11 の時点までは)。ですので将来的には、この型の挙動が変わったり、もしくはこの型そのものが削除されてしまう可能性があります。
    • 十分便利で有用なので、将来的には public 扱いになってくれないかな、と期待しています。

3. invmod()

Julia はイテレーション(通常の for 文(for 式)による繰り返し処理)が十分に高速で、大量のデータをイテレーションして必要なデータを探したり抽出したりといったことも「for を利用したナイーブな実装」で十分短時間で結果が求まることもあります。

例えば「ある整数 nmgcd(n, m) == 1 と仮定)が与えられて、mod(n * x, m) == 1 となる x を求めたい」という要件があるとします。もちろんこの問題を解くのに、for を使って x を1つずつ調べれば結果は出すことができます。

let n=3, m=257
    for x in 1:m-1
        if mod(n * x, m) == 1
            println(x)
            break
        end
    end
end
## 86

でもこんなコード書く必要はありません!
Julia には invmod() という関数があります。これは「nm を法とする逆数(=mod(n * x, m) == 1 となる x)」を求める関数です。

invmod(3, 257)
#> 86

もちろん十分に高速です。何せ直接そのような数を算出するような実装になっているので。

function my_invmod(n, m)
    for x in 1:m-1
        mod(n * x, m) == 1 && return x
    end
    throw(DomainError((n, m)))
end

all(my_invmod(n-1, n) == n-1 for n in 2:100)
#> true

all(my_invmod(n, m) == invmod(n, m) for m in 2:100 for n in 1:m if gcd(n, m) == 1)
#> true

using BenchmarkTools, Random

Random.seed!(1234)
@benchmark my_invmod(n-1, n) setup=(n=rand(2:1000000))

Random.seed!(1234)
@benchmark invmod(n-1, n) setup=(n=rand(2:1000000))

 のベンチマーク

応用として、「mod(d * x, m) == mod(n, m) となるような x を求める関数(m を法とした n ÷ d という除算をする関数)」も以下のように書けます。

function divmod(n, d, m)
    mod(n * invmod(d, m), m)
end

divmod.(1:10, 3, 7)
#> [5, 3, 1, 6, 4, 2, 0, 5, 3, 1]

補足

  • Julia v1.11 からは、引数1つの invmod(n)、および第2引数に(整数)型を受け取る invmod(n, T) が追加されました。invmod(n, T) は、整数型 T のビット数を N としたときに invmod(n, 2^N) を計算する関数(メソッド)、invmod(n) == invmod(n, typeof(n)) です。n が奇数でないとエラーになります。また T として BigInt(多倍長整数)は指定できません。
    • 例:invmod(3) == invmod(3, Int64) == -6148914691236517205
    • 例:invmod(0x03) == invmod(3, UInt8) == 0xab

関連する話題

gcdx(n, m)

Julia の gcd(n, m) という関数は n, m の最大公約数を返す関数です。
Julia には gcdx(n, m) という関数もあります。戻り値は (g, x, y) の3値タプルで、g は最大公約数、x, yn * x + m * y == g を満たす 数値(代表値)です(ベズー係数 と言います)(例:gcdx(8, 6) == (2, 1, -1)8 * 1 + 6 * (-1) == 2)。
gcd() および gcdx() はともにユークリッドの互除法を利用して実装されています。

【2024.12.25 追記】補足:コメントで指摘がありましたが gcd() の方はより高速な Stein's algorithm で実装されたメソッドもあり、64bit/128bit の整数型の場合はそちらが利用されるようになっています。

実は invmod(n, m)gcdx(n, m) を使って実装されています。ざっくり言うと、(g, x, y) = gcdx(n, m) としたとき、g == 1 ならば mod(x, m) を返すような実装です(どうしてそうなのかは、数式から読み解いてください)。
(なお g ≠ 1 の場合はエラーとなります)。

実は先ほどの divmod(n, d, m)gcdx(d, m) を使って実装することもできます。なおこの実装は、g = gcd(d, m) としたとき g == 1 でなくてもなんらかの値が算出される(n % g == 0 なら正しい結果が得られる)よう拡張したものとなっています。

function divmod_v2(n, d, m)
    g, x, _y = gcdx(d, m)  # _y は不使用
    g == 1 && return mod(n * x, m)
    mod(n * x ÷ g, m)
end

divmod_v2.(1:10, 3, 7)
#> [5, 3, 1, 6, 4, 2, 0, 5, 3, 1]

divmod_v2.(1:15, 4, 6)
#> [0, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0, 5, 5]

powermod(n, p, m)

似たような関数に powermod(n, p, m) というものがあります。これはざっくり言うと mod(n^p, m)np乗をmで割った余り)を計算する関数です。ただし普通に書くとオーバーフローするような大きな数値でも計算できるようになっています。

また p < 0 の場合(普通に計算すると 0 < n ^ p < 1 となってしまうような場合)にも拡張されており、その場合 powermod(invmod(n, m), -p, m) を計算(つまり再帰呼び出し)するような実装になっています(実際にはもう少ししっかり場合分けをしています、詳細は割愛します)。

まとめに変えて

  • Julia 楽しいよ!(まとめになってない)
  • そして良いお年を!

参考リンク

来栖川電算

Discussion

YTOKYTOK

細かいツッコミなのですが、gcdのほうはStein's algorithmではないでしょうか。
https://github.com/JuliaLang/julia/blob/d386e40c17d43b79fc89d3e579fc04547241787c/base/intfuncs.jl#L28-L49

antimon2antimon2

ツッコミありがとうございます!
確かに整数の型が64bitまたは128bit(Int64, UInt64, Int128, UInt128)の時はこちらのメソッド(Stein's algorithm)の方が選択されますね(あまり意識してなかった)。
補足しておきます。