Elixirで(microsoft/)LightGBM クラス分類を試す
下記のライブラリを使ったLightGBMのモデル構築/予測の紹介記事になります。
本記事ではLivebookを使います。
例えば、dockerでLivebookの公式イメージを起動してください。
docker run \
-p 8080:8080 \
-p 8081:8081 \
--pull always \
ghcr.io/livebook-dev/livebook
- その他方法でも可能と思いますが、macOSなどで動作確認をしていません。
前準備
新しいノートブックを開いたら、Notebook dependencies and setup 欄でインストールします。
またモデル構築において(ライブラリ内部で)csvファイルを作るなど、がっつりストレージを使う都合で作業フォルダの設定が必要です。Livebookでは明示する必要があります。
Mix.install(
[
{:lgbm_ex, "0.0.2", github: "tato-gh/lgbm_ex"}
]
)
Application.put_env(:lgbm_ex, :workdir, Path.join(System.tmp_dir!(), "lgbm_ex"))
モデル作成
Irisデータを使ったモデル構築例です。モデル構築に使うデータはExplorer.DataFrame
形式で扱います。
refs: https://hexdocs.pm/explorer/Explorer.DataFrame.html
df = Explorer.Datasets.iris()
{mapping, df} = LgbmEx.preproccessing_label_encode(df, "species")
- ここではデータを
Explorer.Datasets
から頂戴しています。そのため、最初からExplorer.DataFrame
形式になっています。 - ただし、LightGBMには文字列のままでは投入できないのでラベルエンコードしています。
実行結果:
{%{"Iris-setosa" => 0, "Iris-versicolor" => 1, "Iris-virginica" => 2},
#Explorer.DataFrame<
Polars[150 x 5]
sepal_length f64 [5.1, 4.9, 4.7, 4.6, 5.0, ...]
sepal_width f64 [3.5, 3.0, 3.2, 3.1, 3.6, ...]
petal_length f64 [1.4, 1.4, 1.3, 1.5, 1.4, ...]
petal_width f64 [0.2, 0.2, 0.2, 0.2, 0.2, ...]
species string ["0", "0", "0", "0", "0", ...]
>}
モデル構築にはfit
関数を使用します。なお、fit
の4つ目のオプションは、microsoft/LightGBM
のparametersにあたります。
model =
LgbmEx.fit("test", df, "species",
objective: "multiclass",
metric: "multi_logloss",
num_class: 3,
num_iterations: 20
)
modelには構築結果の各データが格納されています。
実行結果:
%LgbmEx.Model{
workdir: "/tmp/lgbm_ex",
name: "test",
files: %{
parameter: "/tmp/lgbm_ex/test/parameter.txt",
model: "/tmp/lgbm_ex/test/model.txt",
train: "/tmp/lgbm_ex/test/train.csv",
train_log: "/tmp/lgbm_ex/test/train_log.txt",
validation: "/tmp/lgbm_ex/test/validation.csv"
},
parameters: [
task: "train",
data: "/tmp/lgbm_ex/test/train.csv",
output_model: "/tmp/lgbm_ex/test/model.txt",
label_column: 0,
saved_feature_importance_type: 1,
valid_data: "/tmp/lgbm_ex/test/validation.csv",
objective: "multiclass",
metric: "multi_logloss",
num_class: 3,
num_iterations: 20,
y_name: "species",
x_names: ["sepal_length", "sepal_width", "petal_length", "petal_width"]
],
ref: #Reference<0.1428843139.2291269646.145754>,
num_iterations: 20,
learning_steps: [
{0, 0.933515},
{1, 0.802883},
{2, 0.697101},
{3, 0.610171},
{4, 0.537393},
{5, 0.475927},
{6, 0.424058},
{7, 0.379137},
{8, 0.340792},
{9, 0.307383},
{10, 0.278473},
{11, 0.253297},
{12, 0.231293},
{13, 0.211347},
{14, 0.194058},
{15, 0.178451},
{16, 0.164777},
{17, 0.151918},
{18, 0.140454},
{19, 0.130164}
],
used_parameters: %{
"time_out" => 120,
"enable_bundle" => true,
"lambdarank_norm" => true,
"cat_l2" => 10,
"first_metric_only" => false,
"boosting" => "gbdt",
"objective" => "multiclass",
"verbosity" => 1,
"monotone_constraints_method" => "basic",
"multi_error_top_k" => 1,
"sigmoid" => 1,
"data" => "/tmp/lgbm_ex/test/train.csv",
"boost_from_average" => true,
"is_unbalance" => false,
"feature_fraction_bynode" => 1,
"refit_decay_rate" => 0.9,
"tweedie_variance_power" => 1.5,
"quant_train_renew_leaf" => false,
"extra_seed" => 6,
"min_data_per_group" => 100,
"bagging_freq" => 0,
"lambda_l1" => 0,
"label_column" => "0",
"gpu_device_id" => -1,
"data_sample_strategy" => "bagging",
"force_col_wise" => false,
"num_gpu" => 1,
"gpu_use_dp" => false,
"other_rate" => 0.1,
"num_machines" => 1,
"is_enable_sparse" => true,
"drop_seed" => 4,
"seed" => 0,
"deterministic" => false,
"cegb_penalty_split" => 0,
"max_drop" => 50,
"drop_rate" => 0.1,
"tree_learner" => "serial",
"num_class" => 3,
"zero_as_missing" => false,
"xgboost_dart_mode" => false,
"max_bin" => 255,
...
},
num_classes: 3,
num_features: 4,
feature_importance_split: [13.0, 52.0, 151.0, 76.0],
feature_importance_gain: [1.0072992438772417, 6.377874833182663, 1010.7391075709484,
517.3148670691899]
}
データ予測
構築したモデルと予測対象データをpredict
に渡して結果を取得します。
x_test =
[
[5.4, 3.9, 1.7, 0.4],
[5.7, 2.8, 4.5, 1.4],
[7.6, 3.0, 6.6, 2.2]
]
[p1, p2, p3] = LgbmEx.predict(model, x_test)
実行結果:
[
[0.9467519536231427, 0.026109531291447274, 0.027138515085409997],
[0.0398378208500453, 0.9141394236142755, 0.04602275553567928],
[0.02991031765897211, 0.03163328954666355, 0.9384563927943644]
]
それぞれspeciesのどれにあたるかの確率です。ラベルエンコード時のmapping
と照らし合わせる必要があります。
mapping
#=> %{"Iris-setosa" => 0, "Iris-versicolor" => 1, "Iris-virginica" => 2}
1つ目のデータ [5.4, 3.9, 1.7, 0.4]
は、Iris-setosaといえそうです(0.9467519536231427
)。
x_test
はExplorer.DataFrame
でも可能です。
# 適当に3つほど拝借
x_test_df = Explorer.DataFrame.slice(df, 0, 3)
[p1, p2, p3] = LgbmEx.predict(model, x_test_df)
端的には以上が、構築と予測になります。以下は付属するTipsです。
アーリーストッピング
モデル構築時に気にしないといけないことの1つが「過学習」です。学習データに過度にフィットするモデル(極端にいえばデータ1つ1つにマッピングされてしまう)は、実運用時の未知データに確率的な応答ができずAI技術として役に立たないことがあります。
過学習を回避する方法の1つに、検証データ(答えもわかる)を用意して、検証データに対して評価が落ちたタイミングで学習を止める方法があります(アーリーストッピング)。
# アーリーストッピングのための検証データを使ったモデル構築例
shuffled = Explorer.DataFrame.shuffle(df)
val_df = Explorer.DataFrame.slice(shuffled, 0, 15)
train_df = Explorer.DataFrame.slice(shuffled, 15..-1)
model =
LgbmEx.fit("sample_with_stopping", {train_df, val_df}, "species",
objective: "multiclass",
metric: "multi_logloss",
num_class: 3,
num_iterations: 100,
early_stopping_round: 2,
learning_rate: 0.1
)
実行結果は省略しますがnum_iterations: 100
を設定していますが30回程度で停止します。
学習対象属性の指定
fit
で使用する属性名の指定も可能です。
LgbmEx.fit("sample_with_stopping", {train_df, val_df}, {"species", ["sepal_length", "sepal_width"]},
objective: "multiclass",
metric: "multi_logloss",
num_class: 3,
num_iterations: 100,
early_stopping_round: 2,
learning_rate: 0.1
)
パラメータ変更
LightGBMにはパラメータが多く、数値によってモデルが大きく変わります。そのため、自動やあるいは手動で、よりよいパラメータを探すことになります。
refit_model
で異なるパラメータでの再学習が可能です。
model =
LgbmEx.refit_model(model, [
min_data_in_leaf: 3,
early_stopping_round: 5
])
モデルのコピー
fit
で物理的なファイル領域(nameにあたるディレクトリ)を使う関係で、異なるモデルを構築する際には別名をつけて行います。
new_model =
LgbmEx.copy_model(model, "copied")
|> LgbmEx.refit_model([learning_rate: 0.01])
モデルのロード
fit
で物理的なファイル領域(nameにあたるディレクトリ)を使うため、workdirとnameを指定することでモデルをロードすることができます。
model = LgbmEx.load_model("test")
モデルのzipとunzip
また既存ディレクトリを指定するだけではなく、zipしたファイルから展開も可能です。
(主にLivebookでモデル構築した後にKinoで手元にダウンロードするような用途を意図しています)
zip_path = LgbmEx.zip_model(model)
model = LgbmEx.unzip_model(zip_path, "load_from_zip")
交差検証
LgbmExとは切り離して下記にあります(が、こちらはより一層個人的なライブラリで不安定です)。
終わりに
LgbmExは
- LGBMExCli / LgbmExCapiの統合
-
Explorer.DataFrame
形式への対応
を目的に作りました。試している段階でもありますのでもし使用される場合はご注意ください - -;
Discussion