🧩

mean_squared_error のエラー

に公開

概要

生成AI にコードを書かせていると,よく sklearn.metricsmean_squared_error でミスをしてしまいます.この記事ではその原因と解決策を説明します.

背景・課題

生成AI にコードを書かせていると, sklearn.metricsmean_squared_errorsquared 引数を指定してしまうことがあります.
以下 「sklearn.metrics の mean_squared_error の使用例を示して.」 に対するChatGPT の解答例です:

以下は sklearn.metrics の mean_squared_error を使った例です.MSE(平均二乗誤差)と,squared=False を指定してRMSE(平方根平均二乗誤差)を計算する方法を示しています.

from sklearn.metrics import mean_squared_error
import numpy as np

def calculate_mse(y_true, y_pred, squared=True):
    """Calculate Mean Squared Error (MSE) or Root Mean Squared Error (RMSE).

    Args:
        y_true (array-like): 正解ラベルの配列.
        y_pred (array-like): 予測値の配列.
        squared (bool): True の場合は MSE,False の場合は RMSE を返す.デフォルトは True.

    Returns:
        float: MSE または RMSE の値.
    """
    return mean_squared_error(y_true, y_pred, squared=squared)

# サンプルデータ
y_true = np.array([3.0, -0.5, 2.0, 7.0])
y_pred = np.array([2.5,  0.0, 2.0, 8.0])

# MSE の計算
mse = calculate_mse(y_true, y_pred)
print(f"Mean Squared Error: {mse}")

# RMSE の計算
rmse = calculate_mse(y_true, y_pred, squared=False)
print(f"Root Mean Squared Error: {rmse}")

このコードの実行結果は以下のとおりです.

TypeError: got an unexpected keyword argument 'squared'

これは,「そんな引数ないぞ」というエラーです.

ただし,sklearn のバージョンは 1.7.0 です.

解決策

mean_squared_error(y_true, y_pred, squared=squared) の部分を mean_squared_error(y_true, y_pred)squared 引数を削除してあげれば解決します.

また,RMSE を計算したい場合は,以下の手段があります.

  1. mean_squared_error(y_true, y_pred, squared=False) の部分を np.sqrt(mean_squared_error(y_true, y_pred)) として,RMSE を計算する.
  2. sklearn.metricsroot_mean_squared_error を使用する.

以下実行例:

import numpy as np
from sklearn.metrics import mean_squared_error, root_mean_squared_error

y_true = np.array([3.0, -0.5, 2.0, 7.0])
y_pred = np.array([2.5,  0.0, 2.0, 8.0])

# 方法1: np.sqrt() を使用する
rmse_1 = np.sqrt(mean_squared_error(y_true, y_pred))
print(f"方法1: {rmse_1:.4f}")

# 方法2: root_mean_squared_error を使用する
mse = mean_squared_error(y_true, y_pred)
rmse_2 = root_mean_squared_error(y_true, y_pred)
print(f"方法2: {rmse_2:.4f}")

出力

方法1: 0.6124
方法2: 0.6124

原因

以前のバージョンの scikit-learn では,mean_squared_error 関数に squared=False という引数を指定することで,平均二乗誤差(MSE)ではなく,その平方根である 二乗平均平方根誤差(RMSE) を計算できました.しかし,scikit-learn の バージョン 1.4 (2024年2月頃リリース)で squared 引数は非推奨となり,将来のバージョン(1.6)(2024年12月頃リリース)で削除されることが決定しました.

まとめ

生成AI の出力する sklearn.metricsmean_squared_error には squared の引数が指定されていることがあるので,削除してあげましょう.

参考リンク

Discussion