🐈

NumPyro:時系列分析

2023/05/04に公開

はじめに

今回は時系列分析を扱います。NumPyroではFor文を使用すると速度が遅くなる&メモリを大量に使用するようになるので、時系列などの繰り返し構造がある場合は、scan関数を使用する必要があります。この記事ではscan関数と簡単な分析にとどめて紹介します。

ライブラリのインポート

import os

import jax
import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
from numpyro.contrib.control_flow import scan
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive
from numpyro.infer.util import initialize_model

import arviz as az

az.style.use("arviz-darkgrid")

assert numpyro.__version__.startswith("0.11.0")

numpyro.set_platform("cpu")
numpyro.set_host_device_count(1)

Scan関数

説明がしづらいのですが、おそらく公式ドキュメントが一番わかりやすいです。
以下のコードは擬似コードですが、時系列の関数transitionと初期値carry、時間が格納された配列timestepsを引数にして時間発展のforループを回すような関数になっています。

def scan(transition, carry, timesteps, length=None):
  if timesteps is None:
    timesteps = [None] * length
  ys = []
  for t in timesteps:
    carry, y = transition(carry, t)
    ys.append(y)
  return carry, np.stack(ys)

以下の適当な例を見てみます。y_tは初期値0から始まり、timestepを逐次追加していくような関数になっています。xsで渡した時間配列だけループが回っていることがわかります。

def transition(carry, timesteps):
    y_prev = carry
    y_t = y_prev + timesteps
    carry = y_t
    return carry, y_t

jax.lax.scan(transition, init=(0), xs=jnp.arange(10))
(Array(45, dtype=int32),
 Array([ 0,  1,  3,  6, 10, 15, 21, 28, 36, 45], dtype=int32))

時系列分析

簡単な例として、1階差分のトレンド項を扱います。モデル式は以下になります。

\mu[t] \sim Normal(\mu[t-1], \sigma_{\mu}) \\ y[t] \sim Normal(\mu[t], \sigma_{y})

データの準備

アヒル本の12章のデータを使用します。

df = pd.read_csv("./RStanBook/chap12/input/data-ss1.txt")
df.head()

X	Y
0	1	11.2
1	2	11.0
2	3	11.3
3	4	10.8
4	5	10.8

モデルの定義

for文使用(非推奨)

まずはfor文で書いた方を見ていきます。これは単純にモデル式通り書くだけです。サンプルサイトが重複するとエラーが出るので、"mu_t{}".format(i)のようにイテレーションごとに名前を変えています。

def model_loop(y):
    sigma_mu = numpyro.sample("sigma_mu", dist.HalfNormal(1))
    sigma_y = numpyro.sample("sigma_y", dist.HalfNormal(1))
    
    mu_prev = numpyro.sample("mu_prev", dist.Normal(0, 10))
    
    for i in range(len(y)):
        mu_t = numpyro.sample("mu_t{}".format(i), dist.Normal(mu_prev, sigma_mu))
        y_t = numpyro.sample("y_t{}".format(i), dist.Normal(mu_t, sigma_y), obs=y[i])
        mu_prev = mu_t

scan使用(推奨)

for文をscanに置き換えます。scanを使用する場合はobs=yのように指定するのではなく、with numpyro.handlers.condition(data={"y_t": y}):とすることで、モデル内の各確率変数(ここでは、"y_t")の実現値を固定(条件付け)してサンプリングを行うことが可能になります。

def model_scan(y):
    sigma_mu = numpyro.sample("sigma_mu", dist.HalfNormal(1))
    sigma_y = numpyro.sample("sigma_y", dist.HalfNormal(1))
    
    mu_prev = numpyro.sample("mu_prev", dist.Normal(0, 10))
    
    def transition(carry, _):
        mu_prev = carry
        mu_t = numpyro.sample("mu_t", dist.Normal(mu_prev, sigma_mu))
        y_t = numpyro.sample("y_t", dist.Normal(mu_t, sigma_y))
        carry = mu_t
        return carry, None

    timesteps = jnp.arange(len(y))
    init = mu_prev

    with numpyro.handlers.condition(data={"y_t": y}):
        scan(transition, init, timesteps)

MCMC

scanを扱う際はjnp.asarray(df["Y"].values)のようにjax.numpyの形式に変換しておく必要があります。結果を確認すると、アヒル本の結果とほとんど一致しました。

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model_scan)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
mcmc.run(
    rng_key=rng_key,
    # scanの場合はjnpに変換しておく必要があった
    y=jnp.asarray(df["Y"].values),
)
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
   mu_prev     11.16      0.43     11.15     10.42     11.81   1353.72      1.00
   mu_t[0]     11.18      0.16     11.18     10.91     11.42   1557.75      1.00
   mu_t[1]     11.05      0.15     11.04     10.81     11.28   1020.97      1.00
   mu_t[2]     11.21      0.15     11.23     10.95     11.44    907.79      1.01
   mu_t[3]     10.87      0.16     10.85     10.65     11.15    765.45      1.00
   mu_t[4]     10.88      0.16     10.85     10.64     11.14    816.44      1.01
   mu_t[5]     11.22      0.15     11.24     10.97     11.45    857.40      1.00
   mu_t[6]     11.12      0.14     11.11     10.88     11.36   1503.10      1.00
   mu_t[7]     11.09      0.16     11.06     10.84     11.36    570.96      1.01
   mu_t[8]     11.42      0.15     11.41     11.17     11.66   1169.89      1.00
   mu_t[9]     11.75      0.15     11.74     11.52     12.01   1784.41      1.00
  mu_t[10]     12.41      0.16     12.43     12.16     12.66    697.04      1.01
  mu_t[11]     12.59      0.15     12.60     12.36     12.84   2036.37      1.00
  mu_t[12]     12.79      0.15     12.79     12.56     13.05   1855.04      1.00
  mu_t[13]     12.97      0.15     12.98     12.70     13.18   1578.22      1.00
  mu_t[14]     13.05      0.14     13.04     12.83     13.29   1654.12      1.00
  mu_t[15]     13.46      0.18     13.50     13.15     13.71    393.74      1.01
  mu_t[16]     13.24      0.16     13.25     12.98     13.48    768.24      1.01
  mu_t[17]     12.74      0.17     12.70     12.50     13.02    579.45      1.01
  mu_t[18]     12.99      0.17     13.01     12.69     13.23    563.64      1.01
  mu_t[19]     12.60      0.15     12.60     12.37     12.86   2084.01      1.00
  mu_t[20]     12.19      0.17     12.16     11.94     12.50    561.57      1.02
  sigma_mu      0.39      0.09      0.38      0.24      0.52    756.03      1.00
   sigma_y      0.17      0.09      0.15      0.03      0.28    144.40      1.04

Number of divergences: 0

scan関数の時とほとんど同じ結果です。今回の簡単な例ではscan関数を使用時は2.3s、for文使用時は5.4sなのでfor文を使用すると2倍遅くなっていることがわかります。より複雑なモデルの場合はより遅く&重くなるので可能な限りscan関数を使用するようにします。

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model_loop) # init_to_feasibleにするとN_effが増える
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
mcmc.run(
    rng_key=rng_key,
    y=df["Y"].values,
)
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
   mu_prev     11.18      0.43     11.18     10.48     11.88   1210.47      1.00
     mu_t0     11.18      0.15     11.19     10.92     11.41   1627.87      1.00
     mu_t1     11.04      0.14     11.03     10.82     11.27   1858.22      1.00
    mu_t10     12.43      0.14     12.45     12.18     12.64    609.38      1.00
    mu_t11     12.59      0.14     12.60     12.36     12.82   2652.47      1.00
    mu_t12     12.80      0.14     12.80     12.59     13.03   1999.18      1.00
    mu_t13     12.98      0.14     12.99     12.75     13.21   2032.72      1.00
    mu_t14     13.05      0.13     13.03     12.83     13.27   1356.47      1.00
    mu_t15     13.49      0.16     13.52     13.22     13.73    293.55      1.00
    mu_t16     13.25      0.15     13.27     12.98     13.47   1095.34      1.00
    mu_t17     12.72      0.16     12.69     12.49     12.99    365.10      1.00
    mu_t18     13.00      0.15     13.03     12.76     13.23    383.59      1.00
    mu_t19     12.60      0.14     12.60     12.38     12.83   1832.62      1.00
     mu_t2     11.22      0.14     11.25     10.98     11.44    524.19      1.00
    mu_t20     12.18      0.17     12.15     11.94     12.46    524.21      1.00
     mu_t3     10.86      0.15     10.84     10.65     11.12    761.71      1.00
     mu_t4     10.86      0.15     10.84     10.62     11.09   1134.56      1.00
     mu_t5     11.23      0.14     11.25     10.97     11.43    604.95      1.00
     mu_t6     11.11      0.14     11.11     10.87     11.32   1723.22      1.00
     mu_t7     11.07      0.15     11.05     10.85     11.31    715.61      1.00
     mu_t8     11.41      0.14     11.40     11.19     11.63   1585.63      1.00
     mu_t9     11.75      0.14     11.73     11.55     12.03    931.34      1.00
  sigma_mu      0.39      0.08      0.39      0.25      0.51    526.86      1.00
   sigma_y      0.15      0.08      0.14      0.02      0.26     90.64      1.00

Number of divergences: 0

最後に

以上で「時系列分析」は終わりです。scan関数に慣れさえすれば、簡単に試せますね。

今回でNumPyroの使い方を中心にした記事は一旦終了です。興味がある方はNumPyroのドキュメントPyroのドキュメントを漁ると参考になるコードも多いかと思います。

次回以降は、階層モデルやガウス過程、状態空間モデルの論文やProphet等のソースコード等を調査しながら勉強&実装を進めていきたいと思います。

Discussion