🚀

【0.1Bから作るLLM】JAX/Flaxで作るTransformer言語モデル ❶ lm1b編

2023/07/18に公開

JAX/Flaxで作るTransformer言語モデル ❶ lm1b編

概要

Flax公式のサンプルコード集には、lm1bという名前のTransformerのデコーダー型の言語モデルが含まれています。このサンプルコードは元々、One Billion Word Benchmark(lm1b)という英語データセットで学習することが可能です。本シリーズでは、そのサンプルコードを日本語データセットで学習できるよう修正したコードを使用し、クラウド上のTPUでトレーニングを行う手順を説明します。その結果、日本語のテキストを生成できる言語モデルを作ります。

さらに、lm1bのサンプルコードのモデルはデフォルト設定で6層のTransformerですが、本シリーズでは、GPT-2 smallのモデルアーキテクチャを参考に12層まで拡張します。これにより、最終的なパラメータ数が約1億3000万(0.1B)となるモデルを作成します。作成した0.1Bの学習済み言語モデルを使って、テキスト生成スクリプトで実際に日本語文章を生成します。

記事の構成

本シリーズは、以下の3つの記事から構成されています。

記事 説明
❶lm1b編 まずは、オリジナルと同様に英語データセット(lm1b)で学習させる基本的な手順を解説します。
❷cc100編 次に、lm1bとよく似た短文形式のデータセットであるcc100の日本語で学習させる手順を解説します。また、その際に必要なGoogleクラウドストレージ上のデータセットで学習する手順も解説します。
❸wiki40b編 最後に、wikipediaベースの日本語データセットで学習させます。さらに、GPT-2 smallのモデルアーキテクチャを参考に12層まで拡張し0.1Bのモデルを完成させます。テキスト生成スクリプトで実際に日本語文章を生成します。

使用するソースコード

本シリーズでは、lm1bを日本語データセットで学習できるように修正したコードを使用します。修正したソースコードはGitHubで公開しています。詳細は、GitHubのREADMEを参照して下さい。

https://github.com/FookieMonster/transformer-lm-japanese

Pythonコードの説明

日本語言語モデルの作成には、Flax公式のサンプルコード『lm1b』を日本語データセットで学習できるように修正したtransformer-lm-japaneseを使用します。オリジナルのサンプルコードは6つのPythonファイル(テストコードを除く)から構成されていますが、日本語データセットでの学習に対応するために、新しく2つのPythonファイルの追加(表●)と、一部オリジナルコードの修正を行っています。

Pythonコード 説明 NEW
train.py トレーニングのメインループ処理があります
tokenizer.py SentencePieceでサブワード学習とトークン化を行います
temperature_sampler.py 温度サンプリングで文章生成を行います
models.py Transformer言語モデルの定義本体です
main.py プログラムのエントリポイントです
input_pipeline.py データセットの読み取り等を行います
dataset_preprocessor.py 日本語データセットの前処理を行うプリプロセッサ
generate_text.py 学習済みのチェックポイントから日本語文章を生成するスクリプトです

設定ファイルの説明

言語モデルのアーキテクチャを決定する設定ファイルを事前に用意しています。本シリーズでは、これらの事前に用意された設定ファイルで言語モデルのトレーニングを行います。

設定ファイル 説明
configs/lm1b_default.py ❶ lm1b編で使う設定ファイルです。オリジナルのサンプルコードと同じく6層のTransformerです。
configs/japanese_default_v1.py ❷ cc100編で使う設定ファイルです。オリジナルのサンプルコードと同じく6層のTransformerですが、データセットがcc100の日本語データです。
configs/japanese_0.1b_v1.py ❸ wiki40b編で使う設定ファイルです。GPT-2 smallを参考に12層のTransformerです。データセットがwiki40bの日本語データです。

パラメータ数とトランスフォーマー層の設定は以下のとおりです。

設定ファイル データセット パラメータ数 Layers Dim Heads
configs/lm1b_default.py lm1b 0.05B 6 512 8
configs/japanese_default_v1.py cc100/ja 0.05B 6 512 8
configs/japanese_0.1b_v1.py wiki40b/ja 0.1B 12 768 12

追加された設定項目

transformer-lm-japaneseでは、Pythonコードだけではなく、設定ファイルの項目も一部追加されています。SentencePieceでサブワード学習する際に、日本語特有のオプションを設定ファイルから追加できるように、以下の設定項目を追加しています。

config.spm_train_options = "--character_coverage=0.9995 --byte_fallback=true"

❶ lm1b編 - 英語データセットでトレーニングする

まずは、オリジナルのサンプルコードと同様に、英語データセット(lm1b)を、Cloud上のTPUでトレーニングし、言語モデルを作成する手順を解説します。

GCPプロジェクトの作成

まず、Google Cloud Consoleから新しくプロジェクトを作成します。
ここでは、『Transformer LM』という名前のプロジェクトを作成したとします。

CLIのインストール

次に、gcloud CLI をインストールするのドキュメントに従ってgcloudコマンドをインストール&セットアップして下さい。

gcloudコマンドが正常にセットアップできたことを以下のコマンドで確認します。

(local)$ gcloud config list

新しく作成したプロジェクトIDが表示されていればセットアップは完了です。

[compute]
region = us-central1
zone = us-central1-a
[core]
account = xxxxx@gmail.com
disable_usage_reporting = False
project = (新しく作成したプロジェクトID)

Your active configuration is: [default]

Cloud TPU APIの有効化

Google Cloud Consoleから、[Compute Engine] - [TPU]を選択しCloud TPU APIを有効化します。

TPU-VMの作成

Google Cloud Consoleから、[Compute Engine]-[TPU]-[TPUノードを作成]を選択、以下の設定でTPU-VMを作成します。

  • 名前:my-tpu-vm
  • ゾーン:uscentral1-a
  • TPU VM アーキテクチャ(推奨)
  • TPUタイプ:v3-8
  • TPUソフトウェアバージョン:v2-alpha

プリエンプティブル TPUについて

TPU-VMにSSHアクセス

以下のコマンドでTPU-VMにSSHでアクセスします。

(local)$ gcloud compute tpus tpu-vm ssh my-tpu-vm --zone=us-central1-a

接続するとTPU-VMは実際にはUbuntu 20.04のVMインスタンスであることが分かります。

Welcome to Ubuntu 20.04.2 LTS (GNU/Linux 5.4.0-1043-gcp x86_64)

 * Documentation:  https://help.ubuntu.com
 * Management:     https://landscape.canonical.com
 * Support:        https://ubuntu.com/advantage

  System information as of Wed Jun  7 13:25:13 UTC 2023

  System load:  0.96               Processes:                1006
  Usage of /:   15.4% of 96.75GB   Users logged in:          0
  Memory usage: 0%                 IPv4 address for docker0: x.x.x.x
  Swap usage:   0%                 IPv4 address for ens8:    x.x.x.x

ソースコードのクローンとパッケージのインストール

lm1bを日本語データセットで学習できるように修正したコードをGitHubからクローンし、必要なPythonパッケージをインストールします。

(tpu-vm)$ git clone -b 1.0.0.RC3 https://github.com/FookieMonster/transformer-lm-japanese
(tpu-vm)$ cd ./transformer-lm-japanese/transformer_lm
(tpu-vm)$ pip install -r requirements.txt
(tpu-vm)$ pip install "jax[tpu]==0.4.13" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

データセットの事前ダウンロード

Pythonインタプリータを起動し、TensorFlow Datasetsのデータセット(lm1b)を事前ダウンロードします。

(tpu-vm)$ python3
>>> import tensorflow_datasets as tfds
>>> tfds.load('lm1b')
>>> exit()

設定ファイルを確認

今回トレーニングで使用する設定ファイル(lm1b_default.py)を少しだけ確認してみます。

https://github.com/FookieMonster/transformer-lm-japanese/blob/main/transformer_lm/configs/lm1b_default.py#L20-L123

まず、学習と評価に使用するデータセットがlm1bのtrain/testスプリットであることが分かります。

config.dataset_name = "lm1b"
config.train_split = "train"
config.eval_dataset_name = "lm1b"
config.eval_split = "test"

Transformerのレイヤー数が6層、アテンションのQKVが512次元、単語埋め込み512次元、全結合層が2048次元、アテンションが8ヘッド、であることが分かります。

# Number of transformer layers.
config.num_layers = 6

# Size of query/key/value for attention.
config.qkv_dim = 512
# Size of embeddings.
config.emb_dim = 512
# Size of the MLP.
config.mlp_dim = 2048

# Number of attention heads.
config.num_heads = 8

トレーニング開始

ワークディレクトリと設定ファイルを指定してトレーニングを開始します。

ワークディレクトリには、TensorBoardのイベントログや学習済みの重み(チェックポイント)ファイルが保存されます。

設定ファイルには、学習に利用するデータセットの名前や、Transformerのレイヤー数などのモデルアーキテクチャの設定などが記載されています。

(tpu-vm)$ python3 main.py --workdir=$HOME/logs/lm1b_default \
              --config=configs/lm1b_default.py

トレーニング結果

トレーニングが完了したら、TPV-VM側のワークディレクトリをローカルPC側に一旦コピーし、TensorBoardで学習の推移をグラフで確認します。

コピーには、gcloudのscpコマンドを使います。

書式:
gcloud compute tpus tpu-vm scp [TPU-VM名]:[コピー元パス] [コピー先パス] --zone=[ゾーン]

例)my-tpu-vmのユーザfooのホームにあるファイルをコピー

(local)$ gcloud compute tpus tpu-vm scp my-tpu-vm:/home/foo/logs/lm1b_default/* ./logs/lm1b_default --zone=us-central1-a

コピーしたフォルダを指定してTensorBoardを起動し、localhost:6006をブラウザで開きます。

(local)$ tensorboard --logdir logs/lm1b_default

トレーニングが正しく行われていると、Perplexity=22、Loss=3.1ぐらいになるはずです。
そうなれば、データセット(lm1b)によるTransformer言語モデルのトレーニングの再現は成功です。

次回の❷cc100編では、日本語データセットでの言語モデルのトレーニングを行います。

参考資料

Discussion