Closed6
Stanで順序ロジスティック回帰
Stanで使えそうなもの
-
Bounded Discrete Distributions
- 変数のサンプリング
k ~ ordered_logistic(eta, c)
-
eta
: predictor- 線形回帰の場合、共変量*係数 のようになる
-
c
: それぞれのカテゴリ間の閾値- 3段階の回帰の場合、1と2の閾値・2と3の閾値の2次元ベクトル
-
- 変数のサンプリング
- 関数
-
real ordered_logistic_lpmf(ints k | vector eta, vectors c)
- lpmf: log probability mass function (対数確率質量関数)
- 正規化された確率となるため、
target += ordered_logistic_lpmf(...)
のように使える
-
ordered_logistic_lupmf(ints k | vector eta, vectors c)
- unnormalized lpmf
- MCMCなどでは正規化しなくてもよいため、こちらのほうが高速になる?
-
ordered_logistic_rng(real eta, vector c)
- 事後予測・予測分布サンプリングを行うRNG
- generated quantities ブロックの中で、シミュレーションや予測値を生成したりできる
-
CmdStanPyで推定させてみる
データ生成:
import numpy as np
np.random.seed(123)
N = 100
x = np.random.uniform(-2, 2, size=N) # 説明変数
# 真のパラメータ
beta_true = 1.0
c_true = [-1.0, 1.0] # c1 < c2
# ロジスティック関数
def inv_logit(u):
return 1.0 / (1.0 + np.exp(-u))
# カテゴリ変数 Y を生成
Y = np.zeros(N, dtype=int)
for i in range(N):
eta_i = beta_true * x[i]
p1 = inv_logit(c_true[0] - eta_i) # P(Y <= 1)
p2 = inv_logit(c_true[1] - eta_i) # P(Y <= 2)
u = np.random.rand()
if u < p1:
Y[i] = 1
elif u < p2:
Y[i] = 2
else:
Y[i] = 3
print("Sample (x, Y):")
for i in range(5):
print(f"x={x[i]:.2f}, Y={Y[i]}")
stanファイル
stan_code = """//stan
data {
int<lower=1> N;
array[N] int<lower=1, upper=3> Y;
vector[N] x;
}
parameters {
real beta;
ordered[2] c;
}
model {
// 事前分布
beta ~ normal(0, 1);
c ~ normal(0, 5);
// 尤度の計算
for (n in 1:N) {
real eta_n = beta * x[n];
target += ordered_logistic_lpmf(Y[n] | eta_n, c);
}
}
"""
import os
stan_file = "順序ロジスティック回帰テスト.stan"
with open(stan_file, "w") as f:
f.write(stan_code)
学習実行
from cmdstanpy import CmdStanModel
import pandas as pd
model = CmdStanModel(stan_file=stan_file, model_name="ordinal_3cat_model")
stan_data = {
"N": N,
"Y": Y,
"x": x
}
fit = model.sample(
data=stan_data,
chains=4,
parallel_chains=4,
iter_sampling=2000,
seed=123
)
print(fit.summary().loc[["beta","c[1]","c[2]"],:])
結果の出力例
19:40:46 - cmdstanpy - INFO - CmdStan done processing.
19:40:46 - cmdstanpy - WARNING - Non-fatal error during sampling:
Exception: ordered_logistic: Cut-points is not a valid ordered vector. The element at 2 is -25.4371, but should be greater than the previous element, -25.4371 (in '2025-01-17_順序ロジスティック回帰テスト.stan', line 19, column 4 to column 53)
Exception: ordered_logistic: Cut-points is not a valid ordered vector. The element at 2 is -25.7445, but should be greater than the previous element, -25.7445 (in '2025-01-17_順序ロジスティック回帰テスト.stan', line 19, column 4 to column 53)
Exception: ordered_logistic: Cut-points is not a valid ordered vector. The element at 2 is -4532.61, but should be greater than the previous element, -4532.61 (in '2025-01-17_順序ロジスティック回帰テスト.stan', line 19, column 4 to column 53)
Exception: ordered_logistic: Cut-points is not a valid ordered vector. The element at 2 is -126.652, but should be greater than the previous element, -126.652 (in '2025-01-17_順序ロジスティック回帰テスト.stan', line 19, column 4 to column 53)
Exception: ordered_logistic: Cut-points is not a valid ordered vector. The element at 2 is -14.6781, but should be greater than the previous element, -14.6781 (in '2025-01-17_順序ロジスティック回帰テスト.stan', line 19, column 4 to column 53)
Exception: ordered_logistic: Cut-points is not a valid ordered vector. The element at 2 is -15.5614, but should be greater than the previous element, -15.5614 (in '2025-01-17_順序ロジスティック回帰テスト.stan', line 19, column 4 to column 53)
Exception: ordered_logistic: Cut-points is not a valid ordered vector. The element at 2 is -1471.08, but should be greater than the previous element, -1471.08 (in '2025-01-17_順序ロジスティック回帰テスト.stan', line 19, column 4 to column 53)
Exception: ordered_logistic: Cut-points is not a valid ordered vector. The element at 2 is -655.369, but should be greater than the previous element, -655.369 (in '2025-01-17_順序ロジスティック回帰テスト.stan', line 19, column 4 to column 53)
Consider re-running with show_console=True if the above output is unclear!
Mean MCSE StdDev MAD 5% 50% 95% \
beta 0.820948 0.003116 0.206858 0.206865 0.487227 0.817506 1.158030
c[1] -1.065400 0.003796 0.241766 0.242870 -1.469930 -1.063470 -0.672406
c[2] 1.083340 0.002556 0.238240 0.241798 0.699302 1.082650 1.481470
ESS_bulk ESS_tail R_hat
beta 4459.22 4733.79 1.00090
c[1] 4116.39 4790.56 1.00149
c[2] 8661.84 6371.00 1.00053
- beta, c[1], c[2] は、真の値にそれなりに近い
- R_hatが1付近、ESS(有効サンプルサイズ)が大きいため、問題なさそう
- Exceptionがいくつか出力されていて不穏だが、これは学習過程でorderedでないcが一時的にサンプリングされてしまったことによるものなので、結果に問題はなさそう
このスクラップは2025/01/17にクローズされました