😀

Rustで最小二乗法を実装

2025/01/26に公開

備忘録です。

use nalgebra::{DMatrix, DVector, LU};

/// NumPyの `polyfit` を再現する関数
///
/// # 引数
///
/// * `x`: x座標のデータ点のスライス
/// * `y`: y座標のデータ点のスライス
/// * `degree`: 近似する多項式の次数
///
/// # 戻り値
///
/// * `Result<Vec<f64>, &'static str>`:
///     - `Ok(Vec<f64>)`: 多項式係数(次数の高い順)を格納したベクタ
///     - `Err(&'static str)`: エラーが発生した場合、エラーメッセージ


pub fn polyfit(x: &[f64], y: &[f64], degree: usize) -> Result<Vec<f64>, &'static str> {
    if x.len() != y.len() {
        return Err("x and y vectors must have the same length");
    }
    if x.len() <= degree {
        return Err("Number of data points must be greater than degree + 1");
    }
    if degree > 30 { // degreeが大きすぎると計算が不安定になる可能性があるので、上限を設定
        return Err("Degree is too large, may lead to unstable calculation");
    }

    let n = x.len();
    let m = degree + 1;
    let mut matrix_a = DMatrix::zeros(n, m);
    let vector_y = DVector::from_vec(y.to_vec());

    // 行列Aを構築
    for i in 0..n {
        for j in 0..m {
            matrix_a[(i, j)] = x[i].powi(j as i32);
        }
    }

    // 最小二乗法で係数を計算 (正規方程式を解く)
    // (A^T * A) * coefficients = A^T * y
    let matrix_a_t = matrix_a.transpose();
    let matrix_ata = &matrix_a_t * &matrix_a;
    let vector_aty = matrix_a_t * vector_y;

    // Solve (A^T * A) * coefficients = A^T * y using LU decomposition
    let lu_result = Some(LU::new(matrix_ata)).ok_or_else(|| {
        "Singular matrix encountered in LU decomposition, polyfit failed. Try reducing the degree or check input data."
    })?;

    let solution_result = lu_result.solve(&vector_aty);
    if solution_result.is_none() {
        return Err("Failed to solve linear system");
    }
    let solution = solution_result.unwrap();

    let mut coefficients = Vec::new();
    for i in 0..m {
        coefficients.push(solution[i]);
    }
    coefficients.reverse(); // Reverse to match numpy's output order (highest degree first)
    Ok(coefficients)
}

fn main() {
    let x_data = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0];
    let y_data = vec![0.0, 0.8, 0.9, 0.1, -0.8, -1.0];
    let degree = 3;

   let answer=polyfit(&x_data, &y_data, degree).unwrap();
    println!("{:?}",answer);
}
#[cfg(test)]
mod tests {
    use super::*;
    use approx::assert_relative_eq;

    #[test]
    fn test_polyfit() {
        let x_data = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0];
        let y_data = vec![0.0, 0.8, 0.9, 0.1, -0.8, -1.0];
        let degree = 3;

        match polyfit(&x_data, &y_data, degree) {
            Ok(coefficients) => {
                println!("Polyfit coefficients for degree {}: {:?}", degree, coefficients);
                // NumPyでのpolyfitの結果 (例): [ -0.08730159   0.64936508  -0.55873016   0.0984127 ]
                assert_relative_eq!(coefficients[0], -0.08730158730158728, epsilon = 1e-6);
                assert_relative_eq!(coefficients[1], 0.6493650793650794, epsilon = 1e-6);
                assert_relative_eq!(coefficients[2], -0.5587301587301588, epsilon = 1e-6);
                assert_relative_eq!(coefficients[3], 0.09841269841269842, epsilon = 1e-6);

            },
            Err(err) => {
                panic!("Error: {}", err);
            }
        }
    }

   
}

Discussion