【0.1Bから作るLLM】JAX/Flaxで作るTransformer言語モデル ❷ cc100編
JAX/Flaxで作るTransformer言語モデル ❷ cc100編
概要
cc100編では、lm1bとよく似た短文形式のデータセットであるCC-100の日本語データを使って、クラウド上のTPUでTransformer言語モデルをトレーニングする手順を解説します。lm1bを単純に日本語データセットで学習するようにするのみで、モデルアーキテクチャはオリジナル同様6層のTransformerです。また、その際に必要なGoogleクラウドストレージ(以下GCS)のセットアップについても解説します。
記事の構成
本シリーズは、以下の3つの記事から構成されています。
記事 | 説明 |
---|---|
❶lm1b編 | まずは、オリジナルと同様に英語データセット(lm1b)で学習させる基本的な手順を解説します。 |
❷cc100編 | 次に、lm1bとよく似た短文形式のデータセットであるcc100の日本語で学習させる手順を解説します。また、その際に必要なGoogleクラウドストレージ上のデータセットで学習する手順も解説します。 |
❸wiki40b編 | 最後に、wikipediaベースの日本語データセットで学習させます。さらに、GPT-2 smallのモデルアーキテクチャを参考に12層まで拡張し0.1Bのモデルを完成させます。テキスト生成スクリプトで実際に日本語文章を生成します。 |
❷ cc100編 - 日本語データセットでトレーニングする
なぜGCSを使う必要があるのか?
前回記事の❶lm1b編では、データセットのダウンロード先はデフォルトの~/tensorflow_datasetsに保存しました、また、TensorBoardのイベントログやチェックポイントファイルの保存であるワークディレクトリも~/logsのローカルディスク上に保存していました。
今回のcc100編では、データセットとワークディレクトリの両方でGCSを利用します。
理由は以下の2つです。
- TPU-VMのローカルディスク容量は100GB程しかありません、cc100/jaなどの大規模データセットをダウンロードし展開することが困難です。
- TPU-VMのローカルディスクはインスタンスを削除するとアクセスできなくなります。特に、TPU-VMを料金7割引のプリエンプティブ設定で利用した場合、インスタンスはいつでも停止される可能性があり、一度プリエンプティブ状態になるとローカルディスクにアクセスできず、インスタンスを削除することしかできません。
GCSバケットを作成
前回記事の❶lm1b編で作成したGCPプロジェクトと同一プロジェクト内に、データセットの保存先とワークディレクトリ用の2つのGCSバケットを作成します。ここでは以下の名前で作成することとします。
- my-lm-work (ワークディレクトリ用でチェックポイントやTensorBoardログが保存される)
- my-tfds-data (TensorFlow Datasetsのデータ保存用)
APIキーの作成とダウンロード
PythonコードからGCSバケットにAPIアクセスするために必要なAPIキーファイルを作成しダウンロードします。
Google Cloud Consoleから、[IAMと管理]-[サービスアカウント]を選択します。 そのリストに「Compute Engine default service account」という名前のアカウントがデフォルトで作成されているはずです。 そのリンクを選択し、[キー]-[鍵を追加]-[新しい鍵を作成]-[キーのタイプ JSON]を選択肢キーファイルをダウンロードします。
データセットの事前ダウンロード
データセットの初回ダウンロードにはとても時間がかかります。 lm1bやwiki40b/jaは数時間で完了しますが、cc100/jaは数十時間かかります。 TPU-VM上でダウンロードを行うとコストがかかるので、別のPython3.8環境で事前ダウンロード(GCSのバケットにアップロード)を行います。
また、cc100/jaはデータセットのサイズが74GB(temporaryが数百GB)になるので、一時的なディスクの空き容量が最低でも数百GB必要です。 私の場合は、以下のようなGCE上のVMインスタンス(以下CPU-VM)を使ってPython3.8環境を構築しました。
Google Cloud Consoleから、[Compute Engine]-[VM インスタンス]-[インスタンスを作成]を選択、以下の設定でCPU-VMを作成します。
- 名前: my-cpu-vm
- マシンタイプ: c2-standard-4 CPUx4 メモリ16GB
- ゾーン: us-central1-a
- ディスク: 2TB(標準永続ディスク)
- OS: Ubuntu 20.04 LTS x86/64
CPU-VMにSSHでアクセスします。
(local)$ gcloud compute ssh my-cpu-vm --zone=us-central1-a
まず、Python3.8とpipをインストールします。
(cpu-vm)$ sudo apt-get update
(cpu-vm)$ sudo apt-get install python3.8 python3-pip build-essential
次に、tensorflow-datasetsとその関連パッケージをインストールします。
(cpu-vm)$ pip install tensorflow==2.11.1
(cpu-vm)$ pip install tensorflow-datasets==4.8.3
(cpu-vm)$ pip install datasets==2.12.0
次に、GCSバケットにAPIアクセスするために、APIキーファイルをアップロードしそのパスを環境変数に設定します。
例) カレントディレクトリにあるAPIキーファイルを、my-cpu-vm側の/tmp/service-account-api-key.jsonにコピー
(local)$ gcloud compute scp ./service-account-api-key.json my-cpu-vm:/tmp/service-account-api-key.json --zone=us-central1-a
(cpu-vm)$ export GOOGLE_APPLICATION_CREDENTIALS="/tmp/service-account-api-key.json"
次に、Pythonインタプリターを起動して、データセットを順番にダウンロードします。
(cpu-vm)$ python3
>>> import tensorflow_datasets as tfds
>>> tfds.load('lm1b', data_dir="gs://my-tfds-data")
>>> tfds.load('wiki40b/ja', data_dir="gs://my-tfds-data")
>>> tfds.load('huggingface:cc100/lang=ja', data_dir="gs://my-tfds-data")
データセットの事前ダウンロードが完了すると、以下のようにGCSバケット内に各データセット用のフォルダが作成されているはずです。
TPU-VMの作成
Google Cloud Consoleから、[Compute Engine]-[TPU]-[TPUノードを作成]を選択、以下の設定でTPU-VMを作成します。
- 名前:my-tpu-vm
- ゾーン:uscentral1-a
- TPU VM アーキテクチャ(推奨)
- TPUタイプ:v3-8
- TPUソフトウェアバージョン:v2-alpha
- このTPUノードのプリエンプティブをオンにする
TPU-VMにSSHアクセス
以下のコマンドでTPU-VMにSSHでアクセスします。
(local)$ gcloud compute tpus tpu-vm ssh my-tpu-vm --zone=us-central1-a
ソースコードのクローンとパッケージのインストール
lm1bを日本語データセットで学習できるように修正したコードをGitHubからクローンし、必要なPythonパッケージをインストールします。
(tpu-vm)$ git clone -b 1.0.0.RC3 https://github.com/FookieMonster/transformer-lm-japanese
(tpu-vm)$ cd ./transformer-lm-japanese/transformer_lm
(tpu-vm)$ pip install -r requirements.txt
(tpu-vm)$ pip install "jax[tpu]==0.4.13" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
設定ファイルを確認
今回トレーニングで使用する設定ファイル(japanese_default_v1.py)を少しだけ確認してみます。
まず、学習と評価に使用するデータセットがcc100の日本語データセットであることが分かります。
config.dataset_name = "huggingface:cc100/lang=ja"
config.train_split = "train[:98%]"
config.eval_dataset_name = "huggingface:cc100/lang=ja"
config.eval_split = "train[98%:]"
SentencePieceの語彙サイズは30000、サブワード学習に使用するテキストサイズ5MB、character_coverageを日本語の推奨値0.9995に設定しています。
config.vocab_size = 30_000
config.max_corpus_chars = (10**6) * 5
config.spm_train_options = "--character_coverage=0.9995"
トランスフォーマーのレイヤー数が6層、アテンションのQKVが512次元、単語埋め込み512次元、全結合層が2048次元、アテンションが8ヘッド、であることが分かります。
# Number of transformer layers.
config.num_layers = 6
# Size of query/key/value for attention.
config.qkv_dim = 512
# Size of embeddings.
config.emb_dim = 512
# Size of the MLP.
config.mlp_dim = 2048
# Number of attention heads.
config.num_heads = 8
トレーニングの途中でテキストをサンプリングする際のプロンプトも日本語にしています。
# Prompt for language model sampling.
config.prompts = "昔々あるところに"
トレーニングの開始
APIキーのアップロード
APIキーをローカルPC側からTPU-VMインスタンス側へコピーする必要があります。
gcloud CLIのscpコマンドでコピーします。scpコマンドの書式は以下のとおりです。
書式
gcloud compute tpus tpu-vm scp [ローカルファイルのパス] [TPU-VMインスタンス名]:[リモートファイルのパス] --zone=[ゾーン]
例) カレントディレクトリにあるAPIキーファイルを、my-tpu-vm側の/tmp/service-account-api-key.jsonにコピー
(local)$ gcloud compute tpus tpu-vm scp ./service-account-api-key.json my-tpu-vm:/tmp/service-account-api-key.json --zone=us-central1-a
環境変数のセット
GCSバケットにAPIアクセスするために必要なキーファイルを環境変数にセットします。
(tpu-vm)$ export GOOGLE_APPLICATION_CREDENTIALS="/tmp/service-account-api-key.json"
TensorFlow DatasetsのデータディレクトリをGCSのバケットに設定します。
TensorFlow Datasetsデフォルトのデータ保存先は~/tensorflow_datasetsのローカルディスクですが、環境変数のTFDS_DATA_DIRにパスを設定することで変更が可能です。
(tpu-vm)$ export TFDS_DATA_DIR=gs://my-tfds-data
トレーニング
GCS上のワークディレクトリを指定してトレーニングを開始します。
(tpu-vm)$ python3 main.py --workdir=gs://my-lm-work/japanese_default_v1 \
--config=configs/japanese_default_v1.py
トレーニング結果
トレーニングが正常に進んでいれば、GCSのバケット内のフォルダにTensorBoardのイベントログやチェックポイントファイルが保存されているはずです。イベントログをダウンロードし、ローカルのTensorBoardで表示します。
(local)$ tensorboard --logdir logs/japanese_default_v1
トレーニングが正しく行われていると、Perplexity=66、Loss=4.2ぐらいになるはずです。
正直あまり性能が良くないですね、今回は、単純にlm1bのデータセットをcc100/jaに置き換えただけで、ハイパーパラメーターの変更は特に何も行っていません。
次回の❸wiki40b編では、GPT-2 smallのモデルアーキテクチャを参考にして、12層のトランスフォーマーにするなど、ハイパーパラメーターのチューニングでモデルの性能向上を目指します。
参考資料
- JAX
- Flax
- TensorFlow Datasets
- Google Cloud Platform
- SentencePiece
- OpenAI GPT2
Discussion