💯

llm-jp-evalコードリーディング

2024/01/11に公開

概要

llm-jp-evalの評価部分のコードを読んだときのまとめです。
https://github.com/llm-jp/llm-jp-eval

以下の評価コマンドを打った時の流れを見ていこうと思います。

CUDA_VISIBLE_DEVICES=0 poetry run python scripts/evaluate_llm.py -cn config.yaml \
  model.pretrained_model_name_or_path=/path/to/model_dir \
  tokenizer.pretrained_model_name_or_path=/path/to/tokenizer_dir \
  dataset_dir=/path/to/dataset_dir

引数のパース

公式のREADMEにも記載がありますが、hydraで引数の管理をしています。

https://github.com/llm-jp/llm-jp-eval/blob/main/scripts/evaluate_llm.py#L14-L16

wandb

configに指定しておけばwandbにログインして結果を送信してくれます。

https://github.com/llm-jp/llm-jp-eval/blob/main/src/llm_jp_eval/evaluator.py#L25-L38

model生成

https://github.com/llm-jp/llm-jp-eval/blob/main/src/llm_jp_eval/evaluator.py#L70-L79

model = hydra.utils.call(cfg.model, torch_dtype=torch_dtype, _recursive_=False)

ここで設定ファイルからモデルを動的にインスタンス化しています。

model = AutoModelForCausalLM.from_pretrained(...)

と同等のことを

https://github.com/llm-jp/llm-jp-eval/blob/main/configs/model/llm-jp_llm-jp-1.3b-v1.0.yaml
こちらの設定ファイルを使って行っています。

pipelineの生成

https://github.com/llm-jp/llm-jp-eval/blob/main/src/llm_jp_eval/evaluator.py#L115-L124

few_shotサンプルの取得

trainデータセットの中から指定レコード数サンプリングしています。
先頭から数件取得しています。(シャッフルはしていなそう)
https://github.com/llm-jp/llm-jp-eval/blob/main/src/llm_jp_eval/utils.py#L34-L44

プロンプト生成部分(テンプレートなし)

特にテンプレートを指定しなかった時の処理です。
Alpacaっぽいフォーマット?
https://github.com/llm-jp/llm-jp-eval/blob/main/src/llm_jp_eval/utils.py#L69-L72

プロンプト生成部分(テンプレートあり)

custom_prompt_templateやcustom_fewshots_templateを指定した時の処理です。
https://github.com/llm-jp/llm-jp-eval/blob/main/src/llm_jp_eval/utils.py#L56-L67

以下の様な設定を行うとそれに従ってプロンプトが生成されます。

custom_prompt_template: "以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。\n\n### 指示:\n{instruction}{few_shots_text}\n\n### 入力:\n{input}\n\n### 応答:\n"
custom_fewshots_template: "\n\n### 入力:\n{input}\n\n### 応答:\n{output}"

zero-shotの場合はcustom_prompt_templateのみ。

custom_prompt_template: "以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。\n\n### 指示:\n{instruction}\n\n### 入力:\n{input}\n\n### 応答:\n"

LangChain形式する

LangChain形式にしています。
https://github.com/llm-jp/llm-jp-eval/blob/main/src/llm_jp_eval/evaluator.py#L143-L153

LangChainはこういったケースで良く使われているのですが、個人的には文法を覚えるのが手間なのでそのままpipelineで推論して欲しいです。

推論&評価

評価や正解判定の部分です。
https://github.com/llm-jp/llm-jp-eval/blob/main/src/llm_jp_eval/utils.py#L93-L108

ここで推論結果を文字列として取得しています。
lm-evaluation-harnessでは推論結果のlogitsを直接見て評価している部分がありました。
llm-jp-evalでは生成結果の文字列のみを見ているようで、明確で良さそうに思えました。

y_pred: str = normalize(chain({"input": sample["input"]})["output"].split("\n\n")[0])

正解判定

        y_true: str = normalize(sample["output"])
        output_dict.append({"input": sample["input"], "pred": y_pred, "gold": y_true})
        y_trues.append(y_true)
        y_preds.append(y_pred)
        exact = 1 if y_pred == y_true else 0

正解と回答の文字列間の類似度を計算している。

char_f1 = fuzz.token_sort_ratio(y_pred, y_true) / 100.0

F1スコアの計算。

        set_y_true: list[str] = [x.strip() for x in y_true.split("\n")]
        set_y_pred: list[str] = list({x.strip() for x in y_pred.split("\n")})
        set_pre = sum([1 if y in set_y_true else 0 for y in set_y_pred]) / len(set_y_pred)
        set_rec = sum([1 if y in set_y_true else 0 for y in set_y_pred]) / len(set_y_true)
        set_f1 = 2 * (set_pre * set_rec) / (set_pre + set_rec) if (set_pre + set_rec) != 0 else 0

スコアの平均値を出している部分

各スコアの平均スコアを出す部分。
_exact_match以外のスコアも使われていることに注意。
https://github.com/llm-jp/llm-jp-eval/blob/main/src/llm_jp_eval/utils.py#L161-L178

Discussion