RetNetによる学習・評価・推論
概要
Retentive Network(RetNet)はMicrosoft ResearchとTsinghua Universityが共同で発表した大規模言語モデル向けアーキテクチャであり、その学習速度の速さやメモリ効率の高さから、Transformerの後継モデルとして注目されています。
RetNetの公式ソースコードは、こちらにアップロードされています。ただし、8/6現在では、このソースコードの使用方法の記載は見当たりません。
そこで本記事では、この公式のソースコードを使用し、wiki-text103データセットをRetNetで学習・評価・推論する方法を説明します。
検証環境
項目 | バージョン | 備考 |
---|---|---|
OS | Windows 10 | |
GPU | RTX4070 VRAM12GB | |
python | 3.10.4 | pyenv localにて設定 |
ubuntu | 18.04.05 | wslにて使用。wsl2でもおそらく使用可能。 |
諸注意
本記事のコードブロック内のコマンドは、基本的にubuntu 18.04.05 on Windowsターミナル(以下、wsl)上で実行しています。
本記事の誤記や、より良い方法などがありましたら、コメント欄にてご指摘ください。
環境構築
-
こちらの通りに、wslとpyenv-winの競合を解消し、wsl用のpyenvをinstallします
-
以下のコマンドを実行し、本記事用のフォルダとpython仮想環境を構築します
mkdir RetNetTutorial cd RetNetTutorial mkdir .wsl_env cd .wsl_env pyenv local 3.10.4 python -V # 3.10.4 # 仮想環境構築 python -m venv ./ # pipを新しくしておく cd bin # 仮想環境起動 source activate cd ../../ # 以下は実行しない。備忘録。 # PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/usr/lib/wsl/lib:/snap/bin" # source /etc/profile # source ~/.bashrc
-
以下のコマンドを実行し、必要なライブラリをインストールします
# fairseqのshellスクリプトを使用するため、git cloneする git clone https://github.com/pytorch/fairseq git clone https://github.com/microsoft/torchscale.git cd torchscale pip install -e ./ cd ../ pip install git+https://github.com/shumingma/fairseq.git@moe pip install git+https://github.com/shumingma/infinibatch.git pip install iopath pip install numpy==1.23.0 pip install tensorboardX pip install sentencepiece pip install boto3 # pip freeze > requirements.txt
データセットの作成
学習用データセットは、wikiText-103データセットを使用します。
Fairseq/examples/language_modelにprepare-wikitext-103.shというシェルスクリプトがあるので、こちらを使用します。
cd fairseq/examples/language_model/
sudo apt-get install dos2unix
# wslの場合、dos2unixを実行しないと$'\r': command not foundエラーが発生する
dos2unix prepare-wikitext-103.sh
bash prepare-wikitext-103.sh
前処理
fairseq-preprocessを使用し、データセットをバイナリ化します。
wsl上で、以下のコマンドを入力します。
# カレントディレクトリは/RetNetTutorial/fairseq/examples/language_model/
fairseq-preprocess \
--only-source \
--trainpref wikitext-103/wiki.train.tokens \
--validpref wikitext-103/wiki.valid.tokens \
--testpref wikitext-103/wiki.test.tokens \
--destdir data-bin/wikitext-103 \
--workers 20
出力結果は以下の通りです。
~省略~
2023-08-02 15:09:54 | INFO | fairseq_cli.preprocess | [None] Dictionary: 267744 types
2023-08-02 15:10:41 | INFO | fairseq_cli.preprocess | [None] wikitext-103/wiki.train.tokens: 1801350 sents, 103227021 tokens, 0.0% replaced by <unk>
2023-08-02 15:10:41 | INFO | fairseq_cli.preprocess | [None] Dictionary: 267744 types
2023-08-02 15:10:45 | INFO | fairseq_cli.preprocess | [None] wikitext-103/wiki.valid.tokens: 3760 sents, 217646 tokens, 0.0% replaced by <unk>
2023-08-02 15:10:45 | INFO | fairseq_cli.preprocess | [None] Dictionary: 267744 types
2023-08-02 15:10:50 | INFO | fairseq_cli.preprocess | [None] wikitext-103/wiki.test.tokens: 4358 sents, 245569 tokens, 0.0% replaced by <unk>
2023-08-02 15:10:50 | INFO | fairseq_cli.preprocess | Wrote preprocessed data to data-bin/wikitext-103
学習
torchscale/examples/fairseq/のtrain.pyを使用し、学習を行います。
wsl上で、以下のコマンドを入力します。(所要時間:30時間程度)
# GPUが複数ある場合等、以下の方法でCUDA_VISIBLE_DEVICESを設定
# export CUDA_VISIBLE_DEVICES=0
# printenv
# カレントディレクトリはfairseq/examples/
cd ../../../
cd torchscale/examples/fairseq/
python train.py ../../../fairseq/examples/language_model/data-bin/wikitext-103 \
--task language_modeling \
--save-dir checkpoints/transformer_wikitext-103 \
--arch retnet_base --share-decoder-input-output-embed \
--save-interval 5 \
--dropout 0.1 \
--optimizer adam --adam-betas '(0.9, 0.98)' --weight-decay 0.01 --clip-norm 0.0 \
--lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \
--max-tokens 2048 --update-freq 16 \
--fp16 \
--batch-size 4 \
--max-update 50000 \
--tokens-per-sample 512
出力結果は以下の通りです。
2023-08-03 22:07:58 | INFO | fairseq.checkpoint_utils | Preparing to save checkpoint for epoch 16 @ 50000 updates
2023-08-03 22:07:58 | INFO | fairseq.trainer | Saving checkpoint to checkpoints/transformer_wikitext-103/checkpoint_last.pt
2023-08-03 22:08:07 | INFO | fairseq.trainer | Finished saving checkpoint to checkpoints/transformer_wikitext-103/checkpoint_last.pt
2023-08-03 22:08:07 | INFO | fairseq.checkpoint_utils | Saved checkpoint checkpoints/transformer_wikitext-103/checkpoint_last.pt (epoch 16 @ 50000 updates, score 4.966) (writing took 9.198124899994582 seconds)
2023-08-03 22:08:07 | INFO | fairseq_cli.train | end of epoch 16 (average epoch stats below)
2023-08-03 22:08:07 | INFO | train | epoch 016 | loss 4.916 | ppl 30.19 | wps 21841.2 | ups 0.67 | wpb 32767.9 | bsz 64 | num_updates 50000 | lr 0.000141421 | gnorm 0.57 | loss_scale 16 | train_wall 4063 | cuda_gb_allocated 8.4 | cuda_gb_reserved 9.3 | cuda_gb_free 3.6 | wall 110664
2023-08-03 22:08:07 | INFO | fairseq_cli.train | done training in 110662.3 seconds
評価
torchscale/examples/fairseq/のgenerate.pyを使用し、評価を行います。
wsl上で、以下のコマンドを入力します(所要時間:10分程度)。
# カレントディレクトリはtorchscale/examples/fairseq/
python generate.py ../../../fairseq/examples/language_model/data-bin/wikitext-103 \
--path checkpoints/transformer_wikitext-103/checkpoint_best.pt \
--task language_modeling \
--tokens-per-sample 512 \
--max-tokens 512 \
--batch-size 2 \
--fp16 \
--beam 1
出力結果は以下の通りです。
~省略~
2023-08-06 08:32:53 | INFO | fairseq_cli.generate | NOTE: hypothesis and token scores are output in base 2
2023-08-06 08:32:53 | INFO | fairseq_cli.generate | Translated 480 sentences (37,031 tokens) in 441.4s (1.09 sentences/s, 83.90 tokens/s)
Generate test with beam=1: BLEU4 = 0.32, 98.5/97.1/96.2/95.4 (BP=0.003, ratio=0.149, syslen=36553, reflen=245553)
推論
こちらのREADMEのExample usageに記載の方法を、一部変更して使用します。
- /torchscale/examples/fairseq内にmain.pyを作成し、以下の内容を記載します。
# main.py import torch from models.retnet import RetNetLanguageModel custom_lm = RetNetLanguageModel.from_pretrained('checkpoints/transformer_wikitext-103/', 'checkpoint_best.pt') # print(custom_lm.eval()) # print(custom_lm.sample('Barack Obama', beam=1, sampling=False, sampling_topk=10, temperature=0.8)) print(custom_lm.sample('Barack Obama', beam=2, sampling=False, temperature=0.1, no_repeat_ngram_size=2)) # "Barack Obama (...)"
- wsl上で、以下のコマンドを入力します
# 本記事の「環境構築」章で作成した.wsl_env仮想環境を起動しておく # カレントディレクトリはtorchscale/examples/fairseq/ # main.py実行時にdict.txtが必要なため、以下のコマンドでfairseq/examples/language_model/data-bin/wikitext-103/dict.txtをコピーする cp ../../../fairseq/examples/language_model/data-bin/wikitext-103/dict.txt checkpoints/transformer_wikitext-103/dict.txt # 推論実行 python main.py
- 以下のような出力が得られます。
Barack Obama was the first to be elected to the United States Senate in the 2008 election .
TODO
- 日本語データセットを使用した学習・評価
- beamパラメータを3以上にするとエラーが発生する原因の調査
- tensorboardXで可視化(この辺りの内容参照)
Discussion
私もこの問題に直面しました.以下の部分を変更すると一応動きました.この修正方法が正しいのかどうかは確証がありません.
ちなみに,
--beam 2
のときは本来 ヘッド毎に掛かるはずのdecay
がビーム毎に掛かるので,動きはしますが想定外の処理になっているように思いました.