🔥

RecBoleを使ってみよう4 モデルの学習

2023/09/15に公開

概要

今回はモデルの学習と精度の検証を行います。

データセットは第2回で用意したAnimeにします。複数のモデルの学習を行い、精度を比較することを目標とします。

run_recbole関数

初回にも使ったrun_recboleは、次の全てを実行してくれるオールインワンな関数です。

  • データの読み込み
  • データの分割(デフォルトでは8:1:1に分割します)
    • 学習データ
    • 検証データ
    • テストデータ
  • 学習データと検証データを用いたモデルの学習
    (学習データで学習し、検証データの精度を見て、改善が見られなくなってきたら学習を早期停止します)
  • テストデータの精度を算出

今回はただ動かすだけでなく、設定の方法を理解した上で実行してみます。

設定の優先度について

モデルの学習を行う際、RecBoleは次の4つの場所の設定を参照します。

  1. コマンドライン
  2. パラメータ辞書
  3. 設定ファイル
  4. デフォルト値

これらは上にあるものほど優先されます。参照: 公式ドキュメント

ただし、モデル名modelだけはrun_recbole関数の引数を指定することもでき、上記4つよりも優先されるようです。

簡単のため、今回はモデル名の設定をrun_recbole関数の引数に、その他の設定をパラメータ辞書に書くことにします。

モデルの学習

Animeデータセット ./data/processed/anime/anime.inter, anime.item を使って学習を実行します。(実際は.interファイルの方しか用いられません。)

~/src/recbole_sandbox/train.py
from pathlib import Path

import pandas as pd
from recbole.quick_start import run_recbole

config_dict = {
    "data_path": "./data/processed/",
    "dataset": "anime",
    "epochs": 1,
}
result = run_recbole(model="Pop", config_dict=config_dict)
print(result)
poetry run python train.py

各種設定はパラメータ辞書 config_dict に書かれています。

  • モデル名はPopにしています。これはBPRと同じgeneralレコメンドです。
  • data_pathdataset を上記のように指定することで、Animeデータセットを読み込んでくれます。
    一般にも、{data_path}/{dataset}/{dataset}.inter, {dataset}.user, {dataset}.item のように配置しておけば、同様にデータセットを読み込んでくれます。
  • epochsエポック数を決めています(デフォルト値: 300)。今回は実行テストのため、1に設定しました。

run_recboleの戻り値resultは次のような辞書です。実証データの最善の精度とテストデータの精度が収納されていることがわかります。

result = {
    "best_valid_score": 0.1945,
    "valid_score_bigger": True,
    "best_valid_result": OrderedDict(
        [
            ("recall@10", 0.0984),
            ("mrr@10", 0.1945),
            ("ndcg@10", 0.103),
            ("hit@10", 0.4325),
            ("precision@10", 0.0673),
        ]
    ),
    "test_result": OrderedDict(
        [
            ("recall@10", 0.1052),
            ("mrr@10", 0.216),
            ("ndcg@10", 0.1145),
            ("hit@10", 0.4443),
            ("precision@10", 0.0738),
        ]
    ),
}

ログファイル

学習後にログファイルが、./log/{モデル名}/に保存されます。
ログファイルを見ることで、パラメータがどのような値に設定されたのかが全て分かります。
末尾には上記のresultと同じ、精度の情報が書かれています。

実行ログ
~/src/recbole_sandbox/log/BPR/BPR-anime-Sep-15-2023_13-02-34-8238be.log
Fri 15 Sep 2023 13:02:34 INFO  ['train.py']
Fri 15 Sep 2023 13:02:34 INFO  
General Hyper Parameters:
gpu_id = 0
use_gpu = True
seed = 2020
state = INFO
reproducibility = True
data_path = ./data/processed/anime
checkpoint_dir = saved
show_progress = True
save_dataset = False
dataset_save_path = None
save_dataloaders = False
dataloaders_save_path = None
log_wandb = False

Training Hyper Parameters:
epochs = 1
train_batch_size = 2048
learner = adam
learning_rate = 0.001
train_neg_sample_args = {'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}
eval_step = 1
stopping_step = 10
clip_grad_norm = None
weight_decay = 0.0
loss_decimal_place = 4

Evaluation Hyper Parameters:
eval_args = {'split': {'RS': [0.8, 0.1, 0.1]}, 'group_by': 'user', 'order': 'RO', 'mode': 'full'}
repeatable = False
metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision']
topk = [10]
valid_metric = MRR@10
valid_metric_bigger = True
eval_batch_size = 4096
metric_decimal_place = 4

Dataset Hyper Parameters:
field_separator = 	
seq_separator =  
USER_ID_FIELD = user_id
ITEM_ID_FIELD = item_id
RATING_FIELD = rating
TIME_FIELD = timestamp
seq_len = None
LABEL_FIELD = label
threshold = None
NEG_PREFIX = neg_
load_col = {'inter': ['user_id', 'item_id']}
unload_col = None
unused_col = None
additional_feat_suffix = None
rm_dup_inter = None
val_interval = None
filter_inter_by_user_or_item = True
user_inter_num_interval = [0,inf)
item_inter_num_interval = [0,inf)
alias_of_user_id = None
alias_of_item_id = None
alias_of_entity_id = None
alias_of_relation_id = None
preload_weight = None
normalize_field = None
normalize_all = None
ITEM_LIST_LENGTH_FIELD = item_length
LIST_SUFFIX = _list
MAX_ITEM_LIST_LENGTH = 50
POSITION_FIELD = position_id
HEAD_ENTITY_ID_FIELD = head_id
TAIL_ENTITY_ID_FIELD = tail_id
RELATION_ID_FIELD = relation_id
ENTITY_ID_FIELD = entity_id
benchmark_filename = None

Other Hyper Parameters: 
worker = 0
wandb_project = recbole
shuffle = True
require_pow = False
enable_amp = False
enable_scaler = False
transform = None
embedding_size = 64
numerical_features = []
discretization = None
kg_reverse_r = False
entity_kg_num_interval = [0,inf)
relation_kg_num_interval = [0,inf)
MODEL_TYPE = ModelType.GENERAL
MODEL_INPUT_TYPE = InputType.PAIRWISE
eval_type = EvaluatorType.RANKING
single_spec = True
local_rank = 0
device = cpu
eval_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}


Fri 15 Sep 2023 13:02:54 INFO  anime
The number of users: 73516
Average actions of users: 106.28765558049378
The number of items: 11201
Average actions of items: 697.6550892857143
The number of inters: 7813737
The sparsity of the dataset: 99.05110070703805%
Remain Fields: ['user_id', 'item_id']
Fri 15 Sep 2023 13:03:08 INFO  [Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}]
Fri 15 Sep 2023 13:03:08 INFO  [Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [0.8, 0.1, 0.1]}, 'group_by': 'user', 'order': 'RO', 'mode': 'full'}]
Fri 15 Sep 2023 13:03:08 INFO  BPR(
  (user_embedding): Embedding(73516, 64)
  (item_embedding): Embedding(11201, 64)
  (loss): BPRLoss()
)
Trainable parameters: 5421888
Fri 15 Sep 2023 13:03:08 INFO  FLOPs: 128.0
Fri 15 Sep 2023 13:03:55 INFO  epoch 0 training [time: 47.01s, train loss: 739.1542]
Fri 15 Sep 2023 13:04:36 INFO  epoch 0 evaluating [time: 40.90s, valid_score: 0.194500]
Fri 15 Sep 2023 13:04:36 INFO  valid result: 
recall@10 : 0.0984    mrr@10 : 0.1945    ndcg@10 : 0.103    hit@10 : 0.4325    precision@10 : 0.0673
Fri 15 Sep 2023 13:04:36 INFO  Saving current: saved/BPR-Sep-15-2023_13-03-08.pth
Fri 15 Sep 2023 13:04:37 INFO  Loading model structure and parameters from saved/BPR-Sep-15-2023_13-03-08.pth
Fri 15 Sep 2023 13:05:18 INFO  best valid : OrderedDict([('recall@10', 0.0984), ('mrr@10', 0.1945), ('ndcg@10', 0.103), ('hit@10', 0.4325), ('precision@10', 0.0673)])
Fri 15 Sep 2023 13:05:18 INFO  test result: OrderedDict([('recall@10', 0.1052), ('mrr@10', 0.216), ('ndcg@10', 0.1145), ('hit@10', 0.4443), ('precision@10', 0.0738)])

複数モデルの精度比較

上記をモデル名を変えて複数回行った後、ログファイルを見比べれば、複数のモデルの精度比較ができます。
しかし、手作業でまとめるのは面倒なので、複数のモデルの精度が自動でcsv出力されるようにしてみました。

~/src/recbole_sandbox/train_models.py
from pathlib import Path

import pandas as pd
from recbole.quick_start import run_recbole

results_path = Path("./data/output/results.csv")
results_path.parent.mkdir(exist_ok=True)

models = [
    "BPR",
    "Pop",
]
topk = 10
config_dict = {
    "data_path": "./data/processed/",
    "dataset": "anime",
    "epochs": 1,
    "topk": topk,
}
evaluation_columns = [
    f"recall@{topk}",
    f"mrr@{topk}",
    f"ndcg@{topk}",
    f"hit@{topk}",
    f"precision@{topk}",
]

# モデルごとの結果をまとめ、csvに出力
output_list = []
for model in models:
    result = run_recbole(
        model=model,
        config_dict=config_dict,
    )
    test_result = result["test_result"]

    output_dict = {
        "model": model,
        "best_valid_score": result["best_valid_score"],
        **{col: test_result[col] for col in evaluation_columns},
    }

    output_list.append(output_dict)

output_dataframe = pd.DataFrame(output_list)
output_dataframe.to_csv(results_path, index=False)

次のようなcsvが出力されれば成功です。
エポック数=1なのでズダボロですが... コードが動いたので一旦ヨシ!とします。

model best_valid_score recall@10 mrr@10 ndcg@10 hit@10 precision@10
BPR 0.1945 0.1052 0.216 0.1145 0.4443 0.0738
Pop 0.188 0.1037 0.2077 0.11 0.4324 0.069

ColabでGPUを使うべきかな。。。

Discussion