[メモ] L1ノルム正則化を含む線形回帰に対するADMMとRustによる実装
拡張ラグランジュ法に基づいた最適化手法であるADMMをLASSOに適用したときのメモです。
拡張Lagrange法
以下の最適化問題を考える。
さらにこの最適化問題の制約条件
で与えられる。
拡張Lagrange法では以下のステップを繰り返す。
すなわち、
-
に関する最適化問題を解く\bm x - Lagrange未定乗数
を更新する\bm \alpha -
を満たしていない場合:c_m(\bm x) = 0 だけ\rho_m^{(k)} c_m(\bm x^{(k+1)}) が変化\alpha_m -
を満たしている場合:c_m(\bm x) = 0 は変化しない\alpha_m
-
という手続きを繰り返す。
この方法の利点は、Lagrange未定乗数
ADMM
2つの凸関数
この制約付き最適化問題に対する拡張ラグランジアンは
ただし
前半は
これがADMMと呼ばれる最適化手続きである。拡張Lagrange法におけるLagrange未定乗数
軟判定閾値関数
L1正則化に対する最適化アルゴリズムの内部では以下の型の最小化問題が頻出する。
最適解は
この
L1正則化問題に対するADMM
さてここからが本題。L1正則化の最適化問題は
ただし
これは2つの凸な目的関数
の和であるから、ADMMが適用できる。拡張ラグランジアンは
あとはこれを
wについての最小化問題
まずは
ただし
ゆえに目的関数は二次形式であるから、最小化問題の解は
zについての最小化問題
次に
であることから、線形分離可能である。絶対値を含む関数の最小化問題
を各
さらに軟判定閾値関数がベクトルに対して要素ごとに適用されるように
とすれば、
L1正則化問題に対するADMMの更新規則
以上の結果をまとめると、L1正則化問題に対するADMMの更新規則は以下の通りである。
ただし
特に正則化係数
を解いていることになる。ADMMの手続き上罰金係数
Rust + Nalgebraによる実装
以下はRustによる雑な実装例。
# Cargo.toml
[package]
name = "admm-rust"
version = "0.1.0"
license = "Apache-2.0"
edition = "2021"
[dependencies]
nalgebra = { version = "0.33", features = ["rand"] }
// main.rs
use nalgebra as na;
fn soft_thresholding(x: &f64, threshold: &f64) -> f64 {
// returns the soft thresholding of x with threshold
if x > threshold {
return x - threshold;
} else if x < &-threshold {
return x + threshold;
} else {
return 0.0;
}
}
fn soft_thresholding_vec(x_vec: &na::DVector<f64>, threshold_vec: &na::DVector<f64>) -> na::DVector<f64> {
// returns the soft thresholding of each element of x_vec with the corresponding element of threshold_vec
let mut result_vec: na::DVector<f64> = na::DVector::zeros(x_vec.len());
let num_variables: usize = x_vec.len();
for i in 0..num_variables {
result_vec[i] = soft_thresholding(&x_vec[i], &threshold_vec[i]);
}
return result_vec;
}
fn admm(
x_mat: na::DMatrix<f64>,
y_vec: na::DVector<f64>,
lam_vec: na::DVector<f64>,
z_init_vec: na::DVector<f64>,
u_init_vec: na::DVector<f64>,
rho_vec: na::DVector<f64>,
max_iter: usize,
return_history: bool,
show_progress: bool,
) -> (na::DVector<f64>, na::DVector<f64>, na::DVector<f64>, Vec<na::DVector<f64>>, Vec<na::DVector<f64>>, Vec<na::DVector<f64>>) {
let (num_data_points, num_variables): (usize, usize) = x_mat.shape();
assert!(y_vec.shape() == (num_data_points, 1));
assert!(lam_vec.shape() == (num_variables, 1));
assert!(z_init_vec.shape() == (num_variables, 1));
assert!(u_init_vec.shape() == (num_variables, 1));
assert!(rho_vec.shape() == (num_variables, 1));
let rho_diag: na::DMatrix<f64> = na::DMatrix::from_diagonal(&rho_vec);
let a_mat: na::DMatrix<f64> = x_mat.transpose() * &x_mat + &rho_diag;
let a_mat_chol= a_mat.cholesky().unwrap();
let threshold_vec: na::DVector<f64> = lam_vec.component_div(&rho_vec);
let mut w_vec = na::DVector::zeros(num_variables);
let mut z_vec = z_init_vec;
let mut u_vec = u_init_vec;
let mut w_vec_history: Vec<na::DVector<f64>> = Vec::new();
let mut z_vec_history: Vec<na::DVector<f64>> = Vec::new();
let mut u_vec_history: Vec<na::DVector<f64>> = Vec::new();
for t in 0..max_iter {
let b: na::DVector<f64> = &x_mat.transpose() * &y_vec + &rho_diag * (&z_vec - &u_vec);
w_vec = a_mat_chol.solve( &b );
let w_vec_plus_u_vec: na::DVector<f64> = &w_vec + &u_vec;
z_vec = soft_thresholding_vec( &w_vec_plus_u_vec, &threshold_vec );
u_vec = &u_vec + &w_vec - &z_vec;
if return_history {
w_vec_history.push(w_vec.clone());
z_vec_history.push(z_vec.clone());
u_vec_history.push(u_vec.clone());
}
if show_progress {
let rmse: f64 = (&x_mat * &w_vec - &y_vec).norm() / f64::sqrt(num_variables as f64);
println!("[{:>4} / {}] {}", t, max_iter, rmse);
}
}
return (w_vec, z_vec, u_vec, w_vec_history, z_vec_history, u_vec_history);
}
fn main() {
let num_variables: usize = 256;
let num_data_points: usize = 128;
let density: f64 = 0.5;
let rho: f64 = 1.0;
let lam: f64 = 1.0;
let max_iter: usize = 100;
let x_true_mat: na::DMatrix<f64> =
na::DMatrix::new_random(num_data_points, num_variables) * 2.0
- na::DMatrix::from_element(num_data_points, num_variables, 1.0);
let w_true_vec_row: na::DVector<f64> =
na::DVector::new_random(num_variables) * 2.0
- na::DVector::from_element(num_variables, 1.0);
let w_true_vec_mask: na::DVector<f64> =
na::DVector::new_random(num_variables).map(|x:f64| if x < density { 1.0 } else { 0.0 });
let w_true_vec: na::DVector<f64> = w_true_vec_row.component_mul(&w_true_vec_mask);
let y_true_vec: na::DVector<f64> =
&x_true_mat * &w_true_vec;
let z_init_vec: na::DVector<f64> =
na::DVector::from_element(num_variables, 0.0);
let u_init_vec: na::DVector<f64> =
na::DVector::from_element(num_variables, 0.0);
let rho_vec: na::DVector<f64> =
na::DVector::from_element(num_variables, rho);
let lam_vec: na::DVector<f64> =
na::DVector::from_element(num_variables, lam);
let (_w_est_vec, _, _, _, _, _) = admm(
x_true_mat,
y_true_vec,
lam_vec,
z_init_vec,
u_init_vec,
rho_vec,
max_iter,
false,
true
);
let mse: f64 = (&w_true_vec - &_w_est_vec).norm_squared() / num_variables as f64;
println!("MSE of w_est vs. w_true: {}", mse);
}
データ点数 128、変数の数 256 より、データの比 0.5 の列決定系の状況におけるLASSOとなっている。定数 density
が大きいほど真の係数 w_true_vec
に含まれる非零の成分の割合が大きくなる。色々試してみると、予想通り density
を大きくしていくほどMSEが大きくなる。非零要素が多いほど確かに係数の推定に失敗しやすいことが確認できる。
Discussion