🐥

NumPyro:欠損値の扱い方

2023/04/26に公開

連載している記事の1つです。以前までの記事を読んでいる前提で書いているので、必要であればNumPyroの記事一覧から各記事を参考にしてください。

はじめに

今回は実務でもよくある欠損値がデータにある際の取り扱い方をみていきます。

ライブラリのインポート

import os

import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive
from numpyro.infer.util import initialize_model

import arviz as az

az.style.use("arviz-darkgrid")

assert numpyro.__version__.startswith("0.11.0")

numpyro.enable_x64(True)
numpyro.set_platform("cpu")
numpyro.set_host_device_count(1)

maskの挙動

今回は欠損値をベイズ的に取り扱い、欠損値を何かしらの潜在変数からサンプリングされた値で置換することを考えます。この時、欠損値を置換した値は勝手に潜在変数からサンプリングされた値なので、全体の対数尤度に追加しないようにします。これを実現するのがnumpyroのmask()です。

公式の例の説明ですが、以下のmaskを使わない場合のmodel2aとmaskを使う場合のmodel2bは同じ意味のコードになります。maskを使うことでobs=x_imputedだったとしても観測された値だけ考慮することになるわけですね。以降でより詳細に見ていきます。

def model2a(x):
    x_impute = numpyro.sample("x_impute", dist.Normal(0, 1).expand([4]))
    x_obs = numpyro.sample("x_obs", dist.Normal(0, 1).expand([6]), obs=x[4:])
    x_imputed = jnp.concatenate([x_impute, x_obs])

def model2b(x):
    x_impute = numpyro.sample("x_impute", dist.Normal(0, 1).expand([4]).mask(False))
    x_imputed = jnp.concatenate([x_impute, x[4:]])
    numpyro.sample("x", dist.Normal(0, 1).expand([10]), obs=x_imputed)

mask関数

NumPyro特有の関数などまとめでも触れましたが、numpyroのmaskを使用することでlog_probが0になります。

d = dist.Normal(0, 1).expand([4]).mask(False)
samples = numpyro.sample("a", d, rng_key=random.PRNGKey(0))

print(samples)
print(d.log_prob(0))
[ 1.8160863  -0.75488514  0.33988908 -0.53483534]
[0. 0. 0. 0.]

固定値で置換

最初は全て欠損値のデータの極端な例で挙動を見ていきましょう。

y = np.repeat(np.nan, 10)
y
array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan])

maskがない場合

欠損値を全て100に置換してみました。この場合、100が10回観測されたとして推論が行われるので、y_muの事後分布の平均値は90くらいになっています。

def model_nomask(y):
    
    # imputation
    y_isnan = np.isnan(y)
    # nanのデータのindexを取得
    y_nanidx = np.nonzero(y_isnan)[0]
    # 全てのデータを100で置換
    y = jnp.asarray(y).at[y_nanidx].set(100)
    
    y_mu = numpyro.sample("y_mu", dist.Normal(0, 1))
    numpyro.sample("obs", dist.Normal(y_mu, 1).expand([len(y)]), obs=y)
        
mcmc = MCMC(NUTS(model_nomask), num_warmup=1000, num_samples=1000)
mcmc.run(random.PRNGKey(0), y)
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
      y_mu     90.90      0.30     90.90     90.45     91.39    533.58      1.00

Number of divergences: 0

maskがある場合

先ほどと同様に欠損値を全て100に置換してみましたが、dist.Normal(y_mu, 1).expand([len(y)]).mask(False)mask(False)でlog_probが0になるようになっているので、観測値が与えられてないのと同様になり、事後分布は事前分布と同じ(平均0、標準偏差1)になっています。このようにmaskを適宜使用することで、欠損値に対してはlog_probが0になります。

def model_mask(y):
    
    # imputation
    y_isnan = np.isnan(y)
    # nanのデータのindexを取得
    y_nanidx = np.nonzero(y_isnan)[0]
    # 全てのデータを100で置換
    y = jnp.asarray(y).at[y_nanidx].set(100)
    
    y_mu = numpyro.sample("y_mu", dist.Normal(0, 1))
    numpyro.sample("obs", dist.Normal(y_mu, 1).expand([len(y)]).mask(False), obs=y)
        
        
mcmc = MCMC(NUTS(model_mask), num_warmup=1000, num_samples=1000)
mcmc.run(random.PRNGKey(0), y)
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
      y_mu     -0.04      0.99     -0.00     -1.64      1.57    359.39      1.00

Number of divergences: 0

欠損値を分布からサンプリングした値で置換

次は以下のような欠損値と実測値が混じった例で、欠損値を分布からサンプリングした値で置換する方法をみていきます。

y = np.hstack([np.repeat(np.nan, 10), np.repeat(10., 5)])
print(y)
[nan nan nan nan nan nan nan nan nan nan 10. 10. 10. 10. 10.]

maskがない場合

まず最初に、欠損値がない場合の結果を確認しておきます。欠損値はないので置換部分のコードは実行されない単純なコードです。観測値だけの場合の結果は、y_muの事後分布の平均値が8.32になることが分かりました。

def model_nomask(y):
    # imputation
    y_isnan = np.isnan(y)
    if y_isnan.any():
        # nanのデータのindexを取得
        y_nanidx = np.nonzero(y_isnan)[0]
	# 置換
        impute_mu = numpyro.sample("impute_mu", dist.Normal(100, 1))
        y_impute = numpyro.sample("y_impute", dist.Normal(impute_mu, 1).expand([len(y_nanidx)])) 
        y = jnp.asarray(y).at[y_nanidx].set(y_impute)
    
    y_mu = numpyro.sample("y_mu", dist.Normal(0, 1))
    numpyro.sample("obs", dist.Normal(y_mu, 1).expand([len(y)]), obs=y)

mcmc = MCMC(NUTS(model_nomask), num_warmup=1000, num_samples=1000)
mcmc.run(random.PRNGKey(0), y[~np.isnan(y)])
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
      y_mu      8.32      0.41      8.34      7.64      8.97    373.80      1.00

Number of divergences: 0

次は、先ほどと同様に欠損値があるケースでmaskがない場合をみていきます。この場合、y_muの事後分布の平均値が19.52となり、欠損値を除いた結果と異なる結果になりました。これは欠損値を置換した値も観測値として扱われているためです。

def model_nomask(y):
    # imputation
    y_isnan = np.isnan(y)
    if y_isnan.any():
        # nanのデータのindexを取得
        y_nanidx = np.nonzero(y_isnan)[0]
	# 置換
        impute_mu = numpyro.sample("impute_mu", dist.Normal(100, 1))
        y_impute = numpyro.sample("y_impute", dist.Normal(impute_mu, 1).expand([len(y_nanidx)])) 
        y = jnp.asarray(y).at[y_nanidx].set(y_impute)
    
    y_mu = numpyro.sample("y_mu", dist.Normal(0, 1))
    numpyro.sample("obs", dist.Normal(y_mu, 1).expand([len(y)]), obs=y)

mcmc = MCMC(NUTS(model_nomask), num_warmup=1000, num_samples=1000)
mcmc.run(random.PRNGKey(0), y)
mcmc.print_summary()
                 mean       std    median      5.0%     95.0%     n_eff     r_hat
  impute_mu     32.94      0.50     32.93     32.19     33.78    357.83      1.00
y_impute[0]     26.22      0.83     26.18     24.92     27.61    788.86      1.00
y_impute[1]     26.21      0.81     26.19     24.94     27.57    813.18      1.00
y_impute[2]     26.22      0.82     26.24     24.75     27.43    584.58      1.00
y_impute[3]     26.22      0.85     26.20     24.72     27.49    705.82      1.00
y_impute[4]     26.23      0.78     26.20     25.05     27.58    579.04      1.00
y_impute[5]     26.23      0.83     26.23     24.96     27.63    812.54      1.00
y_impute[6]     26.23      0.81     26.22     24.94     27.56    719.84      1.00
y_impute[7]     26.23      0.83     26.21     24.76     27.42    667.67      1.00
y_impute[8]     26.22      0.81     26.24     24.88     27.55    560.44      1.00
y_impute[9]     26.23      0.82     26.25     24.94     27.59    817.96      1.00
       y_mu     19.52      0.38     19.52     18.94     20.13    432.41      1.00

Number of divergences: 0

maskがある場合

今回は入力に欠損値と観測値が混じっているので、欠損値を生成する分布にmask(False)がついています。この場合、y_muの事後分布の平均値が8.35となり、多少のずれはありますが欠損値を除いた結果と同じ結果になりました。これはmask(False)により欠損値が置換された値はlog_probが0になるためです。
このように、欠損値を生成する分布にmask(False)をつけて生成した置換値と実測値を合わせてnumpyro.sample("obs", base_dist.expand([len(y)]), obs=y)をしてあげれば欠損値に対応できます。

def model_mask(y):
    # imputation
    y_isnan = np.isnan(y)
    if y_isnan.any():
        # nanのデータのindexを取得
        y_nanidx = np.nonzero(y_isnan)[0]
        impute_mu = numpyro.sample("impute_mu", dist.Normal(100, 1))
        y_impute = numpyro.sample("y_impute", dist.Normal(impute_mu, 1).expand([len(y_nanidx)]).mask(False)) 
        y = jnp.asarray(y).at[y_nanidx].set(y_impute)
    
    y_mu = numpyro.sample("y_mu", dist.Normal(0, 1))
    numpyro.sample("obs", dist.Normal(y_mu, 1).expand([len(y)]), obs=y)
    
mcmc = MCMC(NUTS(model_mask), num_warmup=1000, num_samples=1000)
mcmc.run(random.PRNGKey(0), y)
mcmc.print_summary()
                 mean       std    median      5.0%     95.0%     n_eff     r_hat
  impute_mu     99.97      1.02     99.95     98.31    101.68   1686.42      1.00
y_impute[0]      8.33      1.11      8.32      6.51     10.14   1391.07      1.00
y_impute[1]      8.32      1.04      8.30      6.66     10.04   1350.25      1.00
y_impute[2]      8.34      1.07      8.37      6.43      9.92    969.87      1.00
y_impute[3]      8.33      1.13      8.31      6.45     10.17   1218.58      1.00
y_impute[4]      8.36      1.04      8.32      6.82     10.21    920.16      1.00
y_impute[5]      8.38      1.09      8.38      6.67     10.19   1193.24      1.00
y_impute[6]      8.36      1.10      8.38      6.64     10.24   1241.57      1.00
y_impute[7]      8.34      1.10      8.35      6.60     10.07    926.52      1.00
y_impute[8]      8.33      1.06      8.37      6.73     10.24   1124.78      1.00
y_impute[9]      8.32      1.14      8.36      6.51     10.25   1566.90      1.00
       y_mu      8.35      0.39      8.34      7.69      8.99    637.89      1.00

Number of divergences: 0

以上が連続値の変数に欠損値が含まれる際の例になります。NumPyroの例ではより発展的な内容で、欠損値を埋める際に階層モデルでモデル化していますので興味のある方はご覧ください。少し複雑に見えますが、mask()を使用してやっていることは上記と同じです。

離散値の潜在変数に欠損値が含まれる場合

NumPyroは離散潜在変数を使用しても自動で場合の数を列挙(周辺化)してくれますが、その機能を応用して欠損値に対応することもできます。このトピックに関しては素晴らしいチュートリアルがありまして、ここではあまり自分で追記することもないのでコードを一部だけ説明するのに留めます。こちらの元になったissueも要点がまとまっていてわかりやすいかもしれません。

チュートリアルのコードを少し簡略化したものが以下になります。考え方自体は連続値の時と同じで、全体の対数尤度にはXが観測された場合のlog_probのみを追加しています。離散潜在変数の場合は、裏で場合の数が列挙されるので、列挙された値が観測値と異なる場合はlog_probをキャンセルする処理が書かれています。

このように離散潜在変数の場合でも多少コードを変えるだけで対応できるため非常に便利です。

def model(x, y)
    # 観測値かどうかの情報を取得
    x_isobs = ~np.isnan(x)

    # yのためのパラメータ
    b_x = sample("b_x", dist.Normal(0, 2.5))
    s_Y = sample("s_Y", dist.HalfCauchy(2.5))

    # 欠損値用のパラメータ
    eta = numpyro.sample("eta", dist.Normal(0, 1))
    with numpyro.plate("obs", len(y)):
        # 離散潜在変数xをサンプリング
	# maskすることで、ここではlog_probは計算しない
        x_impute = sample(
            "x_impute",
            dist.Bernoulli(logits=eta).mask(False),
            infer={"enumerate": "parallel"},
        )

        # 手動でlog_probを計算する
        log_prob = dist.Bernoulli(logits=eta).log_prob(x_impute)

	# 列挙された値が観測と異なる場合はlog_probをキャンセルする
        log_prob = jnp.where(x_isobs & (x_impute != x), -inf, log_prob)

        # 全体の対数尤度に追加する
        numpyro.factor("obs_x", log_prob)
	
        # for y
        mu = b_x * x_impute
        sample("obs_Y", dist.Normal(mu, s_Y), obs=Y)

最後に

以上で「欠損値の扱い方」は終わりです。ややこしい内容が多いので多少慣れるのに時間がかかるかと思います。次は「ODE」です。

Discussion