mean_squared_error のエラー
概要
生成AI にコードを書かせていると,よく sklearn.metrics の mean_squared_error でミスをしてしまいます.この記事ではその原因と解決策を説明します.
背景・課題
生成AI にコードを書かせていると, sklearn.metrics の mean_squared_error の squared 引数を指定してしまうことがあります.
以下 「sklearn.metrics の mean_squared_error の使用例を示して.」 に対するChatGPT の解答例です:
以下は sklearn.metrics の mean_squared_error を使った例です.MSE(平均二乗誤差)と,squared=False を指定してRMSE(平方根平均二乗誤差)を計算する方法を示しています.
from sklearn.metrics import mean_squared_error import numpy as np def calculate_mse(y_true, y_pred, squared=True): """Calculate Mean Squared Error (MSE) or Root Mean Squared Error (RMSE). Args: y_true (array-like): 正解ラベルの配列. y_pred (array-like): 予測値の配列. squared (bool): True の場合は MSE,False の場合は RMSE を返す.デフォルトは True. Returns: float: MSE または RMSE の値. """ return mean_squared_error(y_true, y_pred, squared=squared) # サンプルデータ y_true = np.array([3.0, -0.5, 2.0, 7.0]) y_pred = np.array([2.5, 0.0, 2.0, 8.0]) # MSE の計算 mse = calculate_mse(y_true, y_pred) print(f"Mean Squared Error: {mse}") # RMSE の計算 rmse = calculate_mse(y_true, y_pred, squared=False) print(f"Root Mean Squared Error: {rmse}")
このコードの実行結果は以下のとおりです.
TypeError: got an unexpected keyword argument 'squared'
これは,「そんな引数ないぞ」というエラーです.
ただし,sklearn のバージョンは 1.7.0 です.
解決策
mean_squared_error(y_true, y_pred, squared=squared) の部分を mean_squared_error(y_true, y_pred) と squared 引数を削除してあげれば解決します.
また,RMSE を計算したい場合は,以下の手段があります.
-
mean_squared_error(y_true, y_pred, squared=False)の部分をnp.sqrt(mean_squared_error(y_true, y_pred))として,RMSE を計算する. -
sklearn.metricsのroot_mean_squared_errorを使用する.
以下実行例:
import numpy as np
from sklearn.metrics import mean_squared_error, root_mean_squared_error
y_true = np.array([3.0, -0.5, 2.0, 7.0])
y_pred = np.array([2.5, 0.0, 2.0, 8.0])
# 方法1: np.sqrt() を使用する
rmse_1 = np.sqrt(mean_squared_error(y_true, y_pred))
print(f"方法1: {rmse_1:.4f}")
# 方法2: root_mean_squared_error を使用する
mse = mean_squared_error(y_true, y_pred)
rmse_2 = root_mean_squared_error(y_true, y_pred)
print(f"方法2: {rmse_2:.4f}")
出力
方法1: 0.6124
方法2: 0.6124
原因
以前のバージョンの scikit-learn では,mean_squared_error 関数に squared=False という引数を指定することで,平均二乗誤差(MSE)ではなく,その平方根である 二乗平均平方根誤差(RMSE) を計算できました.しかし,scikit-learn の バージョン 1.4 (2024年2月頃リリース)で squared 引数は非推奨となり,将来のバージョン(1.6)(2024年12月頃リリース)で削除されることが決定しました.
まとめ
生成AI の出力する sklearn.metrics の mean_squared_error には squared の引数が指定されていることがあるので,削除してあげましょう.
Discussion