📝

[メモ] L1ノルム正則化を含む線形回帰に対するADMMとRustによる実装

2025/01/10に公開

拡張ラグランジュ法に基づいた最適化手法であるADMMをLASSOに適用したときのメモです。

拡張Lagrange法

以下の最適化問題を考える。

\begin{aligned} &\operatorname*{minimize}_{\bm x} && f(\bm x) \\ &\operatorname{subject~to} && c_m(\bm x) = 0 \quad (m = 1, \dots, M). \end{aligned}

さらにこの最適化問題の制約条件 c_m(\bm x) を罰金法で目的関数に組み込んだ最適化問題を考える。

\begin{aligned} &\operatorname*{minimize}_{\bm x} && f(\bm x) + \sum_{m=1}^M \frac{\rho_m}{2} c_m^2(\bm x) \\ &\operatorname{subject~to} && c_m(\bm x) = 0 \quad (m = 1, \dots, M). \end{aligned}

\bm \rho は罰金係数。これに対するLagrange未定乗数法によるラグランジアンは

\begin{aligned} L(\bm x, \bm \alpha, \bm \rho) &= f(\bm x) + \sum_{m=1}^M \alpha_m c_m(\bm x) + \sum_{m=1}^M \frac{\rho_m}{2} c_m^2(\bm x) \\ &= f(\bm x) + \sum_{m=1}^M \frac{\rho_m}{2} \left( c_m(\bm x) + \frac{\alpha_m}{\rho_m} \right)^2 - \frac{1}{2} \sum_{m=1}^M \frac{\alpha_m^2}{\rho_m} \end{aligned}

で与えられる。\bm \alpha は新たに導入されたLagrange未定乗数。このラグランジアンは拡張ラグランジアンと呼ばれる。

拡張Lagrange法では以下のステップを繰り返す。

\begin{aligned} &(1) && \bm x^{(k+1)} = \argmin_{\bm x} L(\bm x, \bm \alpha^{(k)}, \bm \rho^{(k)}). \\ &(2) && \alpha_m^{(k+1)} = \alpha_m^{(k)} + \rho_m^{(k)} c_m(\bm x^{(k+1)}). \end{aligned}

すなわち、

  1. \bm x に関する最適化問題を解く
  2. Lagrange未定乗数 \bm \alpha を更新する
    1. c_m(\bm x) = 0 を満たしていない場合: \rho_m^{(k)} c_m(\bm x^{(k+1)}) だけ \alpha_m が変化
    2. c_m(\bm x) = 0 を満たしている場合: \alpha_m は変化しない

という手続きを繰り返す。

この方法の利点は、Lagrange未定乗数 \bm \alpha を新たに導入し、その \bm \alpha に罰金係数の調整を押し付ける事によって、もとの定式化の罰金係数 \bm \rho を手続き上は変化させないという点にある。

ADMM

2つの凸関数 f(\bm x)g(\bm z) の和に対する最適化問題を考える。

\begin{aligned} &\operatorname*{minimize}_{\bm x \in \mathbb R^N} && f(\bm x) + g(\bm x). \end{aligned}

f(\bm x) 単体の最適化と g(\bm x) 単体の最適化はそれぞれ容易に解けるが、f(\bm x)g(\bm x) の和の最適化は難しい。そこで、2つの変数 \bm x\bm z を導入して以下の最適化問題を考える。

\begin{aligned} &\operatorname*{minimize}_{\bm x, \bm z \in \mathbb R^N} && f(\bm x) + g(\bm z) \\ &\operatorname{subject~to} && x_n - z_n = 0 \quad (n = 1, \dots, N). \end{aligned}

この制約付き最適化問題に対する拡張ラグランジアンは

\begin{aligned} L(\bm x, \bm z, \bm \alpha, \bm \rho) &= f(\bm x) + g(\bm z) + \frac{1}{2} \sum_{n=1}^N \rho_n \left( x_n - z_n + \frac{\alpha_n}{\rho_n} \right)^2 + \mathrm{const}, \\ L(\bm x, \bm z, \bm u, \bm \rho) &= f(\bm x) + g(\bm z) + \frac{1}{2} \sum_{n=1}^N \rho_n \left( x_n - z_n + u_n \right)^2 + \mathrm{const}. \end{aligned}

ただし u_n = \alpha_n / \rho_n である。あとは通常の拡張Lagrange法と同様にして

\begin{aligned} &(1)&& (\bm x[t+1], \bm z[t+1]) = \argmin_{\bm x, \bm z} L(\bm x, \bm z, \bm \alpha[t], \rho). \\ &(2)&& \alpha_n[t+1] = \alpha_n[t] + \rho_n (x_n[t+1] - z_n[t+1]). \end{aligned}

前半は \bm x\bm z について分離して解くことにし、さらに u_n[t] = \alpha_n[t] / \rho_n であることを反映させると

\begin{aligned} &(1)& \bm x[t+1] &= \argmin_{\bm x} L(\bm x, \bm z[t], \bm u[t], \rho), \\ &(2)& \bm z[t+1] &= \argmin_{\bm z} L(\bm x[t+1], \bm z, \bm u[t], \rho), \\ &(3)& \bm u[t+1] &= \bm u[t] + \bm x[t+1] - \bm z[t+1]. \end{aligned}

これがADMMと呼ばれる最適化手続きである。拡張Lagrange法におけるLagrange未定乗数 \bm \alpha はADMMにおいて \bm u に置き換えられていることに注意。

軟判定閾値関数

L1正則化に対する最適化アルゴリズムの内部では以下の型の最小化問題が頻出する。

\begin{aligned} &\operatorname*{minimize}_{x} && f(x) = \frac{1}{2} (x - a)^2 + b |x|. \end{aligned}

x の値に関して場合分けして変形すると

\begin{aligned} f(x) &= \left\{ \begin{aligned} &\frac{1}{2} (x - (a - b))^2 - \frac{1}{2} (a - b)^2 && (x > 0), \\ &\frac{1}{2} a^2 && (x = 0), \\ &\frac{1}{2} (x - (a + b))^2 - \frac{1}{2} (a + b)^2 && (x < 0). \end{aligned} \right. \end{aligned}

最適解は ab の関係によって場合分けされ、

\begin{aligned} \argmin_x f(x) &= \left\{ \begin{aligned} &a - b && (a > b), \\ &0 && (-b \le a \le b), \\ &a + b && (a < -b). \end{aligned} \right. \end{aligned}

この \argmin_x f(x) はしばしば軟判定閾値関数と呼ばれる。以降 \mathcal S_{b}(a) と書くことにする。

L1正則化問題に対するADMM

さてここからが本題。L1正則化の最適化問題は

\begin{aligned} &\operatorname*{minimize}_{\bm w \in \mathbb R^N} && \frac{1}{2} \sum_{m=1}^M (y_m - \bm x_m^\top \bm w) + \sum_{n=1}^N \lambda_n |w_n|. \end{aligned}

ただし

\begin{aligned} \bm X \in \mathbb R^{M \times N}, \quad \bm y \in \mathbb R^M, \quad \bm w \in \mathbb R^N. \end{aligned}

これは2つの凸な目的関数

\begin{aligned} f(\bm w) &= \frac{1}{2} \| \bm y - \bm X \bm w \|_2^2, \\ g(\bm w) &= \lambda_n |w_n| \end{aligned}

の和であるから、ADMMが適用できる。拡張ラグランジアンは

\begin{aligned} L(\bm w, \bm z, \bm \lambda, \rho) &= \frac{1}{2} \sum_{m=1}^M (y_m - \bm x_m^\top \bm w) + \sum_{n=1}^N \lambda_n |w_n| + \frac{1}{2} \sum_{n=1}^N \rho_n \left( w_n - z_n + u_n \right)^2 + \mathrm{const} . \end{aligned}

あとはこれを \bm x\bm z について最小化すれば更新規則が得られる。

wについての最小化問題

まずは \bm w についての最小化問題の目的関数を変形していく。

\begin{aligned} & \sum_{m=1}^M (y_m - \bm x_m^\top \bm w) + \sum_{n=1}^N \rho_n \left( w_n - z_n + u_n \right)^2 \\ &= (\bm y - \bm X \bm w)^\top (\bm y - \bm X \bm w) + (\bm w - \bm z + \bm u)^\top \operatorname{diag}\{ \rho_n \} (\bm w - \bm z + \bm u) \\ &= \bm w^\top ( \bm X^\top \bm X + \operatorname{diag}\{ \rho_n \} ) \bm w - 2 \bm w^\top ( \bm X^\top \bm y + \operatorname{diag}\{ \rho_n \} (\bm z - \bm u) ) + \mathrm{const} \\ &= (\bm w - \bm \mu)^\top \bm A (\bm w - \bm \mu) + \mathrm{const}. \end{aligned}

ただし

\begin{aligned} \bm A &= \bm X^\top \bm X + \operatorname{diag}\{ \rho_n \}, \\ \bm \mu &= \bm A^{-1} \left( \bm X^\top \bm y + \operatorname{diag}\{ \rho_n \} (\bm z - \bm u) \right). \end{aligned}

ゆえに目的関数は二次形式であるから、最小化問題の解は \bm w = \bm \mu である。

\begin{aligned} \argmin_{\bm w} L(\bm w, \bm z, \bm \lambda, \rho) = \bm A^{-1} \left( \bm X^\top \bm y + \operatorname{diag}\{ \rho_n \} (\bm z - \bm u) \right). \end{aligned}

\bm w の事後分布がガウス分布になるような状況の最大事後確率推定と同じことをやっている。\bm A^{-1} は事後共分散行列 \mathbb E[\bm w \bm w^\top] に対応し、\bm \mu は事後平均ベクトル \mathbb E[\bm w] に対応する。

zについての最小化問題

次に \bm z についての最小化問題。目的関数は

\begin{aligned} & \sum_{n=1}^N \lambda_n |w_n| + \frac{1}{2} \sum_{n=1}^N \rho_n \left( w_n - z_n + u_n \right)^2 + \mathrm{const} \\ &= \sum_{n=1}^N \rho_n \left( \left( w_n - z_n + u_n \right)^2 + \frac{\lambda_n}{\rho_n} |w_n| \right) + \mathrm{const} \end{aligned}

であることから、線形分離可能である。絶対値を含む関数の最小化問題

\begin{aligned} &\operatorname*{minimize}_{z_n} && \left( w_n - z_n + u_n \right)^2 + \frac{\lambda_n}{\rho_n} |w_n| \end{aligned}

を各 n に関して解けば良い。したがって軟判定閾値関数 \mathcal S_{\lambda_n / \rho_n}(w_n + u_n) を用いて

\begin{aligned} \argmin_{z_n} L(\bm w, \bm z, \bm \lambda, \rho) &= \mathcal S_{\lambda_n / \rho_n}(w_n + u_n). \end{aligned}

さらに軟判定閾値関数がベクトルに対して要素ごとに適用されるように

\begin{aligned} \mathcal S_{\bm a}(\bm x) = [ \mathcal S_{a_n}(x_n) ]_n \in \mathbb R^N \end{aligned}

とすれば、

\begin{aligned} \argmin_{\bm z} L(\bm w, \bm z, \bm \lambda, \rho) &= \mathcal S_{[\lambda_n / \rho_n]}(\bm w + \bm u). \end{aligned}

L1正則化問題に対するADMMの更新規則

以上の結果をまとめると、L1正則化問題に対するADMMの更新規則は以下の通りである。

\begin{aligned} &(1)& \bm w[t+1] &= \bm A^{-1} \left( \bm X^\top \bm y + \operatorname{diag}\{ \rho_n \} (\bm z[t] - \bm u[t]) \right), \\ &(2)& \bm z[t+1] &= \mathcal S_{[\lambda_n / \rho_n]}(\bm w[t+1] + \bm u[t]), \\ &(3)& \bm u[t+1] &= \bm u[t] + \bm w[t+1] - \bm z[t+1]. \end{aligned}

ただし

\begin{aligned} \bm A &= \bm X^\top \bm X + \operatorname{diag}\{ \rho_n \}. \end{aligned}

特に正則化係数 \lambda_1, \dots, \lambda_N がすべて等しい場合はLASSO

\begin{aligned} \operatorname*{minimize}_{\bm w} \quad \frac{1}{2} \| \bm y - \bm X \bm w \|_2^2 + \lambda \| \bm w \|_1 \end{aligned}

を解いていることになる。ADMMの手続き上罰金係数 \bm \rho は変化しないから、\bm A も変化しない。よって逆行列 \bm A^{-1} は一度だけ計算すればよく、しかも \bm \rho をきちんと設定すれば正定値対称行列となるので計算が安定する。

Rust + Nalgebraによる実装

以下はRustによる雑な実装例。

# Cargo.toml
[package]
name = "admm-rust"
version = "0.1.0"
license = "Apache-2.0"
edition = "2021"

[dependencies]
nalgebra = { version = "0.33", features = ["rand"] }
// main.rs
use nalgebra as na;

fn soft_thresholding(x: &f64, threshold: &f64) -> f64 {
    // returns the soft thresholding of x with threshold
    if x > threshold {
        return x - threshold;
    } else if x < &-threshold {
        return x + threshold;
    } else {
        return 0.0;
    }
}

fn soft_thresholding_vec(x_vec: &na::DVector<f64>, threshold_vec: &na::DVector<f64>) -> na::DVector<f64> {
    // returns the soft thresholding of each element of x_vec with the corresponding element of threshold_vec
    let mut result_vec: na::DVector<f64> = na::DVector::zeros(x_vec.len());
    let num_variables: usize = x_vec.len();
    for i in 0..num_variables {
        result_vec[i] = soft_thresholding(&x_vec[i], &threshold_vec[i]);
    }
    return result_vec;
}

fn admm(
    x_mat: na::DMatrix<f64>,
    y_vec: na::DVector<f64>,
    lam_vec: na::DVector<f64>,
    z_init_vec: na::DVector<f64>,
    u_init_vec: na::DVector<f64>,
    rho_vec: na::DVector<f64>,
    max_iter: usize,
    return_history: bool,
    show_progress: bool,
) -> (na::DVector<f64>, na::DVector<f64>, na::DVector<f64>, Vec<na::DVector<f64>>, Vec<na::DVector<f64>>, Vec<na::DVector<f64>>) {
    let (num_data_points, num_variables): (usize, usize) = x_mat.shape();

    assert!(y_vec.shape() == (num_data_points, 1));
    assert!(lam_vec.shape() == (num_variables, 1));
    assert!(z_init_vec.shape() == (num_variables, 1));
    assert!(u_init_vec.shape() == (num_variables, 1));
    assert!(rho_vec.shape() == (num_variables, 1));

    let rho_diag: na::DMatrix<f64> = na::DMatrix::from_diagonal(&rho_vec);
    let a_mat: na::DMatrix<f64> = x_mat.transpose() * &x_mat + &rho_diag;
    let a_mat_chol= a_mat.cholesky().unwrap();
    let threshold_vec: na::DVector<f64> = lam_vec.component_div(&rho_vec);

    let mut w_vec = na::DVector::zeros(num_variables);
    let mut z_vec = z_init_vec;
    let mut u_vec = u_init_vec;

    let mut w_vec_history: Vec<na::DVector<f64>> = Vec::new();
    let mut z_vec_history: Vec<na::DVector<f64>> = Vec::new();
    let mut u_vec_history: Vec<na::DVector<f64>> = Vec::new();

    for t in 0..max_iter {
        let b: na::DVector<f64> = &x_mat.transpose() * &y_vec + &rho_diag * (&z_vec - &u_vec);
        w_vec = a_mat_chol.solve( &b );

        let w_vec_plus_u_vec: na::DVector<f64> = &w_vec + &u_vec;
        z_vec = soft_thresholding_vec( &w_vec_plus_u_vec, &threshold_vec );
        u_vec = &u_vec + &w_vec - &z_vec;

        if return_history {
            w_vec_history.push(w_vec.clone());
            z_vec_history.push(z_vec.clone());
            u_vec_history.push(u_vec.clone());
        }

        if show_progress {
            let rmse: f64 = (&x_mat * &w_vec - &y_vec).norm() / f64::sqrt(num_variables as f64);
            println!("[{:>4} / {}] {}", t, max_iter, rmse);
        }
    }

    return (w_vec, z_vec, u_vec, w_vec_history, z_vec_history, u_vec_history);
}

fn main() {
    let num_variables: usize = 256;
    let num_data_points: usize = 128;
    let density: f64 = 0.5;
    let rho: f64 = 1.0;
    let lam: f64 = 1.0;
    let max_iter: usize = 100;

    let x_true_mat: na::DMatrix<f64> =
        na::DMatrix::new_random(num_data_points, num_variables) * 2.0
        - na::DMatrix::from_element(num_data_points, num_variables, 1.0);

    let w_true_vec_row: na::DVector<f64> =
        na::DVector::new_random(num_variables) * 2.0
        - na::DVector::from_element(num_variables, 1.0);
    let w_true_vec_mask: na::DVector<f64> =
        na::DVector::new_random(num_variables).map(|x:f64| if x < density { 1.0 } else { 0.0 });
    let w_true_vec: na::DVector<f64> = w_true_vec_row.component_mul(&w_true_vec_mask);

    let y_true_vec: na::DVector<f64> =
        &x_true_mat * &w_true_vec;

    let z_init_vec: na::DVector<f64> =
        na::DVector::from_element(num_variables, 0.0);
    let u_init_vec: na::DVector<f64> =
        na::DVector::from_element(num_variables, 0.0);
    let rho_vec: na::DVector<f64> =
        na::DVector::from_element(num_variables, rho);
    let lam_vec: na::DVector<f64> =
        na::DVector::from_element(num_variables, lam);

    let (_w_est_vec, _, _, _, _, _) = admm(
        x_true_mat,
        y_true_vec,
        lam_vec,
        z_init_vec,
        u_init_vec,
        rho_vec,
        max_iter,
        false,
        true
    );

    let mse: f64 = (&w_true_vec - &_w_est_vec).norm_squared() / num_variables as f64;
    println!("MSE of w_est vs. w_true: {}", mse);
}

データ点数 128、変数の数 256 より、データの比 0.5 の列決定系の状況におけるLASSOとなっている。定数 density が大きいほど真の係数 w_true_vec に含まれる非零の成分の割合が大きくなる。色々試してみると、予想通り density を大きくしていくほどMSEが大きくなる。非零要素が多いほど確かに係数の推定に失敗しやすいことが確認できる。

Discussion