NumPyro について、確率モデルを扱いたい理由を再確認したあとに動かしてみる
この記事はTensorFlow Advent Calendar 2020の6日目の記事です。
概要
この記事では NumPyro について扱います。NumPyro は確率的プログラミングを行うためのフレームワークの1つで、バックエンドに JAX を使っていることが特徴的です。この時点で次のような疑問が生まれるでしょう。
- そもそもなんで確率なの?サイコロを投げるの?
- 確率的プログラミングとは?
- JAXって?
これにできる限り真正面から答えようというのがこの記事の目的です。まず確率モデルを導入する理由について述べます。次に、確率的プログラミングが扱う課題について述べます。その後、 NumPyro に関係する技術である Pyro や JAX について確認したあとに NumPyro について触れます。
このような構成のため、それぞれの構成要素について深くは触れません。また、ベイズ推論や機械学習に関する基本的な知識は前提とします。具体的には次のとおりです。
- 仕事ではじめる機械学習と機械学習図鑑 の内容については既知
- はじめてのパターン認識とデータ解析のための統計モデリング入門は頑張れば読めるが本棚に積んでいる
- TensorFlow のチュートリアルと100 numpy exercise の前半 は手を出したことがある
できるだけ前提知識がなくても雰囲気で読めるように記述しますが、上記の知識があるとより楽しんでいただけると思います。
確率モデルについて
ここではまず、確率モデルを扱うモチベーションについて確認していきましょう。機械学習があるんだからデータからよろしく結果を推論してくれそうで、これ以上データの活用において何が必要なんだと思われるかもしれませんが (実際はそこまで持ち込むことがすごく大変なのですが)、実際は機械学習と確率モデルは切っても切り離せない関係にあります。
たとえば、機械学習の分野のさまざまなアルゴリズムを使いやすくまとめているライブラリであるscikit-learn に含まれているアルゴリズムのうち、特に Latent Direchlet Allocation (LDA) はベイズ推定の手法でモデリングされており、次のようなグラフィカルモデルで表されます。
(図はLatent Direchlet Allocation (LDA)から引用)
LDAは逆転オセロニアでも活用されているアルゴリズムです。活用方法については次のスライドに記述されています。
このほかにもリコメンドやロジスティック回帰、ニューラルネットワークなどもベイズ推定の手法でモデリングできます。これらについてはベイズ推論による機械学習入門の後半で触れられています。その一部を勉強会用にまとめた資料がこちらです。
このように、ベイズ推定の手法を用いた確率モデルはライブラリに隠蔽されているものの、すでに使われていると考えても良いでしょう。では、なぜベイズ推定の手法を用いた確率モデルを用いると何がよいのでしょうか。個人的な見解ですが、次の3点が挙げられると思います。
- 外部の知識の注入が可能
- 因果関係の考慮が可能
- 不確実性の考慮が可能
これらについて、それぞれ見ていきましょう。
外部の知識の注入
機械学習において典型的なアプローチは、すでに存在するモデルをデータを用いて訓練するものです。ここでは、過去のデータが現在のデータと似ているという仮定が暗黙的になされています。このため、過去のデータが現在のデータと似ていることを保証できないケースでは別のアプローチが必要です。
このような状況が発生しているのが COVID-19 の感染拡大です。この感染症は人類にとって未知のものであり、また、社会環境も流動的に対応しているため、過去のデータが現在も利用できる保証は全くありません。たとえば、多数のデモが発生していたアメリカの感染拡大状況のデータをもって、政府の統制が強い中国の感染拡大状況を予測するのはナンセンスですし、同一の国でもロックダウン中のデータをもってロックダウンしていないときの感染拡大状況を予測するのはナンセンスです。このため、予測を行うためには何らかの追加の知識が必要になります。
データに対する追加の知識として数理モデルを導入することができます。COVID-19 の場合、未知のものではあるにしろ、感染症であることは間違いないため、感染症に関する感染拡大の数理モデルを用いることができます。最も単純なものは SIRモデル でしょう。SIRモデルについてはすでに解説があるためこちらを参照ください。
これを拡張したモデルが実行再生産数の算出に使われたことはまだ記憶に新しいでしょう。西浦先生による実行再生算数のモデルのコードはこちらから確認できます。
解説はこちらを参照ください。
モデルでは感染症の広がり方について確率モデルを立て、その中のパラメーターをデータから推定しています。このようにすることで、感染症に関する理論的な知識をモデルに持たせられます。Google が感染拡大について予測を公表していますが、そこでも同様のアプローチが取られています。モデルの概念図をホワイトペーパーから引用します。
こちらも同様に、曝露群や発病者、回復、入院、死亡といった一連の感染における状態遷移が定義され、それが確率モデルとして定義されています。確率モデルの中のパラメーターは機械学習モデルを用いていますが、データからパラメーターを求めている点もほぼ同様です。
ここまでで見てきたように、確率モデルを用いることで外部の知識をモデルに導入できることがわかりました。このようにすることで、過去のデータをよりよく用いた予測が可能になると言えるでしょう。
因果関係の考慮
確率モデルを導入することで、因果関係の考慮ができます。本節は少し専門的になります。また、結論は前節の「確率モデルを用いることで、外部の知識をモデルに導入できる」とほぼ同じなので読み飛ばしていただいても構いません。
ここからしばらく述べる内容は、過去にスライドに起こしていますのでそちらも参照ください。
RCT による介入効果の推定は一般に行われていますが(A/Bテストと呼ばれるもの)、それを超えて「一体何をやるのが重要なのか」に答えるためには単一の実験だけではなく、それを繰り返して何を変えることが実験の効果を大きく変えるのかを解析する必要があります。このための手法に分散分析があります。
分散分析では効果を検証したい変数を要因、変数の取る具体的な数値や設定を水準として整理し、どの要因が実験全体の分散に大きく寄与するのかを分析します。また、それぞれの要因の交互作用が分散に与える影響についても分析できます。このあたりについて平易に解説したサイトに「ハンバーガー統計学」があります。非常にわかりやすくおすすめです。
分散分析により、要因の変動が実験結果の分散を説明する度合いはわかりましたが、実際には要因の背後に共通した因子があり、その因子を発見することが重要な場合があります。たとえば、アンケート調査において各回答者の属性を10通り収集したものの、それを更に要約した4つの属性を抽出したい場合がそれに該当します。実際の例は 日経リサーチの因子分析の解説ページ を参照ください。作成されるアウトプットのイメージについて、同サイトから引用します。
一方、これはアンケート作成時の業務フローから考えると多少不自然ではあります。アンケートでは取得したいなんらかの属性をもとに、それに関連する質問文を必要なだけ考えてアンケートをとり、目的とする変数との関係を分析するでしょう。このように、それぞれの因子の背後にある要因が予め想定できるケースがあります。
背後にある要因をあらかじめ設定する分析方法が共分散構造解析です。この手法は因子と各要因の間に線形性を仮定した連立方程式を建てるため、構造方程式モデリング (SEM) とも呼ばれます。実際の分析例はJCSIによる顧客満足モデルの構築 (PDF) にあります。得られる結果について引用します。
構造方程式モデリングは、確率変数を正規分布に、確率変数間の関係を線形関数のみにした階層ベイズモデルとみなせます。筆者は当時の状況に詳しくないのですが、構造方程式モデリングにより因果関係が把握できることを期待する記述が見られます。これは最近話題になる因果推論の分野で扱われる問題です。構造方程式モデリングと因果推論の関係についてもう少し考えてみましょう。
因果関係の把握を行うための分野に因果推論があります。因果推論にはいくつかの流儀があり、それぞれに関連がありますが、ここでは Judea Pearl による Directed Acyclic Graph(DAG) の記述によるものを考えます。DAG の導入とその記述方法については解説資料がさまざまにあります。非専門家向けとしての資料は次が読みやすいと思います。
因果推論の手法の一つである、IPW法では、DAG によりバイアスを整理し、何らかの調整 (IPWなど) を施して介入効果を推定し、観察研究のデータから因果関係の把握を試みます。ここで、DAG はそれぞれの要因 (交絡因子や共変量と呼ぶほうが適切だと思います) を確率的な事象だとみなし、それぞれの事象の関係について人間が外部から知識を持ち込んだものです。このため、DAG を書くことは確率的なモデルを作っていることにほぼ等価です。
また、DAG に現れる
因果推論とDAGについての俯瞰的な資料はこちらがわかりやすかったです。
不確実性の考慮が可能
確率モデルでは得られる結果は単一の予測値ではなく、確率分布になります。このため単一の予測値ではなく、その予測値の不確実さも同時に示すことができます。
このため、その予測がどの程度信頼できるのかを信頼区間として表示できます。たとえば、さきほどの Google によるCOVID-19の感染予測(日本版) では一部のグラフに信頼区間が表示されています。次は 2020-12-06 に取得した累計陽性者数の予測グラフです。
右側に行くほど信頼区間が広がっていることから、将来になればなるほど予測精度が悪化していくことが見て取れます。平均的な予測値よりも、ある程度悪いケースまでを想定した意思決定を行う場合、信頼区間が表示できることは役に立つでしょう。
予測値の不確実さを利用する別の例として、ブラックボックス最適化があります。ブラックボックス最適化は、あるシステムの中身を知ることなく、そのシステムが最良な結果を出すような入力を見つけようというものです。具体的なタスクの例としては、機械学習アルゴリズムのハイパーパラメーター探索や、サーバー (nginx などが動いている任意のサーバー) の設定項目のチューニングがあります。
ブラックボックス最適化の分野において、ベイズ最適化の手法が適用される事例が増えてきています。ハイパーパラメーター自動最適化フレームワークである Optuna や Google が提供する最適化サービスである Cloud Vizier はどちらもベイズ最適化に基づいています。ベイズ最適化に基づくハイパーパラメーターチューニングの仕組みについて Google のブログ記事 から画像を引用します。
図中の黒い実線が予測値、点線が真の値です。不確実性は信頼区間として青い帯で表されています。点はそれぞれの試行で試したパラメーターの値 (横軸) と、その時得られた値 (縦軸) を表しています。試行するにあたり、全体を均一に探すのではなく、もしかしたら最良の結果が得られるかもしれない箇所を探しながら試行を行っていることがわかります。最適化については Optuna の作者の一人が作成したスライドがあり、こちらが大変わかりやすいです。
Cloud Vizier については以前調査した際の資料がありますので、そちらも掲載しておきます。
上記のアルゴリズムはガウス過程というモデルに基づきます。これもまたグラフィカルモデルとして記述できる、確率的なモデルです。ガウス過程については以前ガウス過程と機械学習についての勉強会でまとめた資料があるので、そちらを参照ください。
だいぶ長くなってしまいましたが、確率モデルを柔軟に組み替えたい、得た確率分布を利用したいというモチベーションは共有できたのではないかと思います。
確率的プログラミングについて
ここでは確率的プログラミングそのものについて詳細な定義を与えるのではなく、確率的プログラミングを行うためのツールの役割の紹介を行います。
確率モデルを自分で設計する場合、結構な量の計算を行わなければいけません。特殊なケース (たとえば共役事前分布を使って事後分布を計算するケース) では手で計算ができますが、一般の場合には計算は困難です。このため、機械学習フレームワークを用いたとき同様に、モデルの定義とデータを与えたら、自動的に必要な計算を行って確率分布を出力するフレームワークが必要となります。
フレームワークにはいくつか代表的なものがあり、現状では次のものが有名でしょう。
これらのフレームワークについては次のスライドで横断的に解説されています。
ただし、 PyMC については方針の変更があり、TensorFlow Probability の上で動かす PyMC4 を作るのではなく、PyMC3 が利用している Theano を JAX バックエンドで動かすことにしたようです。
大まかには PyMC3 で利用している Theano の計算グラフを JAX の上で動かせるように変換して動かすようです。筆者が PyMC3 にも Theano にもあまり詳しくないためこれ以上は述べません。
Pyro について
Pyro は Python で実行できる確率的プログラミングのためのライブラリです。当時 Uber AI Labo 所属の人たちが中心に作ったようです。Pyro については論文があるためこちらをまとめます。
目的
AI Researcher が階層モデルを実装する際に、環境構築やそのモデルのためだけのコードを書くことなく、素早く実装し、拡張性のあるモデルを書きやすいような状態を実現したかったようですね。
デザイン原則
次を要素をすべて充足することを目指しています。
- expressive : さまざまな制御フローをモデルの内部で柔軟にかけること
- scalable : 大規模なデータや高次元データを扱え、ハードウェアによる高速化もできること
- flexible : 素早く手軽に研究者が自分の書きたい処理を実装できること
- minimal : 独自のDSLを持たず、他の可視化ツールやライブラリがそのまま使えること
この要件をすべて満たすものがないので、自分たちで作ったとしています。既存のツールとの比較表は次のとおりです。
見て分かる通り、PyTorch をバックエンドに使っています。
速度比較
PyTorch の上に構築されているので、PyTorch と Pyro で同じモデル VAE を実装して比較をした結果がこちらです。
この結果を指して、Pyro によるオーバーヘッドはさほどないと筆者たちは主張しています。
コードサンプル
チュートリアルから正規分布に従う乱数を発生させるかんたんなコードを確認しましょう。
def weather():
cloudy = pyro.sample('cloudy', pyro.distributions.Bernoulli(0.3))
cloudy = 'cloudy' if cloudy.item() == 1.0 else 'sunny'
mean_temp = {'cloudy': 55.0, 'sunny': 75.0}[cloudy]
scale_temp = {'cloudy': 10.0, 'sunny': 15.0}[cloudy]
temp = pyro.sample('temp', pyro.distributions.Normal(mean_temp, scale_temp))
return cloudy, temp.item()
for _ in range(3):
print(weather())
# ('cloudy', 64.5440444946289)
# ('sunny', 94.37557983398438)
# ('sunny', 72.5186767578125)
3行目のコードから Python の制御構文を使えていることがわかります。また、4-5行目は見慣れない感じになっていますが、{'cloudy': 55.0, 'sunny': 75.0}
という辞書にキー cloudy
を渡して、対応する値を取得しています。
ここでは割愛しますが、VAE の実装も KL Divergence を直接計算しているオリジナルの実装 よりは複雑さが隠蔽されていて好印象です。
触ってみた感想
実際に VAE の実装を動かしてみたのですが、やはり訓練に時間がかかるなあという印象でした。Colab で動かしてみたのですが、CPU だとちょっと試すくらいの時間では終わらなさそうです。
JAX について
JAX は機械学習のための JIT (Just-in-Time) コンパイラです。JAX は2つの大きな機能を持っています、1つは XLA (Accelerated Linear Algebra) でのコンパイル時に行われる計算グラフの最適化とアクセラレーター(GPU, TPU)での実行、もう一つは AutoGrad による numpy 互換な演算の自動微分です。
XLA によるコンパイル
ドキュメント によると、XLA は次のように動作します。
- Python コードのトレースを行い HLO (High Level Optimizer) を生成 (1回目)
- 出力された HLO について実行環境に依存しない最適化を行い、HLO を生成 (2回目)
- 最適化された HLO について実行環境に依存した最適化を施し LLVM でコンパイル
次の図は同じドキュメントからの引用で、上記の動作を説明するものです。
HLO はコンパイラの生成する中間表現と捉えて良さそうです。XLA は機械学習で用いられるさまざまな演算をサポートしています。TensorFLow のユーザーが直接 XLA を触ることはなかなかないと思いますが、tf.function を通じて利用することができます。tf.functionを通じてXLAを用いるチュートリアルがあるので参照してください。また、tf.function についてのチュートリアルやtf.function について以前書いた解説記事もありますので参考にしていただければと思います。
Autograd
Autograd はかなり多機能 なのですが、ざっくりというと numpy 互換の行列演算 jax.numpy
とその自動微分機能群からなります。自動微分を行うものの一つが jax.grad
です。
NumPy 互換の行列演算
jax.numpy
は numpy のサブセットと互換性があり、dot
などの関数が用意されています。利用は次のようにしてできます。
import jax.numpy as jnp
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
jnp.dot(x, x.T).block_until_ready()
CUDA のインストールなど、JAX が GPU を利用できるように環境構築したあとにこのように記述することで、numpy の配列から GPU にデータが転送され、ドット積が GPU 上で計算されます。また、jax.random
には正規分布などの確率分布が用意されています。これらも次のように numpy と同じ感覚で利用できます。
from jax import random
key = random.PRNGKey(0)
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
jnp.dot(x, x.T).block_until_ready() # runs on the GPU
この場合は GPU にデータを転送する必要はなく、GPU 上でドット積の計算が行われます。jax.numpy
で利用可能な関数はjax.numpy
のドキュメントを、jax.random
で利用可能な確率分布についてはjax.random
のドキュメント を参照してください。また、関連する機能がjax.scipy
にもあります。
自動微分
jax.numpy
による演算結果は jax.grad
により自動的に微分され、結果を得られます。サンプルコードを次に示します。
import jax.numpy as jnp
from jax import grad
def sum_logistic(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
key = random.PRNGKey(0)
x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))
# [0.25 0.19661197 0.10499357]
grad
により、自動的に導関数が算出され (derivative_fn
) それに入力を渡すことで、導関数の値が計算されます。grad
を重ねることで高階微分も可能です。
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
# -0.035325605
Autograd は多変数関数のヤコビアンの計算やヘッセ行列の計算、多変数複素関数の微分も扱えます。詳細はドキュメントを参照してください。
深層学習ライブラリ
Jax はこれまで見てきたように、機械学習に必要になる一通りの機能を取り揃えています。一方、利用者に要求する知識は TensorFlow や PyTorch と比較したときに相対的に少ないです。このため、これを利用するフレームワークが次々に誕生しています。
もっともシンプルなものは JAX に付随している stax でしょう。これは1ファイルで構成されていますが、かんたんなニューラルネットワークや CNN の構築が可能です。サンプルコードは次のようになります。
import jax.numpy as jnp
from jax import random
from jax.experimental import stax
from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax
# Use stax to set up network initialization and evaluation functions
net_init, net_apply = stax.serial(
Conv(32, (3, 3), padding='SAME'), Relu,
Conv(64, (3, 3), padding='SAME'), Relu,
MaxPool((2, 2)), Flatten,
Dense(128), Relu,
Dense(10), LogSoftmax,
)
Keras に触れたことのある方なら何となく何が行われているのか予想がつくのではないでしょうか。本稿では詳細は触れませんが、他にも google/flax, google/trax, PyTorch 的なインターフェースをしている google/objax, Sonnet っぽい deepmind/dm-haiku (Haiku) など、さまざまなチームが思い思いにフレームワークを作っています。
JAX はまだ荒削りではあるものの、事前知識を要求しない気軽さから利用はまだ広がるのではないかと思います。
NumPyro について
NumPyro は自分自身のことを「Pyro と同じ確率的モデリング用のインターフェースを持つ、NumPy をバックエンドに持つ軽量ライブラリ」だと主張しています。実際は JAX をバックエンドに持っているので numpy 互換の配列を取り扱える上に、GPU/TPU での高速化が有効です。
こちらも論文が出ているので内容を確認しましょう。
変更点
Pyro のインターフェースをできるだけ保ったまま実装したようですが、次の点は変更があったようです。
NUTS (No-U-Turn Sampler) の実装
JAX の JIT に対応させるためには Python の制御構造 (具体的には再帰呼び出し) が利用できないため、新規にアルゴリズムを実装し直したようです。 NUTS では逆戻りしていないかの確認のために過去の履歴を保持する木構造を、関数を再帰的に呼び出して生成していましたが、再帰的な呼び出しを行わないようにアルゴリズムを置き換えることでJAXでコンパイルできるようにしたようです。
vmap の実装
バッチ処理について、JAX が提供する vmap
を利用し、バッチ処理を用意にしたようです。サンプルコードが掲載されていたので引用します。
numpy の vectorize
に似たような書き方かと思います。
機械学習モデルとの連携
論文には書かれていませんでしたが、バックエンドに PyTorch を使っていないので深層学習モデルを利用したい場合、何らかの別のフレームワークを使う必要があります。NumPyro のチュートリアルでは stax を利用していましたが他のものを検討しても良いでしょう。
実行速度の比較
Pyro や Stan との実行速度の比較がなされていました。HMM の実行において、Pyro と比較して 340 倍、Stan と比較して 6倍の高速に実行できたようです。
NumPyro を動かす
NumPyro と Pyro を動かしてみて比較しましょう。動作させるアルゴリズムは両方ともにサンプルコードが用意されていた VAE を用います。全く同じ実装を用いての比較は時間の都合上できませんでしたが、隠れ層の次元などのハイパーパラメーターはできるだけ揃えました。Pyro での実装は pyro/vae.ipynb を、NumPyro での実装は numpyro/vae.py を参考に Jupyter Notebook に再実装したものを用いました。実行環境は Colab です。結果は次のようになりました。
利用したライブラリ | 訓練に要した実行時間 |
---|---|
Pyro (CPU) | - (10分以上) |
Pyro (GPU) | 10 分 |
NumPyro (CPU) | 9 分 |
NumPyro (GPU) | 25 秒 |
確かに CPU 間で比較しても GPU 間で比較しても高速になっていることがわかります。Pyro と NumPyro の CPU 実行時の速度比較はできませんでした。GPU 間での実行速度は 25 倍と論文の報告内容とは差異があったものの、たしかに大幅に高速化されました。
最後に
準備にかけられた時間の都合上、前半で触れていた SIR のような確率モデルの作成までは行き着きませんでしたが、NumPyro が登場した背景については一通り触れられたかと思います。これを期に今後も確率的プログラミングに取り組んでいきたいと思います。
おまけ
- NumPyroとJax Numpyで時系列 - HELLO CYBERNETICS 勇気づけられました
- Example: Causal Effect VAE — Pyro Tutorials 1.5.0 documentation 本当はこれを解説した上で NumPyro で再実装したかった
Discussion