🔥

Mixtral 250MのpretrainingからInstruction Tuningまで

2024/02/28に公開

MoEを持つMixtralがhuggingface/transformersで公開されているので、これを利用しつつ、250Mの小さいサイズとして日本語と英語でpretraining、finetuningを行います。

学習させたものは以下
https://huggingface.co/if001/tiny_mixtral_ja
https://huggingface.co/if001/tiny_mixtral_ja_instruction

Pretraining

lit-llamaを参考にする
https://github.com/Lightning-AI/lit-llama

データセットの準備

lit-llamaでは、torchで圧縮したデータセットを用意しておく必要がある。データセット作成用のscriptはここ
https://github.com/Lightning-AI/lit-llama/blob/main/scripts/prepare_any_text.py

これを参考にhuggingface datasetsからdatasetを作成できるように修正したものがこれ

https://github.com/if001/lit-llama-ja/blob/mixtral/scripts/prepare_ja.py

今回は合計8.64Bのデータセットを作成した。

total tokens: 8.64B

wikipedia_ja:    844.65M  
wikipedia_en:    3.80B  
open-text-books: 60.17M  
oscar:           3.85B  
aozorabunko:     92.97M

tokenizerは以前作成したこれを使う
https://zenn.dev/if001/articles/87bbe893411fa1

oscarも以前事前処理を行って作成したこれを使う
https://zenn.dev/if001/articles/cc262413e69e3d

モデルの準備

lit-llamaでは、torch modelを学習するようになっているので、torch実装のモデルであれば、簡単にpretrainingが行える。

ここでロードされるモデルをMixtralに差し替える
https://github.com/Lightning-AI/lit-llama/blob/main/pretrain/redpajama.py#L101

Mixtralのtorch実装は、transformersで公開されているものをそのまま利用する
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py#L1274

小さなサイズで学習させるためMixtralConfigを修正する。layer数やhidden_dimなどを設定
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/configuration_mixtral.py#L28

MixtralForCausalLMのコンストラクタの引数として、修正したMixtralConfigを渡す

MoEではexpertの選択が偏らないようなlossが追加されています。各layer、batch size、seqence lengthについて偏りがないかを見るのですが、batch sizeやsequence lengthは大きい値になることが多く、ほとんどの場合lossが意味を成さないのでは?と思います。これについては、Github上で議論が行われており、今後の展開に期待です。
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py#L77

学習

モデルサイズは、275.86M
8.64Bのデータセットを2epoch
環境はpaperspace gradientのA6000(1GPU)を用います。

学習の結果、val loss: 4.9381で、かかった時間は約175hでした。

upload

huggingfaceにモデルをuploadしておきます。

huggingfaceのGUIからrepositoryを作成。

loginしておきます

from huggingface_hub import login
login()

modelをロードしてupload

    with lazy_load(checkpoint_path) as checkpoint:
        config = MixtralConfig_HF.from_name(model_name)
        model = MixtralForCausalLM(config)
        model.load_state_dict(checkpoint)

    model.push_to_hub(repo_id)

Instruction Tuning

instructionデータセットを使いfinetuningします。

データセット

いつもお世話になっている以下を使います
https://huggingface.co/datasets/kunishou/databricks-dolly-15k-ja
https://huggingface.co/datasets/kunishou/oasst1-89k-ja
https://huggingface.co/datasets/izumi-lab/llm-japanese-dataset

学習

huggingfaceにモデルをuploadしたので、学習にはtorchではなくtransformersのコードを使います。

TrainingArgumentsやTrainerを用意
https://huggingface.co/docs/transformers/ja/training#trainer

3epoch学習し、その結果、最終的なlossは0.021となりました。
少し学習率が大きかった気がするのと過学習気味かもしれません。

JGLUE

JGLUEで評価しておきます。
https://github.com/Stability-AI/lm-evaluation-harness

TASK=[
"jcommonsenseqa-1.1-0.3", # 5択の常識的問題を推論
"jnli-1.3-0.3", # 2つの文章の関係を推定
"marc_ja-1.1-0.3" # amazonのレビューがpositive/negativeかを分類
"jsquad-1.1-0.3", # 与えられた文章から答えを抽出
"xlsum_ja-1.0-0.3", # 要約
"jaqket_v2-0.2-0.3" # Wikipediaの記事名を答えとしたQA
"xwinograd_ja",# 照応解析
"mgsm-1.0-0.3" # 簡単な算数のタスク
]

lm-evaluation-harnessの評価

https://github.com/Stability-AI/lm-evaluation-harness?tab=readme-ov-file#leaderboard

model average jcommonsenseqa jnli marc_ja jsquad jaqket_v2 xlsum_ja xwinograd_ja mgsm
tiny-mixtral 41.41 42.90 54.85 83.66 11.00 70.44 4.02 51.41 3
stabilityai-japanese-stablelm-instruct-alpha-7b 54.71 82.22 52.05 82.88 63.26 74.83 7.79 72.68 2
rinna-bilingual-gpt-neox-4b-instruction-sft 47.75 49.51 47.08 95.28 55.99 61.17 5.51 64.65 2.8

平均スコアでは下回っているが、サイズの割に善戦した気がします

今回のモデルは、jnliやmarc_jaはそこそこのスコアになってます。それぞれ「2つの文章の関係を含意、矛盾、中立の3択から選ぶ」、「positive/negativeの2択の分類」のタスクになっており、簡単なタスクや出力は行えそう。

一方、jcommonsenseqaやjsquadではスコアが低くなってます。jcommonsenseqaは問題といくつかの選択肢が提示されその中から正解を抽出するタスクで、文章をより正確に読み取りそれを出力に適応する必要があります。このような複雑なタスクは行えなさそう。

Nejumiリーダーボード

スクリプトは以下
https://github.com/llm-jp/llm-jp-eval/tree/main

https://wandb.ai/wandb-japan/llm-leaderboard/reports/Nejumi-LLM-Neo--Vmlldzo2MTkyMTU0

model average EL FA MC MR NLI QA RC
tiny-mixtral 0.0507 0 0.0312 0.00 0.00 0.000 0.0856 0.2384
rinna/nekomata-14b-instruction 0.4375 0.0067 0.1651 0.77 0.42 0.494 0.3402 0.8663
rinna/nekomata-7b-instruction 0.3194 0 0.059 0.77 0.03 0.162 0.3318 0.8826

タスクはそれぞれ
EL (Entity Linking)
FA (Fundamental Analysis)
MC (Multiple Choice question answering) (lm-evaluation-harnessではJCommonsenseQA)
MR (Mathematical Reasoning)
NLI (Natural Language Inference) (lm-evaluation-harnessではjnli)
QA (Question Answering)
RC (Reading Comprehension) (lm-evaluation-harnessではJSQuAD)

こちらでは、他のモデルと比べてだいぶスコアが低め。

3択を選ぶタスクでもlm-evaluation-harnessでは数字のみで出力で良かったが、正確に正解の文字列を出力しなければ正解とならない。より正確な出力が求められる分、スコアが低くなっている

まとめ

250MのMixtralをpretrainingからfinetuningまでを行いました。小さいサイズなりにうっすら日本語を理解してそう。入力から正確に情報を抽出とそれらを使った出力はさすがに難しそう。あとは、推論時のexpertの選択のされかたや同サイズのモデルとの比較をしてみたいところ

lit-llamaではLightning Fabricを使っており、GPU並列、マシン並列が簡単に試せそうなのでこのあたりを試してみる
https://lightning.ai/docs/fabric/stable/

Discussion