Mixtral 250MのpretrainingからInstruction Tuningまで
MoEを持つMixtralがhuggingface/transformersで公開されているので、これを利用しつつ、250Mの小さいサイズとして日本語と英語でpretraining、finetuningを行います。
学習させたものは以下
Pretraining
lit-llamaを参考にする
データセットの準備
lit-llamaでは、torchで圧縮したデータセットを用意しておく必要がある。データセット作成用のscriptはここ
これを参考にhuggingface datasetsからdatasetを作成できるように修正したものがこれ
今回は合計8.64Bのデータセットを作成した。
total tokens: 8.64B
wikipedia_ja: 844.65M
wikipedia_en: 3.80B
open-text-books: 60.17M
oscar: 3.85B
aozorabunko: 92.97M
tokenizerは以前作成したこれを使う
oscarも以前事前処理を行って作成したこれを使う
モデルの準備
lit-llamaでは、torch modelを学習するようになっているので、torch実装のモデルであれば、簡単にpretrainingが行える。
ここでロードされるモデルをMixtralに差し替える
Mixtralのtorch実装は、transformersで公開されているものをそのまま利用する
小さなサイズで学習させるためMixtralConfigを修正する。layer数やhidden_dimなどを設定
MixtralForCausalLMのコンストラクタの引数として、修正したMixtralConfigを渡す
MoEではexpertの選択が偏らないようなlossが追加されています。各layer、batch size、seqence lengthについて偏りがないかを見るのですが、batch sizeやsequence lengthは大きい値になることが多く、ほとんどの場合lossが意味を成さないのでは?と思います。これについては、Github上で議論が行われており、今後の展開に期待です。
学習
モデルサイズは、275.86M
8.64Bのデータセットを2epoch
環境はpaperspace gradientのA6000(1GPU)を用います。
学習の結果、val loss: 4.9381で、かかった時間は約175hでした。
upload
huggingfaceにモデルをuploadしておきます。
huggingfaceのGUIからrepositoryを作成。
loginしておきます
from huggingface_hub import login
login()
modelをロードしてupload
with lazy_load(checkpoint_path) as checkpoint:
config = MixtralConfig_HF.from_name(model_name)
model = MixtralForCausalLM(config)
model.load_state_dict(checkpoint)
model.push_to_hub(repo_id)
Instruction Tuning
instructionデータセットを使いfinetuningします。
データセット
いつもお世話になっている以下を使います
学習
huggingfaceにモデルをuploadしたので、学習にはtorchではなくtransformersのコードを使います。
TrainingArgumentsやTrainerを用意
3epoch学習し、その結果、最終的なlossは0.021となりました。
少し学習率が大きかった気がするのと過学習気味かもしれません。
JGLUE
JGLUEで評価しておきます。
TASK=[
"jcommonsenseqa-1.1-0.3", # 5択の常識的問題を推論
"jnli-1.3-0.3", # 2つの文章の関係を推定
"marc_ja-1.1-0.3" # amazonのレビューがpositive/negativeかを分類
"jsquad-1.1-0.3", # 与えられた文章から答えを抽出
"xlsum_ja-1.0-0.3", # 要約
"jaqket_v2-0.2-0.3" # Wikipediaの記事名を答えとしたQA
"xwinograd_ja",# 照応解析
"mgsm-1.0-0.3" # 簡単な算数のタスク
]
lm-evaluation-harnessの評価
model | average | jcommonsenseqa | jnli | marc_ja | jsquad | jaqket_v2 | xlsum_ja | xwinograd_ja | mgsm |
---|---|---|---|---|---|---|---|---|---|
tiny-mixtral | 41.41 | 42.90 | 54.85 | 83.66 | 11.00 | 70.44 | 4.02 | 51.41 | 3 |
stabilityai-japanese-stablelm-instruct-alpha-7b | 54.71 | 82.22 | 52.05 | 82.88 | 63.26 | 74.83 | 7.79 | 72.68 | 2 |
rinna-bilingual-gpt-neox-4b-instruction-sft | 47.75 | 49.51 | 47.08 | 95.28 | 55.99 | 61.17 | 5.51 | 64.65 | 2.8 |
平均スコアでは下回っているが、サイズの割に善戦した気がします
今回のモデルは、jnliやmarc_jaはそこそこのスコアになってます。それぞれ「2つの文章の関係を含意、矛盾、中立の3択から選ぶ」、「positive/negativeの2択の分類」のタスクになっており、簡単なタスクや出力は行えそう。
一方、jcommonsenseqaやjsquadではスコアが低くなってます。jcommonsenseqaは問題といくつかの選択肢が提示されその中から正解を抽出するタスクで、文章をより正確に読み取りそれを出力に適応する必要があります。このような複雑なタスクは行えなさそう。
Nejumiリーダーボード
スクリプトは以下
model | average | EL | FA | MC | MR | NLI | QA | RC |
---|---|---|---|---|---|---|---|---|
tiny-mixtral | 0.0507 | 0 | 0.0312 | 0.00 | 0.00 | 0.000 | 0.0856 | 0.2384 |
rinna/nekomata-14b-instruction | 0.4375 | 0.0067 | 0.1651 | 0.77 | 0.42 | 0.494 | 0.3402 | 0.8663 |
rinna/nekomata-7b-instruction | 0.3194 | 0 | 0.059 | 0.77 | 0.03 | 0.162 | 0.3318 | 0.8826 |
タスクはそれぞれ
EL (Entity Linking)
FA (Fundamental Analysis)
MC (Multiple Choice question answering) (lm-evaluation-harnessではJCommonsenseQA)
MR (Mathematical Reasoning)
NLI (Natural Language Inference) (lm-evaluation-harnessではjnli)
QA (Question Answering)
RC (Reading Comprehension) (lm-evaluation-harnessではJSQuAD)
こちらでは、他のモデルと比べてだいぶスコアが低め。
3択を選ぶタスクでもlm-evaluation-harnessでは数字のみで出力で良かったが、正確に正解の文字列を出力しなければ正解とならない。より正確な出力が求められる分、スコアが低くなっている
まとめ
250MのMixtralをpretrainingからfinetuningまでを行いました。小さいサイズなりにうっすら日本語を理解してそう。入力から正確に情報を抽出とそれらを使った出力はさすがに難しそう。あとは、推論時のexpertの選択のされかたや同サイズのモデルとの比較をしてみたいところ
lit-llamaではLightning Fabricを使っており、GPU並列、マシン並列が簡単に試せそうなのでこのあたりを試してみる
Discussion