Closed6

Stanで順序ロジスティック回帰

sh1nsh1n

Stanで使えそうなもの

  • Bounded Discrete Distributions
    • 変数のサンプリング k ~ ordered_logistic(eta, c)
      • eta: predictor
        • 線形回帰の場合、共変量*係数 のようになる
      • c: それぞれのカテゴリ間の閾値
        • 3段階の回帰の場合、1と2の閾値・2と3の閾値の2次元ベクトル
sh1nsh1n
  • 関数
    • real ordered_logistic_lpmf(ints k | vector eta, vectors c)
      • lpmf: log probability mass function (対数確率質量関数)
      • 正規化された確率となるため、 target += ordered_logistic_lpmf(...) のように使える
    • ordered_logistic_lupmf(ints k | vector eta, vectors c)
      • unnormalized lpmf
      • MCMCなどでは正規化しなくてもよいため、こちらのほうが高速になる?
    • ordered_logistic_rng(real eta, vector c)
      • 事後予測・予測分布サンプリングを行うRNG
      • generated quantities ブロックの中で、シミュレーションや予測値を生成したりできる
sh1nsh1n

CmdStanPyで推定させてみる

データ生成:

import numpy as np

np.random.seed(123)

N = 100
x = np.random.uniform(-2, 2, size=N)  # 説明変数

# 真のパラメータ
beta_true = 1.0
c_true = [-1.0, 1.0]  # c1 < c2

# ロジスティック関数
def inv_logit(u):
    return 1.0 / (1.0 + np.exp(-u))

# カテゴリ変数 Y を生成
Y = np.zeros(N, dtype=int)
for i in range(N):
    eta_i = beta_true * x[i]
    p1 = inv_logit(c_true[0] - eta_i)    # P(Y <= 1)
    p2 = inv_logit(c_true[1] - eta_i)    # P(Y <= 2)
    u = np.random.rand()
    if u < p1:
        Y[i] = 1
    elif u < p2:
        Y[i] = 2
    else:
        Y[i] = 3

print("Sample (x, Y):")
for i in range(5):
    print(f"x={x[i]:.2f}, Y={Y[i]}")

stanファイル

stan_code = """//stan
data {
  int<lower=1> N;
  array[N] int<lower=1, upper=3> Y;
  vector[N] x;
}
parameters {
  real beta;
  ordered[2] c;
}
model {
  // 事前分布
  beta ~ normal(0, 1);
  c ~ normal(0, 5);

  // 尤度の計算
  for (n in 1:N) {
    real eta_n = beta * x[n];
    target += ordered_logistic_lpmf(Y[n] | eta_n, c);
  }
}
"""

import os

stan_file = "順序ロジスティック回帰テスト.stan"
with open(stan_file, "w") as f:
    f.write(stan_code)

学習実行

from cmdstanpy import CmdStanModel
import pandas as pd

model = CmdStanModel(stan_file=stan_file, model_name="ordinal_3cat_model")

stan_data = {
    "N": N,
    "Y": Y,
    "x": x
}

fit = model.sample(
    data=stan_data,
    chains=4,
    parallel_chains=4,
    iter_sampling=2000,
    seed=123
)

print(fit.summary().loc[["beta","c[1]","c[2]"],:])
sh1nsh1n

結果の出力例

19:40:46 - cmdstanpy - INFO - CmdStan done processing.
19:40:46 - cmdstanpy - WARNING - Non-fatal error during sampling:
Exception: ordered_logistic: Cut-points is not a valid ordered vector. The element at 2 is -25.4371, but should be greater than the previous element, -25.4371 (in '2025-01-17_順序ロジスティック回帰テスト.stan', line 19, column 4 to column 53)
	Exception: ordered_logistic: Cut-points is not a valid ordered vector. The element at 2 is -25.7445, but should be greater than the previous element, -25.7445 (in '2025-01-17_順序ロジスティック回帰テスト.stan', line 19, column 4 to column 53)
	Exception: ordered_logistic: Cut-points is not a valid ordered vector. The element at 2 is -4532.61, but should be greater than the previous element, -4532.61 (in '2025-01-17_順序ロジスティック回帰テスト.stan', line 19, column 4 to column 53)
	Exception: ordered_logistic: Cut-points is not a valid ordered vector. The element at 2 is -126.652, but should be greater than the previous element, -126.652 (in '2025-01-17_順序ロジスティック回帰テスト.stan', line 19, column 4 to column 53)
Exception: ordered_logistic: Cut-points is not a valid ordered vector. The element at 2 is -14.6781, but should be greater than the previous element, -14.6781 (in '2025-01-17_順序ロジスティック回帰テスト.stan', line 19, column 4 to column 53)
	Exception: ordered_logistic: Cut-points is not a valid ordered vector. The element at 2 is -15.5614, but should be greater than the previous element, -15.5614 (in '2025-01-17_順序ロジスティック回帰テスト.stan', line 19, column 4 to column 53)
	Exception: ordered_logistic: Cut-points is not a valid ordered vector. The element at 2 is -1471.08, but should be greater than the previous element, -1471.08 (in '2025-01-17_順序ロジスティック回帰テスト.stan', line 19, column 4 to column 53)
Exception: ordered_logistic: Cut-points is not a valid ordered vector. The element at 2 is -655.369, but should be greater than the previous element, -655.369 (in '2025-01-17_順序ロジスティック回帰テスト.stan', line 19, column 4 to column 53)
Consider re-running with show_console=True if the above output is unclear!

          Mean      MCSE    StdDev       MAD        5%       50%       95%  \
beta  0.820948  0.003116  0.206858  0.206865  0.487227  0.817506  1.158030   
c[1] -1.065400  0.003796  0.241766  0.242870 -1.469930 -1.063470 -0.672406   
c[2]  1.083340  0.002556  0.238240  0.241798  0.699302  1.082650  1.481470   

      ESS_bulk  ESS_tail    R_hat  
beta   4459.22   4733.79  1.00090  
c[1]   4116.39   4790.56  1.00149  
c[2]   8661.84   6371.00  1.00053  
  • beta, c[1], c[2] は、真の値にそれなりに近い
  • R_hatが1付近、ESS(有効サンプルサイズ)が大きいため、問題なさそう
  • Exceptionがいくつか出力されていて不穏だが、これは学習過程でorderedでないcが一時的にサンプリングされてしまったことによるものなので、結果に問題はなさそう
このスクラップは2025/01/17にクローズされました