【0.1Bから作るLLM】JAX/Flaxで作るTransformer言語モデル ❸ wiki40b編
JAX/Flaxで作るTransformer言語モデル ❸ wiki40b編
概要
wiki40b編では、Wikipediaベースのデータセットであるwiki40bの日本語データを使って、クラウド上のTPUでTransformer言語モデルをトレーニングする手順を解説します。前回のcc100編では、モデルのハイパーパラメーターを変更することはなくオリジナル同様6層のTransformerでしたが、今回はGPT-2 smallのモデルアーキテクチャを参考にして、12層のTransformerに拡張しモデルの性能向上を図ります。そして、最終的にパラメータ数が約1億3000万(0.1B)の言語モデルを完成させます。最後に、学習済み言語モデルを使って、テキスト生成スクリプトで実際に日本語文章を生成してみます。
記事の構成
本シリーズは、以下の3つの記事から構成されています。
記事 | 説明 |
---|---|
❶lm1b編 | まずは、オリジナルと同様に英語データセット(lm1b)で学習させる基本的な手順を解説します。 |
❷cc100編 | 次に、lm1bとよく似た短文形式のデータセットであるcc100の日本語で学習させる手順を解説します。また、その際に必要なGoogleクラウドストレージ上のデータセットで学習する手順も解説します。 |
❸wiki40b編 | 最後に、wikipediaベースの日本語データセットで学習させます。さらに、GPT-2 smallのモデルアーキテクチャを参考に12層まで拡張し0.1Bのモデルを完成させます。テキスト生成スクリプトで実際に日本語文章を生成します。 |
❸ wiki40b編 - 日本語データセットでトレーニングする
前回までのおさらい
前回記事のcc100編では、Googleクラウドストレージ(GCS)上にデータセット保存用とワークディレクトリ用の2つのバケットを作成しました。そして、データセット用のバケットにデータセットを事前ダウンロードも行いました。また、PythonコードからGCSバケットにAPIアクセスするために必要なAPIキーファイルを作成しダウンロードしました。
この記事では、GCSバケットにデータセットの事前ダウンロードが完了していて、APIキーファイルも手元にある前提で『TPU-VMの作成』以降の手順から解説します。
【前提条件】
- GCSバケットにAPIアクセスするためキーファイル(.json)が手元にある
- GCSバケットにデータセット(wiki40b/ja)の事前ダウンロードが完了している
もしまだの方は、前回記事のcc100編を参照下さい。
TPU-VMの作成
Google Cloud Consoleから、[Compute Engine]-[TPU]-[TPUノードを作成]を選択、以下の設定でTPU-VMを作成します。
- 名前:my-tpu-vm
- ゾーン:uscentral1-a
- TPU VM アーキテクチャ(推奨)
- TPUタイプ:v3-8
- TPUソフトウェアバージョン:v2-alpha
- このTPUノードのプリエンプティブをオンにする
TPU-VMにSSHアクセス
以下のコマンドでTPU-VMにSSHでアクセスします。
(local)$ gcloud compute tpus tpu-vm ssh my-tpu-vm --zone=us-central1-a
ソースコードのクローンとパッケージのインストール
lm1bを日本語データセットで学習できるように修正したコードをGitHubからクローンし、必要なPythonパッケージをインストールします。
(tpu-vm)$ git clone -b 1.0.0.RC3 https://github.com/FookieMonster/transformer-lm-japanese
(tpu-vm)$ cd ./transformer-lm-japanese/transformer_lm
(tpu-vm)$ pip install -r requirements.txt
(tpu-vm)$ pip install "jax[tpu]==0.4.13" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
設定ファイルを確認
今回トレーニングで使用する設定ファイル(japanese_0.1b_v1.py)を少しだけ確認してみます。
まず、学習と評価に使用するデータセットがwiki40bの日本語であることが分かります。
config.dataset_name = "wiki40b/ja"
config.train_split = "train"
config.eval_dataset_name = "wiki40b/ja"
config.eval_split = "validation"
SentencePieceの語彙サイズは30000、サブワード学習に使用するテキストサイズ400MB、character_coverageを日本語の推奨値0.9995、byte_fallback有効で未知語なしに設定しています。
config.vocab_size = 30_000
config.max_corpus_chars = (10**6) * 400
config.spm_train_options = "--character_coverage=0.9995 --byte_fallback=true"
トランスフォーマーのレイヤー数が12層、アテンションのQKVが768次元、単語埋め込み768次元、全結合層が3072次元、アテンションが12ヘッド、であることが分かります。GPT-2 smallのモデルアーキテクチャを参考にしています。
# Number of transformer layers.
config.num_layers = 12
# Size of query/key/value for attention.
config.qkv_dim = 768
# Size of embeddings.
config.emb_dim = 768
# Size of the MLP.
config.mlp_dim = 3072
# Number of attention heads.
config.num_heads = 12
最大の入力シーケンス長(max_target_length)を128から256に増やしています。増やすとモデルの性能が上がるみたいですが、その分トレーニングに時間がかかります。ここでは約1.5日でトレーニングが完了する256を採用しました。
# Maximum length cutoff for training examples.
config.max_target_length = 256
# Maximum length cutoff for eval examples.
config.max_eval_target_length = 512
# Maximum length cutoff for predicted tokens.
config.max_predict_length = 512
トレーニング時間の短縮のため、eval_every_stepsを1000から5000に変更しています。
# Frequency of eval during training, e.g. every 1_000 steps.
config.eval_every_steps = 5_000
トレーニングの開始
APIキーのアップロード
APIキーをローカルPC側からTPU-VMインスタンス側へコピーする必要があります。
gcloud CLIのscpコマンドでコピーします。scpコマンドの書式は以下のとおりです。
書式
gcloud compute tpus tpu-vm scp [ローカルファイルのパス] [TPU-VMインスタンス名]:[リモートファイルのパス] --zone=[ゾーン]
例) カレントディレクトリにあるAPIキーファイルを、my-tpu-vm側の/tmp/service-account-api-key.jsonにコピー
(local)$ gcloud compute tpus tpu-vm scp ./service-account-api-key.json my-tpu-vm:/tmp/service-account-api-key.json --zone=us-central1-a
環境変数のセット
GCSバケットにAPIアクセスするために必要なキーファイルを環境変数にセットします。
(tpu-vm)$ export GOOGLE_APPLICATION_CREDENTIALS="/tmp/service-account-api-key.json"
TensorFlow DatasetsのデータディレクトリをGCSのバケットに設定します。
(tpu-vm)$ export TFDS_DATA_DIR=gs://my-tfds-data
トレーニング
GCS上のワークディレクトリを指定してトレーニングを開始します。
(tpu-vm)$ python3 main.py --workdir=gs://my-lm-work/japanese_0.1b_v1 \
--config=configs/japanese_0.1b_v1.py
トレーニング結果
トレーニングが正常に進んでいれば、GCSのバケット内のフォルダにTensorBoardのイベントログやチェックポイントファイルが保存されているはずです。イベントログをダウンロードし、ローカルのTensorBoardで表示します。
(local)$ tensorboard --logdir logs/japanese_0.1b_v1
トレーニングが正しく行われていると、Perplexity=35、Loss=3.5ぐらいになるはずです。
日本語文章の生成
完成した0.1Bの言語モデルを使って、実際に日本語文章の生成を行います。generate_text.pyというテキスト生成スクリプトをTPU-VM上で使います。各引数の説明は以下の通りです。
引数 | 説明 |
---|---|
workdir | 学習済みのチェックポイントがあるディレクトリを指定 |
config | モデルのアーキテクチャを決定する設定ファイルを指定 |
config.sampling_temperature | サンプリング温度、0に近いほど貪欲法と等しくなる |
config.sampling_top_k | 上位k個を使ってサンプリング、0の場合全てからサンプリング |
config.seed | ランダム生成器のシード値を指定 |
config.prompts | プロンプトを指定 |
num_generated_texts | 生成する文章の件数を指定します |
(tpu-vm)$ python3 generate_text.py --workdir=gs://my-lm-work/japanese_0.1b_v1 \
--config=configs/japanese_0.1b_v1.py \
--config.sampling_temperature=0.6 \
--config.sampling_top_k=20 \
--config.seed=0 \
--config.prompts="夏目漱石は、" \
--num_generated_texts=10
学習済みの言語モデルから、「夏目漱石は、」で始まる日本語の文章を10件生成しています。
Generating text.
Sample: 夏目漱石は、自分の作品を「文学の本」として出版することを構想していた。
Generating text.
Sample: 夏目漱石は、明治の文学運動を「文学の原点に立ち帰る」と位置づけ、漱石が「文学の本質をあらわすのが文学である」との認識を、当時の知識人たちが持っていたことを指摘している。
Generating text.
Sample: 夏目漱石は、小説『坊っちゃん』で、この「坊っちゃん」を「坊っちゃん」に置き換えた。「坊っちゃん」は、坊っちゃんの「坊」の字を、「坊」は「坊」の字をもじってつけられた。
Generating text.
Sample: 夏目漱石は、漱石の『坊っちゃん』を読んで、「漱石は、私に『坊っちゃん』をおもしろおかしく書かせた。これは、私に『坊っちゃん』を書かせるのを、私に教えてくれたからだ」と述懐している。
Generating text.
Sample: 夏目漱石は、自身の著作『漱石全集』の中で「漱石が生涯のほとんどを漱石の文学に捧げた」と評価している。
Generating text.
Sample: 夏目漱石は、漱石が「『吾輩は猫』を観るのが嫌だ」と言ったのを、漱石が「あんなに怖いとは思わなかった」と返している。
Generating text.
Sample: 夏目漱石は、自身の日記の中で「文学の本質と現実との間には、対立関係があり、また対立関係があっても、それが文学の本質と現実との間には関係がある」と書いている。
Generating text.
Sample: 夏目漱石は、夏目が漱石の『吾輩は猫である』を読んでいた時に、漱石の『吾輩は猫である』を読んだという。漱石は「猫は猫である」と書いていたが、漱石は「猫である」と書いた。
Generating text.
Sample: 夏目漱石は、小説『坊っちゃん』の中で、主人公が「おばあさん」と「おばあさん」の2人で暮らしていると、その家から「おばあさん」と「おばあさん」が飛び出してくるという話を紹介している。
Generating text.
Sample: 夏目漱石は、漱石の「吾輩は猫である」という言葉を、漱石が「猫を飼っている人は猫である」という誤解から誤解したのだろうと、著書『猫の散歩道』で述べている。
パラメータ数0.1Bで、トレーニング期間1.5日にしては、悪くない生成結果ではないでしょうか?
個人的には、10番目の生成結果の "著書『猫の散歩道』で述べている。" が気になります。幻覚(ハルシネーション)でしょうか。
今回作成した、0.1Bの学習済みモデルはHugging Faceのモデルハブでも公開しています。TPUではなく、CPUを使って日本語の文章生成を行う手順も記載してあるので、気軽に試すことが可能です。詳細はモデルカードを参照ください。
さいごに
このシリーズでは、Transformer言語モデルの事前学習を通じて、日本語文章の生成が可能になりました。今後、ChatGPTのような対話応答できるモデルを目指して、続編の「❹ ファインチューニング編」や「❺ RLHF編」に取り組む予定です。具体的な時期はまだ確定していませんが、このアカウントをフォローをして頂ければ、開始時に通知が届くと思います。
参考資料
- JAX
- Flax
- TensorFlow Datasets
- Google Cloud Platform
- SentencePiece
- OpenAI GPT2
Discussion