📝

Llama-3をColabで記事執筆用にファインチューニングしてみた

2024/05/10に公開

こんにちは、Kaiです。
先日公開されたMetaのオープンソースLLMであるLlama-3、早速様々なところで派生モデルが作られていますね。

8BモデルであればColabでファインチューニングと推論ができるということだったので、早速当社でも試してみました。

解きたいタスク

CareNetは医療メディアですので、医学論文を記事化することが多々あります。
これをAI化できないかという発想はすぐに思いつきますが、文体や表現などのノウハウはなかなかプロンプトに埋め込むのが難しい暗黙知でした。

ということで、CareNetで医学論文を扱った過去記事数千件を学習データとして、ファインチューニングで「英語論文アブストラクト→日本語記事」という変換タスクが取り扱えるか試してみることにします。

ソース記事

いつもの通りnpakaさんのnoteを参考にさせて頂いています。
かなりの部分、npakaさんの写経です。
https://note.com/npaka/n/n315c0bdbbf00

Llama-3について

こちらは詳細な記事が多数出ていますので省略します。
私が先日まとめたAIニュースにもいくつか掲載していますのでご覧ください。
https://zenn.dev/carenet/articles/6173fdf67db4d6

事前準備1:HuggingFace

Llama-3のモデルはHuggingFaceに上がっていますが、利用するためにはAgreementに合意する必要があります。
https://huggingface.co/meta-llama/Meta-Llama-3-8B

この下の方にいくと、氏名や所属を入れた上でAgreementに同意するというボタンがあるはずですので、まずこちらに同意しておきます。

私の場合、数分程度でHuggingFaceに登録したメールアドレスに
[Access granted] Your request to access model meta-llama/Meta-Llama-3-8B has been accepted
というメールが届きました。

事前準備2:wandb

npakaさんのコードでは、WandBで学習ログの分析をしていますので、予め登録してトークンを取得しておきましょう。
(参考)
https://qiita.com/Yu_Mochi/items/4fc283ebc31225d4e106

事前準備3:学習データ

CareNetの記事データベースから、英語の医学論文をもとに執筆された記事約6千件を取得し、対応する英語AbstractをPubmedから取得したデータを用意します。

データは40MB弱のcsvファイルになり、こいつを用いてファインチューニングします。
使うカラム名は「abstract(英文アブストラクト)」と「CN_body(日本語記事)」です。

学習

npakaさんに従って進めていきます。

Colab設定

Llama-3-8Bのファインチューニングには24GBほどのVRAMが必要なため、Colabは「GPU→A100」を選択します。Pro加入が必要かもしれません。

パッケージインストール

# パッケージのインストール
!pip install -U transformers accelerate bitsandbytes
!pip install trl peft wandb
!git clone https://github.com/huggingface/trl
%cd trl

環境変数

Colab左端の鍵マークをクリックして、HF_TOKENとしてHuggingFaceのトークンを登録してください。スイッチを押して有効化するのを忘れずに。

HuggingFaceログイン

# HuggingFaceへのログイン
!huggingface-cli login

実行するとトークンを入力するよう言われますので、再度入力してEnter。
なお、Colabではそのままセルの実行結果にコマンドラインのごとく値を入力できます。

Googleドライブマウント

学習データファイルはGoogleドライブに置きましたので、パスを通します。

# Googleドライブのマウント
from google.colab import drive
drive.mount('/content/drive')

PEFT学習コードの編集

さて、ここでnpakaさんの記事と直近のHuggingFaceコードに差分が出てきます。
まず、編集すべきファイルは「trl/examples/scripts/sft.py」になります。
こちらの記事を参考に、一旦ファイルの全てを書き換えます。
https://note.com/_4piken/n/n196b81eedf9e

その上で、116行目からのデータセット作成部分を、自前のコードに置き換えます。

編集前
    ################
    # Dataset
    ################
    raw_datasets = load_dataset(args.dataset_name)

    train_dataset = raw_datasets[args.dataset_train_name]
    eval_dataset = raw_datasets[args.dataset_test_name]
編集後
    ################
    # Dataset
    ################
    # データセットの読み込み
    dataset = load_dataset("csv", data_files = "CSVファイルパス", split="train")

    # プロンプトの生成
    def generate_prompt(example):
        messages = [
            {
                'role': "system",
                'content': "あなたは優秀な医療メディアの記者です。与えられた英語の医学論文アブストラクトから、日本語の記事を執筆してください。" # シンプルな指示で試してみます
            },
            {
                'role': "user",
                'content': example["abstract"] # CSVのアブストラクトカラム
            },
            {
                'role': "assistant",
                'content': example["CN_body"] # CSVの記事本文カラム
            }
        ]
        return tokenizer.apply_chat_template(messages, tokenize=False)

    # textカラムの追加
    def add_text(example):
        example["text"] = generate_prompt(example)
        return example

    dataset = dataset.map(add_text)
    dataset = dataset.remove_columns(["CN_body","abstract"]) # 元ファイルデータの削除、必要に応じて追加

    # データセットの分割
    train_test_split = dataset.train_test_split(test_size=0.1)
    train_dataset = train_test_split["train"]
    eval_dataset = train_test_split["test"]

学習実行

# 学習
!python examples/scripts/sft.py \
    --model_name meta-llama/Meta-Llama-3-8B-Instruct \
    --dataset_name 学習用データセットのパス \
    --dataset_text_field text \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 1 \
    --learning_rate 2e-4 \
    --optim adamw_torch \
    --save_steps 50 \
    --logging_steps 50 \
    --max_steps 3000 \
    --use_peft \
    --lora_r 64 \
    --lora_alpha 16 \
    --lora_dropout 0.1 \
    --load_in_4bit \
    --report_to wandb \
    --output_dir Llama-3-CN-8B-Instruct

試しに学習させてみたところ、500stepsで20分ほどかかったため、2時間程度の現実的な時間で終わる3000stepsにしてみます。その他のパラメータは、とりあえずそのまま実行します。
実行すると学習可視化用のWandBからトークンを聞かれるので、先ほど取得したトークンを入力します。

学習実行中のリソース状況はこんな感じです。

学習結果

lossの状況はこんな感じです。
まぁ正直、タスク的に綺麗に収束するようなものではないのでこんなものでしょう。

推論

こちらもnpakaさんに従って進めていきます。

トークナイザーとモデルの準備

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# トークナイザーとモデルの準備
tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Meta-Llama-3-8B-Instruct"
)
model = AutoModelForCausalLM.from_pretrained(
    "./Llama-3-CN-8B-Instruct",
    device_map="auto",
    torch_dtype="auto",
)

推論実行

import torch

# プロンプトの準備
chat = [
    { "role": "system", "content": "あなたは優秀な医療メディアの記者です。与えられた英語の医学論文アブストラクトから、日本語の記事を執筆してください。" },
    { "role": "user", "content": "試したい英文アブストラクトを入れます。" },
]
prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)

# 推論の実行
# max_new_tokensはタスクの性質上広めに取ります。
token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
with torch.no_grad():
    output_ids = model.generate(
        token_ids.to(model.device),
        do_sample=True,
        temperature=0.6,
        top_p=0.9,
        max_new_tokens=4096,
        eos_token_id=[
            tokenizer.eos_token_id,
            tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ],
    )
output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1) :], skip_special_tokens=True)
print(output)

推論結果

元論文
https://jamanetwork.com/journals/jama/article-abstract/2817546

ファインチューニング後の推論結果

非バルサルタン系抗血液凝固薬であるアピキサバン(エリキサリ)やリヴァロキサバン(Xァレルタ)を投与する心房細動患者に対し、心室頻拍を制御する目的で用いられるジチアゼムは、メトプロロールに比べ重篤な出血のリスクが高く、特に120mg以上のジチアゼム投与はメトプロロール投与に比べ重篤な出血のリスクが約1.5倍となった。米国・ジョンズ・ホプキンス大学のKatherine L. Cohen氏らが、米国メディケア保険者を対象に実施した後ろ向きコホート研究の結果を報告した。心房細動とアピキサバン/リヴァロキサバンを併用している患者において、ジチアゼムはメトプロロールに比べ重篤な出血のリスクが高いことを示した。心房細動は、世界の約10%に及ぶ心疾患の原因である。心房細動患者は、血栓塞栓のリスクが高く、血栓塞栓の発生率は、心房細動患者の約4%である。心房細動患者に対し、抗血液凝固薬が投与されるが、そのうち非バルサルタン系抗血液凝固薬は、ジチアゼムを用いる心室頻拍の制御と併用されることが多い。ジチアゼムは、心室頻拍を制御する目的で用いられるが、非バルサルタン系抗血液凝固薬を併用する患者においては、血液中の濃度を増加させるため、重篤な出血のリスクが高まる可能性があることが示されている。ジチアゼムの使用は、心房細動患者の約50%に及ぶが、ジチアゼムの使用が重篤な出血のリスクを高めるかどうかは、明らかにされていない。米国・ジョンズ・ホプキンス大学のKatherine L. Cohen氏らは、「ジチアゼムを用いる心室頻拍の制御と非バルサルタン系抗血液凝固薬を併用する心房細動患者の重篤な出血のリスクは、メトプロロールを用いる患者に比べ高いか?」という研究の目的を、米国メディケア保険者を対象に実施した後ろ向きコホート研究の結果を報告した。研究の方法は、ジチアゼムとメトプロロールの使用を開始した心房細動患者を対象に、2012年1月1日から2020年11月29日までの期間を対象にした後ろ向きコホート研究で、ジチアゼムとメトプロロールを開始した患者を対象に、2012年1月1日から2020年11月29日までの期間を対象にした後ろ向きコホート研究を行い、ジチアゼムとメトプロロールを併用した患者とメトプロロールを用いる患者との重篤な出血のリスクを比較した。研究の結果、ジチアゼムを用いる患者は、メトプロロールを用いる患者に比べ重篤な出血のリスクが高く、ジチアゼムを120mg以上投与する患者は、メトプロロールを用いる患者に比べ重篤な出血のリスクが約1.5倍となった。研究の結果、ジチアゼムを用いる患者は、メトプロロールを用いる患者に比べ重篤な出血のリスクが高く、ジチアゼムを120mg以上投与する患者は、メトプロロールを用いる患者に比べ重篤な出血のリスクが約1.5倍となった。心房細動とアピキサバン/リヴァロキサバンを併用している患者において、ジチアゼムはメトプロロールに比べ重篤な出血のリスクが高いことを示した。

いかがでしょうか。
ファインチューニングを施していない素のLlama-3-8Bに同様のプロンプトとアブストラクトを食わせて比較してみましょう。

ファインチューニング前の推論結果

【新規使用者におけるatrial fibrillation患者の血液栓塞療法に関する研究】

ドイツメディカルジャーナル、2024年

atrial fibrillation(心房細動)患者の血液栓塞療法において、diltiazemというventricular rate–control薬剤がapixabanとrivaroxabanの排出を阻害し、過抗凝血化を引き起こす可能性があることが示唆された。そこで、diltiazemとmetoprololを併用したatrial fibrillation患者の新規使用者における重篤な出血リスクを比較する研究を行った。

研究では、2012年1月1日から2020年11月29日までの間にatrial fibrillationを患い、diltiazemやmetoprololを開始した65歳以上のMedicare保険者を対象にした回帰コホート研究を行った。患者たちは、2020年11月30日までの365日間を対象にした。データは、2023年8月から2024年2月にかけて分析された。

研究では、diltiazemとmetoprololを併用した患者たちの出血リスクを比較し、主な結果としては、出血関連の入院死亡、死に伴う出血、主な出血イベント(出血性脳梗塞、肺塞栓症、内臓出血、死亡)、死亡(出血関連のない場合)を測定した。

研究では、204,155人のMedicare保険者を対象にした。うち53,275人はdiltiazemを、150,880人はmetoprololを開始した。平均年齢は76.9歳、女性は52.7%であった。患者たちは、90,927人年のフォローアップ(median 120 [IQR, 59-281] days)を実施した。

結果として、diltiazemを開始した患者たちは、主な結果として出血関連の入院死亡(rate difference 10.6 [95% CI, 7.0-14.2] per 1000 person-years; hazard ratio 1.21 [95% CI, 1.13-1.29])と死に伴う出血(rate difference 2.4 [95% CI, 0.6-4.2] per 1000 person-years; hazard ratio 1.19 [95% CI, 1.05-1.34])のリスクが高かったことが示唆された。特に、初回のdiltiazemの投与量が120mg以上であった患者たちは、より高いリスクを示した(rate difference 15.1 [95% CI, 10.2-20.1] per 1000 person-years; hazard ratio 1.29 [95% CI, 1.19-1.39])。また、高用量グループでは、主な出血イベントのリスクも高かった(hazard ratio 1.14 [95% CI, 1.02-1.27))。

一方、metoprololを開始した患者たちは、出血関連の入院死亡、死に伴う出血、主な出血イベント、死亡のリスクが低かった。

結論として、atrial fibrillation患者の新規使用者におけるdiltiazemの使用は、metoprololに比べてより高い出血リスクを示し、特に120mg以上のdiltiazemの投与量ではより高いリスクを示したことが示された。

うーん、ファインチューニング前のモデルは直訳に近く、記事っぽくはありませんが正確さに勝るような気がします。一方、ファインチューニング後のモデルは、記事っぽさを出すために前後や文脈を入れ替えようとした結果、繰り返しや破綻が発生しているように見えます。

結論としては、やはりGPT-4やClaude3など、プロプライエタリで強力なモデルの方が、性能の面でも使いやすさの面でも相当に上回っていますね。

ただ、これはあくまで8Bモデルであって、これまでの同規模モデルよりも遥かに強力であることは間違いありません。また、今回のファインチューニングはデータセットも学習パラメータも工夫せず、単にそのまま突っ込んでいるだけですので、実用にはまだ遠いものになっています。

残念ながら、最先端モデルに匹敵すると言われているLlama-3-70Bモデルは簡単に手元で試すというわけにはいかないサイズになっていますが、OSSモデルの洗練はこれからも続いていくと思われます。クローズドかつセキュアな環境下で実行しなければならないという企業ニーズに応える一つの選択肢になりそうです。

CareNet Engineers

Discussion