💯
llm-jp-evalコードリーディング
概要
llm-jp-evalの評価部分のコードを読んだときのまとめです。
以下の評価コマンドを打った時の流れを見ていこうと思います。
CUDA_VISIBLE_DEVICES=0 poetry run python scripts/evaluate_llm.py -cn config.yaml \
model.pretrained_model_name_or_path=/path/to/model_dir \
tokenizer.pretrained_model_name_or_path=/path/to/tokenizer_dir \
dataset_dir=/path/to/dataset_dir
引数のパース
公式のREADMEにも記載がありますが、hydraで引数の管理をしています。
wandb
configに指定しておけばwandbにログインして結果を送信してくれます。
model生成
model = hydra.utils.call(cfg.model, torch_dtype=torch_dtype, _recursive_=False)
ここで設定ファイルからモデルを動的にインスタンス化しています。
model = AutoModelForCausalLM.from_pretrained(...)
と同等のことを
こちらの設定ファイルを使って行っています。
pipelineの生成
few_shotサンプルの取得
trainデータセットの中から指定レコード数サンプリングしています。
先頭から数件取得しています。(シャッフルはしていなそう)
プロンプト生成部分(テンプレートなし)
特にテンプレートを指定しなかった時の処理です。
Alpacaっぽいフォーマット?
プロンプト生成部分(テンプレートあり)
custom_prompt_templateやcustom_fewshots_templateを指定した時の処理です。
以下の様な設定を行うとそれに従ってプロンプトが生成されます。
custom_prompt_template: "以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。\n\n### 指示:\n{instruction}{few_shots_text}\n\n### 入力:\n{input}\n\n### 応答:\n"
custom_fewshots_template: "\n\n### 入力:\n{input}\n\n### 応答:\n{output}"
zero-shotの場合はcustom_prompt_templateのみ。
custom_prompt_template: "以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。\n\n### 指示:\n{instruction}\n\n### 入力:\n{input}\n\n### 応答:\n"
LangChain形式する
LangChain形式にしています。
LangChainはこういったケースで良く使われているのですが、個人的には文法を覚えるのが手間なのでそのままpipelineで推論して欲しいです。
推論&評価
評価や正解判定の部分です。
ここで推論結果を文字列として取得しています。
lm-evaluation-harnessでは推論結果のlogitsを直接見て評価している部分がありました。
llm-jp-evalでは生成結果の文字列のみを見ているようで、明確で良さそうに思えました。
y_pred: str = normalize(chain({"input": sample["input"]})["output"].split("\n\n")[0])
正解判定
y_true: str = normalize(sample["output"])
output_dict.append({"input": sample["input"], "pred": y_pred, "gold": y_true})
y_trues.append(y_true)
y_preds.append(y_pred)
exact = 1 if y_pred == y_true else 0
正解と回答の文字列間の類似度を計算している。
char_f1 = fuzz.token_sort_ratio(y_pred, y_true) / 100.0
F1スコアの計算。
set_y_true: list[str] = [x.strip() for x in y_true.split("\n")]
set_y_pred: list[str] = list({x.strip() for x in y_pred.split("\n")})
set_pre = sum([1 if y in set_y_true else 0 for y in set_y_pred]) / len(set_y_pred)
set_rec = sum([1 if y in set_y_true else 0 for y in set_y_pred]) / len(set_y_true)
set_f1 = 2 * (set_pre * set_rec) / (set_pre + set_rec) if (set_pre + set_rec) != 0 else 0
スコアの平均値を出している部分
各スコアの平均スコアを出す部分。
_exact_match以外のスコアも使われていることに注意。
Discussion