😽

Elixirで(microsoft/)LightGBM クラス分類を試す

2024/04/12に公開

下記のライブラリを使ったLightGBMのモデル構築/予測の紹介記事になります。

https://github.com/tato-gh/lgbm_ex

本記事ではLivebookを使います。

例えば、dockerでLivebookの公式イメージを起動してください。

docker run \
  -p 8080:8080 \
  -p 8081:8081 \
  --pull always \
  ghcr.io/livebook-dev/livebook
  • その他方法でも可能と思いますが、macOSなどで動作確認をしていません。

前準備

新しいノートブックを開いたら、Notebook dependencies and setup 欄でインストールします。

またモデル構築において(ライブラリ内部で)csvファイルを作るなど、がっつりストレージを使う都合で作業フォルダの設定が必要です。Livebookでは明示する必要があります。

Mix.install(
  [
    {:lgbm_ex, "0.0.2", github: "tato-gh/lgbm_ex"}
  ]
)

Application.put_env(:lgbm_ex, :workdir, Path.join(System.tmp_dir!(), "lgbm_ex"))

モデル作成

Irisデータを使ったモデル構築例です。モデル構築に使うデータはExplorer.DataFrame形式で扱います。

refs: https://hexdocs.pm/explorer/Explorer.DataFrame.html

df = Explorer.Datasets.iris()
{mapping, df} = LgbmEx.preproccessing_label_encode(df, "species")
  • ここではデータをExplorer.Datasetsから頂戴しています。そのため、最初からExplorer.DataFrame形式になっています。
  • ただし、LightGBMには文字列のままでは投入できないのでラベルエンコードしています。

実行結果:

{%{"Iris-setosa" => 0, "Iris-versicolor" => 1, "Iris-virginica" => 2},
 #Explorer.DataFrame<
   Polars[150 x 5]
   sepal_length f64 [5.1, 4.9, 4.7, 4.6, 5.0, ...]
   sepal_width f64 [3.5, 3.0, 3.2, 3.1, 3.6, ...]
   petal_length f64 [1.4, 1.4, 1.3, 1.5, 1.4, ...]
   petal_width f64 [0.2, 0.2, 0.2, 0.2, 0.2, ...]
   species string ["0", "0", "0", "0", "0", ...]
 >}

モデル構築にはfit関数を使用します。なお、fitの4つ目のオプションは、microsoft/LightGBMparametersにあたります。

model =
  LgbmEx.fit("test", df, "species",
    objective: "multiclass",
    metric: "multi_logloss",
    num_class: 3,
    num_iterations: 20
  )

modelには構築結果の各データが格納されています。

実行結果:

%LgbmEx.Model{
  workdir: "/tmp/lgbm_ex",
  name: "test",
  files: %{
    parameter: "/tmp/lgbm_ex/test/parameter.txt",
    model: "/tmp/lgbm_ex/test/model.txt",
    train: "/tmp/lgbm_ex/test/train.csv",
    train_log: "/tmp/lgbm_ex/test/train_log.txt",
    validation: "/tmp/lgbm_ex/test/validation.csv"
  },
  parameters: [
    task: "train",
    data: "/tmp/lgbm_ex/test/train.csv",
    output_model: "/tmp/lgbm_ex/test/model.txt",
    label_column: 0,
    saved_feature_importance_type: 1,
    valid_data: "/tmp/lgbm_ex/test/validation.csv",
    objective: "multiclass",
    metric: "multi_logloss",
    num_class: 3,
    num_iterations: 20,
    y_name: "species",
    x_names: ["sepal_length", "sepal_width", "petal_length", "petal_width"]
  ],
  ref: #Reference<0.1428843139.2291269646.145754>,
  num_iterations: 20,
  learning_steps: [
    {0, 0.933515},
    {1, 0.802883},
    {2, 0.697101},
    {3, 0.610171},
    {4, 0.537393},
    {5, 0.475927},
    {6, 0.424058},
    {7, 0.379137},
    {8, 0.340792},
    {9, 0.307383},
    {10, 0.278473},
    {11, 0.253297},
    {12, 0.231293},
    {13, 0.211347},
    {14, 0.194058},
    {15, 0.178451},
    {16, 0.164777},
    {17, 0.151918},
    {18, 0.140454},
    {19, 0.130164}
  ],
  used_parameters: %{
    "time_out" => 120,
    "enable_bundle" => true,
    "lambdarank_norm" => true,
    "cat_l2" => 10,
    "first_metric_only" => false,
    "boosting" => "gbdt",
    "objective" => "multiclass",
    "verbosity" => 1,
    "monotone_constraints_method" => "basic",
    "multi_error_top_k" => 1,
    "sigmoid" => 1,
    "data" => "/tmp/lgbm_ex/test/train.csv",
    "boost_from_average" => true,
    "is_unbalance" => false,
    "feature_fraction_bynode" => 1,
    "refit_decay_rate" => 0.9,
    "tweedie_variance_power" => 1.5,
    "quant_train_renew_leaf" => false,
    "extra_seed" => 6,
    "min_data_per_group" => 100,
    "bagging_freq" => 0,
    "lambda_l1" => 0,
    "label_column" => "0",
    "gpu_device_id" => -1,
    "data_sample_strategy" => "bagging",
    "force_col_wise" => false,
    "num_gpu" => 1,
    "gpu_use_dp" => false,
    "other_rate" => 0.1,
    "num_machines" => 1,
    "is_enable_sparse" => true,
    "drop_seed" => 4,
    "seed" => 0,
    "deterministic" => false,
    "cegb_penalty_split" => 0,
    "max_drop" => 50,
    "drop_rate" => 0.1,
    "tree_learner" => "serial",
    "num_class" => 3,
    "zero_as_missing" => false,
    "xgboost_dart_mode" => false,
    "max_bin" => 255,
    ...
  },
  num_classes: 3,
  num_features: 4,
  feature_importance_split: [13.0, 52.0, 151.0, 76.0],
  feature_importance_gain: [1.0072992438772417, 6.377874833182663, 1010.7391075709484,
   517.3148670691899]
}

データ予測

構築したモデルと予測対象データをpredictに渡して結果を取得します。

x_test =
  [
    [5.4, 3.9, 1.7, 0.4],
    [5.7, 2.8, 4.5, 1.4],
    [7.6, 3.0, 6.6, 2.2]
  ]

[p1, p2, p3] = LgbmEx.predict(model, x_test)

実行結果:

[
  [0.9467519536231427, 0.026109531291447274, 0.027138515085409997],
  [0.0398378208500453, 0.9141394236142755, 0.04602275553567928],
  [0.02991031765897211, 0.03163328954666355, 0.9384563927943644]
]

それぞれspeciesのどれにあたるかの確率です。ラベルエンコード時のmappingと照らし合わせる必要があります。

mapping
#=> %{"Iris-setosa" => 0, "Iris-versicolor" => 1, "Iris-virginica" => 2}

1つ目のデータ [5.4, 3.9, 1.7, 0.4] は、Iris-setosaといえそうです(0.9467519536231427)。

x_testExplorer.DataFrameでも可能です。

# 適当に3つほど拝借
x_test_df = Explorer.DataFrame.slice(df, 0, 3)
[p1, p2, p3] = LgbmEx.predict(model, x_test_df)

端的には以上が、構築と予測になります。以下は付属するTipsです。

アーリーストッピング

モデル構築時に気にしないといけないことの1つが「過学習」です。学習データに過度にフィットするモデル(極端にいえばデータ1つ1つにマッピングされてしまう)は、実運用時の未知データに確率的な応答ができずAI技術として役に立たないことがあります。

過学習を回避する方法の1つに、検証データ(答えもわかる)を用意して、検証データに対して評価が落ちたタイミングで学習を止める方法があります(アーリーストッピング)。

# アーリーストッピングのための検証データを使ったモデル構築例
shuffled = Explorer.DataFrame.shuffle(df)
val_df = Explorer.DataFrame.slice(shuffled, 0, 15)
train_df =  Explorer.DataFrame.slice(shuffled, 15..-1)

model =
  LgbmEx.fit("sample_with_stopping", {train_df, val_df}, "species",
    objective: "multiclass",
    metric: "multi_logloss",
    num_class: 3,
    num_iterations: 100,
    early_stopping_round: 2,
    learning_rate: 0.1
  )

実行結果は省略しますがnum_iterations: 100を設定していますが30回程度で停止します。

学習対象属性の指定

fitで使用する属性名の指定も可能です。

LgbmEx.fit("sample_with_stopping", {train_df, val_df}, {"species", ["sepal_length", "sepal_width"]},
  objective: "multiclass",
  metric: "multi_logloss",
  num_class: 3,
  num_iterations: 100,
  early_stopping_round: 2,
  learning_rate: 0.1
)

パラメータ変更

LightGBMにはパラメータが多く、数値によってモデルが大きく変わります。そのため、自動やあるいは手動で、よりよいパラメータを探すことになります。

refit_modelで異なるパラメータでの再学習が可能です。

model =
  LgbmEx.refit_model(model, [
    min_data_in_leaf: 3,
    early_stopping_round: 5
  ])

モデルのコピー

fitで物理的なファイル領域(nameにあたるディレクトリ)を使う関係で、異なるモデルを構築する際には別名をつけて行います。

new_model =
  LgbmEx.copy_model(model, "copied")
  |> LgbmEx.refit_model([learning_rate: 0.01])

モデルのロード

fitで物理的なファイル領域(nameにあたるディレクトリ)を使うため、workdirとnameを指定することでモデルをロードすることができます。

model = LgbmEx.load_model("test")

モデルのzipとunzip

また既存ディレクトリを指定するだけではなく、zipしたファイルから展開も可能です。
(主にLivebookでモデル構築した後にKinoで手元にダウンロードするような用途を意図しています)

zip_path = LgbmEx.zip_model(model)
model = LgbmEx.unzip_model(zip_path, "load_from_zip")

交差検証

LgbmExとは切り離して下記にあります(が、こちらはより一層個人的なライブラリで不安定です)。

https://github.com/tato-gh/lgbm_exx

終わりに

LgbmExは

  • LGBMExCli / LgbmExCapiの統合
  • Explorer.DataFrame形式への対応

を目的に作りました。試している段階でもありますのでもし使用される場合はご注意ください - -;

Discussion