NumPyro:時系列分析
はじめに
今回は時系列分析を扱います。NumPyroではFor文を使用すると速度が遅くなる&メモリを大量に使用するようになるので、時系列などの繰り返し構造がある場合は、scan
関数を使用する必要があります。この記事ではscan
関数と簡単な分析にとどめて紹介します。
ライブラリのインポート
import os
import jax
import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
from numpyro.contrib.control_flow import scan
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive
from numpyro.infer.util import initialize_model
import arviz as az
az.style.use("arviz-darkgrid")
assert numpyro.__version__.startswith("0.11.0")
numpyro.set_platform("cpu")
numpyro.set_host_device_count(1)
Scan関数
説明がしづらいのですが、おそらく公式ドキュメントが一番わかりやすいです。
以下のコードは擬似コードですが、時系列の関数transition
と初期値carry
、時間が格納された配列timesteps
を引数にして時間発展のforループを回すような関数になっています。
def scan(transition, carry, timesteps, length=None):
if timesteps is None:
timesteps = [None] * length
ys = []
for t in timesteps:
carry, y = transition(carry, t)
ys.append(y)
return carry, np.stack(ys)
以下の適当な例を見てみます。y_t
は初期値0から始まり、timestepを逐次追加していくような関数になっています。xsで渡した時間配列だけループが回っていることがわかります。
def transition(carry, timesteps):
y_prev = carry
y_t = y_prev + timesteps
carry = y_t
return carry, y_t
jax.lax.scan(transition, init=(0), xs=jnp.arange(10))
(Array(45, dtype=int32),
Array([ 0, 1, 3, 6, 10, 15, 21, 28, 36, 45], dtype=int32))
時系列分析
簡単な例として、1階差分のトレンド項を扱います。モデル式は以下になります。
データの準備
アヒル本の12章のデータを使用します。
df = pd.read_csv("./RStanBook/chap12/input/data-ss1.txt")
df.head()
X Y
0 1 11.2
1 2 11.0
2 3 11.3
3 4 10.8
4 5 10.8
モデルの定義
for文使用(非推奨)
まずはfor文で書いた方を見ていきます。これは単純にモデル式通り書くだけです。サンプルサイトが重複するとエラーが出るので、"mu_t{}".format(i)
のようにイテレーションごとに名前を変えています。
def model_loop(y):
sigma_mu = numpyro.sample("sigma_mu", dist.HalfNormal(1))
sigma_y = numpyro.sample("sigma_y", dist.HalfNormal(1))
mu_prev = numpyro.sample("mu_prev", dist.Normal(0, 10))
for i in range(len(y)):
mu_t = numpyro.sample("mu_t{}".format(i), dist.Normal(mu_prev, sigma_mu))
y_t = numpyro.sample("y_t{}".format(i), dist.Normal(mu_t, sigma_y), obs=y[i])
mu_prev = mu_t
scan使用(推奨)
for文をscanに置き換えます。scanを使用する場合はobs=y
のように指定するのではなく、with numpyro.handlers.condition(data={"y_t": y}):
とすることで、モデル内の各確率変数(ここでは、"y_t")の実現値を固定(条件付け)してサンプリングを行うことが可能になります。
def model_scan(y):
sigma_mu = numpyro.sample("sigma_mu", dist.HalfNormal(1))
sigma_y = numpyro.sample("sigma_y", dist.HalfNormal(1))
mu_prev = numpyro.sample("mu_prev", dist.Normal(0, 10))
def transition(carry, _):
mu_prev = carry
mu_t = numpyro.sample("mu_t", dist.Normal(mu_prev, sigma_mu))
y_t = numpyro.sample("y_t", dist.Normal(mu_t, sigma_y))
carry = mu_t
return carry, None
timesteps = jnp.arange(len(y))
init = mu_prev
with numpyro.handlers.condition(data={"y_t": y}):
scan(transition, init, timesteps)
MCMC
scan
を扱う際はjnp.asarray(df["Y"].values)
のようにjax.numpyの形式に変換しておく必要があります。結果を確認すると、アヒル本の結果とほとんど一致しました。
# 乱数の固定に必要
rng_key= random.PRNGKey(0)
# NUTSでMCMCを実行する
kernel = NUTS(model_scan)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
mcmc.run(
rng_key=rng_key,
# scanの場合はjnpに変換しておく必要があった
y=jnp.asarray(df["Y"].values),
)
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
mu_prev 11.16 0.43 11.15 10.42 11.81 1353.72 1.00
mu_t[0] 11.18 0.16 11.18 10.91 11.42 1557.75 1.00
mu_t[1] 11.05 0.15 11.04 10.81 11.28 1020.97 1.00
mu_t[2] 11.21 0.15 11.23 10.95 11.44 907.79 1.01
mu_t[3] 10.87 0.16 10.85 10.65 11.15 765.45 1.00
mu_t[4] 10.88 0.16 10.85 10.64 11.14 816.44 1.01
mu_t[5] 11.22 0.15 11.24 10.97 11.45 857.40 1.00
mu_t[6] 11.12 0.14 11.11 10.88 11.36 1503.10 1.00
mu_t[7] 11.09 0.16 11.06 10.84 11.36 570.96 1.01
mu_t[8] 11.42 0.15 11.41 11.17 11.66 1169.89 1.00
mu_t[9] 11.75 0.15 11.74 11.52 12.01 1784.41 1.00
mu_t[10] 12.41 0.16 12.43 12.16 12.66 697.04 1.01
mu_t[11] 12.59 0.15 12.60 12.36 12.84 2036.37 1.00
mu_t[12] 12.79 0.15 12.79 12.56 13.05 1855.04 1.00
mu_t[13] 12.97 0.15 12.98 12.70 13.18 1578.22 1.00
mu_t[14] 13.05 0.14 13.04 12.83 13.29 1654.12 1.00
mu_t[15] 13.46 0.18 13.50 13.15 13.71 393.74 1.01
mu_t[16] 13.24 0.16 13.25 12.98 13.48 768.24 1.01
mu_t[17] 12.74 0.17 12.70 12.50 13.02 579.45 1.01
mu_t[18] 12.99 0.17 13.01 12.69 13.23 563.64 1.01
mu_t[19] 12.60 0.15 12.60 12.37 12.86 2084.01 1.00
mu_t[20] 12.19 0.17 12.16 11.94 12.50 561.57 1.02
sigma_mu 0.39 0.09 0.38 0.24 0.52 756.03 1.00
sigma_y 0.17 0.09 0.15 0.03 0.28 144.40 1.04
Number of divergences: 0
scan関数の時とほとんど同じ結果です。今回の簡単な例ではscan
関数を使用時は2.3s、for文使用時は5.4sなのでfor文を使用すると2倍遅くなっていることがわかります。より複雑なモデルの場合はより遅く&重くなるので可能な限りscan
関数を使用するようにします。
# 乱数の固定に必要
rng_key= random.PRNGKey(0)
# NUTSでMCMCを実行する
kernel = NUTS(model_loop) # init_to_feasibleにするとN_effが増える
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
mcmc.run(
rng_key=rng_key,
y=df["Y"].values,
)
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
mu_prev 11.18 0.43 11.18 10.48 11.88 1210.47 1.00
mu_t0 11.18 0.15 11.19 10.92 11.41 1627.87 1.00
mu_t1 11.04 0.14 11.03 10.82 11.27 1858.22 1.00
mu_t10 12.43 0.14 12.45 12.18 12.64 609.38 1.00
mu_t11 12.59 0.14 12.60 12.36 12.82 2652.47 1.00
mu_t12 12.80 0.14 12.80 12.59 13.03 1999.18 1.00
mu_t13 12.98 0.14 12.99 12.75 13.21 2032.72 1.00
mu_t14 13.05 0.13 13.03 12.83 13.27 1356.47 1.00
mu_t15 13.49 0.16 13.52 13.22 13.73 293.55 1.00
mu_t16 13.25 0.15 13.27 12.98 13.47 1095.34 1.00
mu_t17 12.72 0.16 12.69 12.49 12.99 365.10 1.00
mu_t18 13.00 0.15 13.03 12.76 13.23 383.59 1.00
mu_t19 12.60 0.14 12.60 12.38 12.83 1832.62 1.00
mu_t2 11.22 0.14 11.25 10.98 11.44 524.19 1.00
mu_t20 12.18 0.17 12.15 11.94 12.46 524.21 1.00
mu_t3 10.86 0.15 10.84 10.65 11.12 761.71 1.00
mu_t4 10.86 0.15 10.84 10.62 11.09 1134.56 1.00
mu_t5 11.23 0.14 11.25 10.97 11.43 604.95 1.00
mu_t6 11.11 0.14 11.11 10.87 11.32 1723.22 1.00
mu_t7 11.07 0.15 11.05 10.85 11.31 715.61 1.00
mu_t8 11.41 0.14 11.40 11.19 11.63 1585.63 1.00
mu_t9 11.75 0.14 11.73 11.55 12.03 931.34 1.00
sigma_mu 0.39 0.08 0.39 0.25 0.51 526.86 1.00
sigma_y 0.15 0.08 0.14 0.02 0.26 90.64 1.00
Number of divergences: 0
最後に
以上で「時系列分析」は終わりです。scan
関数に慣れさえすれば、簡単に試せますね。
今回でNumPyroの使い方を中心にした記事は一旦終了です。興味がある方はNumPyroのドキュメントやPyroのドキュメントを漁ると参考になるコードも多いかと思います。
次回以降は、階層モデルやガウス過程、状態空間モデルの論文やProphet等のソースコード等を調査しながら勉強&実装を進めていきたいと思います。
Discussion