🔢

最強の機械学習フレームワークを作りたい その 1 「何がしたいか、Lean 4 の基礎とともに」

2024/12/16に公開

はじめに

PyTorch という機械学習フレームワークがあり、弊社でもよくお世話になっています。
しかし、PyTorch で開発をしていると、以下のような事態に陥ることがあります:

テンソル形状で不整合を起こすコードを書いてしまい、実行後、それも短くない時間経過ののちにエラーによってそのことに気づく

常に詳細なドキュメンテーションを心がけることでこうした事態をある程度防ぐことはできますが、できればこういうのはシステマチックに検出したいものです。

PyTorch の枠組みでそのような検出をすることは難しいです。なぜなら、テンソルは torch.Tensor という型を持っているのみで、その形状について静的には知ることができないからです。これは PyTorch が悪いというより、そもそも Python では数値の情報を型に乗せることはできても、その情報の取り回しが十分柔軟ではないことに起因します。

そこで、Lean 4 のような強力な依存型システムを持つ言語で機械学習フレームワークを構築できないかと考えました。以下のようなことが期待されます:

  • テンソル形状の不整合のようなエラーを型検査によって検出できる。
  • 定理証明支援系としての機能を活かし、モデルの持つ性質をコードに埋め込める。

もちろん、書きやすいフレームワーク設計や、GPU サポートのような性能面の課題を考えると、実現は容易ではありません。しかしともかく、まずはどういうことができたらいいと考えているのかを、Lean 4 の基礎と簡単な例とともに述べていきたいと思います。読んでいて不明瞭に感じる点は、適宜各種ドキュメントを参照してください。本記事のコードは Lean 4 Web にて試すことができます。

なお、本記事の内容はすべて筆者の個人的興味に基づいており、社を挙げて取り組む予定は今のところ全くないことに注意してください。

まずは行列積

早速ですが、Lean 4 では行列積を以下のような関数として記述できます:

import Mathlib.Data.Matrix.Basic

def matMulFloat {m n p : Nat}
  (left : Matrix (Fin m) (Fin n) Float)
  (right : Matrix (Fin n) (Fin p) Float)
  : Matrix (Fin m) (Fin p) Float :=
  fun i j =>
    (List.finRange n).foldl (fun acc k => acc + left i k * right k j) 0.0

順を追って見ていきましょう:

  • Lean 4 には Mathlib という、種々の数学的概念とそれにまつわる定理証明を提供するライブラリがあり、その中には行列が含まれています。これをインポートします。
  • def 構文によって、関数 matMulFloat を定義します。
    • {m n p : Nat} は暗黙の引数というやつで、非負整数 (Nat) 型の値 m, n, p を受け取ることを表しています。これらはのちの明示的な引数によって推論されます。
    • (left : Matrix (Fin m) (Fin n) Float) は引数で、m×n で各成分が Float の行列型の値 left を受け取ることを表しています。Fin mm 未満の非負整数値の集合で、このように、どういうものが添字になるかということによって間接的に形状を指定します。こういうゴリゴリの関数型言語に慣れていない方には読みにくいかもしれませんが、Matrix は型を返すカリー化された関数であり、関数適用自体は括弧を伴いません。
    • (right : Matrix (Fin n) (Fin p) Float) は引数で、n×p で各成分が Float の行列型の値 right を受け取ることを表しています。
    • : Matrix (Fin m) (Fin p) Float は返り値の型で、m×p で各成分が Float の行列型の値を返すことを表しています。
    • := 以降が関数の本体です。
      • 行列値の実体は、添字を受け取って成分を返す関数です。このため、fun i j => という構文で、添字 i, j を受け取る無名関数を作っています。
      • (List.finRange n).foldl (fun acc k => acc + left i k * right k j) 0.0 によって行列の i, j 成分を計算します。いかにも関数型チックな書き方になっていますが、要は 0.0 をスタートとして、List.finRange n0 から n - 1 までのリストを生成し、foldl によって加算を行っています。

これは以下のように使用することができます:

import Mathlib.Data.Matrix.Notation

def a :=
  !![
    1.0, 2.0, 3.0;
    4.0, 5.0, 6.0;
  ]

def b :=
  !![
    1.0, 2.0;
    3.0, 4.0;
    5.0, 6.0;
  ]

#eval matMulFloat a b

順を追って見ていきましょう:

  • 後述の行列の表記法をインポートします。

  • 変数 a を定義します。Lean 4 では、変数も関数と同様に def で定義します。変数が無引数の関数であるかのように思うことができます。

    • !![...] によって、決まった成分を持つ行列を簡単に表記することができます。
  • 変数 b を定義します。

  • #eval コマンドによって、式 matMulFloat a b を評価します。以下のような評価結果が得られ、行列積が正しく計算できていることがわかります:

    !![22.000000, 28.000000; 49.000000, 64.000000]
    

さて、これの何が嬉しいのでしょうか。実は、行列型がその形状の情報を含んでいて、matMulFloat がその情報をシグネチャで使っているため、不整合を静的に検知することができます。以下のコードを追加すると、型エラーになります:

def c :=
  !![
    1.0, 2.0;
    3.0, 4.0;
  ]

#eval matMulFloat a c

行列積では、左行列の列数と右行列の行数が一致していなければなりませんが、a2×3 行列、c2×2 行列です。これは計算不可能であり、型エラーが出るのは正しい振る舞いだと言えます。さらに、返り値の型にも形状情報が載っているため、後続の処理でもその恩恵が得られます。

すでに嬉しい雰囲気が漂ってきたと思いますが、ここまでは Python でもギリ可能な範囲です。

ちょっと踏み込んで、カーネル畳み込み

今度は、CNN などで用いられるカーネル畳み込みを考えましょう。形状の観点ではより複雑になり、Lean 4 を使用する利点がよく顕れます。ついでに、定理証明を記述するとはどういうことかも見ることができます。

以下の関数が本記事の本命です:

def conv2D {height_signal width_signal height_kernel width_kernel : Nat}
  (signal : Matrix (Fin height_signal) (Fin width_signal) Float)
  (kernel : Matrix (Fin height_kernel) (Fin width_kernel) Float)
  (rel_height : height_signal ≥ height_kernel)
  (rel_width : width_signal ≥ width_kernel)
  : Matrix (Fin (height_signal - height_kernel + 1)) (Fin (width_signal - width_kernel + 1)) Float :=
  fun i j =>
    (List.finRange height_kernel).foldl (fun acc₁ u =>
      (List.finRange width_kernel).foldl (fun acc₂ v =>
        acc₂ + signal (Fin.mk (i.val + u.val) (by
          have h₀ := Nat.le_of_succ_le_succ (Nat.succ_le_of_lt i.isLt)
          have h₁ := Nat.succ_le_of_lt u.isLt
          have h₂ := Nat.add_le_add h₀ h₁
          rw [Nat.add_succ, Nat.sub_add_cancel rel_height] at h₂
          exact Nat.succ_le.mp h₂
        )) (Fin.mk (j.val + v.val) (by
          have h₀ := Nat.le_of_succ_le_succ (Nat.succ_le_of_lt j.isLt)
          have h₁ := Nat.succ_le_of_lt v.isLt
          have h₂ := Nat.add_le_add h₀ h₁
          rw [Nat.add_succ, Nat.sub_add_cancel rel_width] at h₂
          exact Nat.succ_le.mp h₂
        )) * kernel u v
      ) acc₁
    ) 0.0

すべてを説明していると長くなるので、面白みのある部分をピックアップします:

  • (rel_height : height_signal ≥ height_kernel) は引数で、height_signal ≥ height_kernel であるという事実 rel_height を受け取ることを表しています。型による定理証明の世界では、命題が型、事実が値に対応します。拙記事も参考にしてください。
  • : Matrix (Fin (height_signal - height_kernel + 1)) (Fin (width_signal - width_kernel + 1)) Float は返り値の型で、(height_signal - height_kernel + 1)×(width_signal - width_kernel + 1) で各成分が Float の行列の型の値を返すことを表しています。このように、当然かの如く型パラメータの上で算術演算を行うことができます。
  • Fin.mk は、非負整数値と、それが別の非負整数値 n 未満であるという事実から、Fin n 型の値を生成する関数です。したがって by 以下が、i.val + u.val < height_signal であることの証明になっています。ここで、.valFin n 値から n 未満の非負整数値を得るメソッドです。一つ目の by について、以下の流れで証明が行われています:
    • i.val ≤ height_signal - height_kernel であるという事実を h₀ と置きます。
    • u.val.succ ≤ height_kernel であるという事実を h₁ と置きます。ここで、.succ は非負整数値に 1 を加えるメソッドです。
    • 辺々足して、i.val + u.val.succ ≤ height_signal - height_kernel + height_kernel であるという事実を得、h₂ と置きます。
    • h₂(i.val + u.val).succ ≤ height_signal という形に書き換えます。
    • h₂ から i.val + u.val < height_signal であるという事実を得て、証明終了とします。

これは以下のように使用することができます:

def d :=
  !![
    1.0, 2.0, 3.0;
    4.0, 5.0, 6.0;
    7.0, 8.0, 9.0;
  ]

def e :=
  !![
    1.0, 2.0;
    3.0, 4.0;
  ]

#eval conv2D d e (by trivial) (by trivial)
  • #eval コマンドによって、式 conv2D d e (by trivial) (by trivial) を評価します。以下のような評価結果が得られ、カーネル畳み込みが正しく計算できていることがわかります:

    !![37.000000, 47.000000; 67.000000, 77.000000]
    
  • conv2D の第三、第四引数は、信号画像の縦横のサイズがカーネルの縦横のサイズ以上であるという事実です。これは trivial によって証明されています。このように、自明なものは自明と書くだけで察してくれたりするのも Lean 4 のよいところです。

以下のコードを追加すると、ちゃんと型エラーになります:

def f :=
  !![
    1.0, 2.0;
    3.0, 4.0;
    5.0, 6.0;
    7.0, 8.0;
  ]

#eval conv2D d f (by trivial) (by trivial)

いかがでしょうか。カーネル畳み込みの結果の形状を得るには、信号画像の形状とカーネルの形状を元に算術演算を行う必要があり、これを型で表現し切るのは多くの言語では難しいですが、Lean 4 なら可能なのです。

まとめ

本記事では、筆者が考える最強の機械学習フレームワークを Lean 4 で作るための基礎を述べました。
まだまだ全然妄想の段階ですが、今後も暇があればブラッシュアップしていきたいです。

mutex Official Tech Blog

Discussion