🔖

RetNetによる学習・評価・推論

2023/08/06に公開1

概要

Retentive Network(RetNet)はMicrosoft ResearchとTsinghua Universityが共同で発表した大規模言語モデル向けアーキテクチャであり、その学習速度の速さやメモリ効率の高さから、Transformerの後継モデルとして注目されています。
RetNetの公式ソースコードは、こちらにアップロードされています。ただし、8/6現在では、このソースコードの使用方法の記載は見当たりません。
そこで本記事では、この公式のソースコードを使用し、wiki-text103データセットをRetNetで学習・評価・推論する方法を説明します。

検証環境

項目 バージョン 備考
OS Windows 10
GPU RTX4070 VRAM12GB
python 3.10.4 pyenv localにて設定
ubuntu 18.04.05 wslにて使用。wsl2でもおそらく使用可能。

諸注意

本記事のコードブロック内のコマンドは、基本的にubuntu 18.04.05 on Windowsターミナル(以下、wsl)上で実行しています。
本記事の誤記や、より良い方法などがありましたら、コメント欄にてご指摘ください。

環境構築

  1. こちらの通りに、wslとpyenv-winの競合を解消し、wsl用のpyenvをinstallします

  2. 以下のコマンドを実行し、本記事用のフォルダとpython仮想環境を構築します

    mkdir RetNetTutorial
    cd RetNetTutorial
    mkdir .wsl_env
    cd .wsl_env
    pyenv local 3.10.4
    python -V
    # 3.10.4
    # 仮想環境構築
    python -m venv ./
    # pipを新しくしておく
    cd bin
    # 仮想環境起動
    source activate
    cd ../../
    
    # 以下は実行しない。備忘録。
    # PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/usr/lib/wsl/lib:/snap/bin"
    # source /etc/profile
    # source ~/.bashrc
    
  3. 以下のコマンドを実行し、必要なライブラリをインストールします

    # fairseqのshellスクリプトを使用するため、git cloneする
    git clone https://github.com/pytorch/fairseq
    git clone https://github.com/microsoft/torchscale.git
    cd torchscale
    pip install -e ./
    cd ../
    pip install git+https://github.com/shumingma/fairseq.git@moe
    pip install git+https://github.com/shumingma/infinibatch.git
    pip install iopath
    pip install numpy==1.23.0
    pip install tensorboardX
    pip install sentencepiece
    pip install boto3	
    # pip freeze > requirements.txt
    

データセットの作成

学習用データセットは、wikiText-103データセットを使用します。
Fairseq/examples/language_modelにprepare-wikitext-103.shというシェルスクリプトがあるので、こちらを使用します。

cd fairseq/examples/language_model/
sudo apt-get install dos2unix
# wslの場合、dos2unixを実行しないと$'\r': command not foundエラーが発生する
dos2unix prepare-wikitext-103.sh
bash prepare-wikitext-103.sh

前処理

fairseq-preprocessを使用し、データセットをバイナリ化します。
wsl上で、以下のコマンドを入力します。

# カレントディレクトリは/RetNetTutorial/fairseq/examples/language_model/
fairseq-preprocess \
--only-source \
--trainpref wikitext-103/wiki.train.tokens \
--validpref wikitext-103/wiki.valid.tokens \
--testpref wikitext-103/wiki.test.tokens \
--destdir data-bin/wikitext-103 \
--workers 20

出力結果は以下の通りです。

~省略~
2023-08-02 15:09:54 | INFO | fairseq_cli.preprocess | [None] Dictionary: 267744 types
2023-08-02 15:10:41 | INFO | fairseq_cli.preprocess | [None] wikitext-103/wiki.train.tokens: 1801350 sents, 103227021 tokens, 0.0% replaced by <unk>
2023-08-02 15:10:41 | INFO | fairseq_cli.preprocess | [None] Dictionary: 267744 types
2023-08-02 15:10:45 | INFO | fairseq_cli.preprocess | [None] wikitext-103/wiki.valid.tokens: 3760 sents, 217646 tokens, 0.0% replaced by <unk>
2023-08-02 15:10:45 | INFO | fairseq_cli.preprocess | [None] Dictionary: 267744 types
2023-08-02 15:10:50 | INFO | fairseq_cli.preprocess | [None] wikitext-103/wiki.test.tokens: 4358 sents, 245569 tokens, 0.0% replaced by <unk>
2023-08-02 15:10:50 | INFO | fairseq_cli.preprocess | Wrote preprocessed data to data-bin/wikitext-103

学習

torchscale/examples/fairseq/のtrain.pyを使用し、学習を行います。
wsl上で、以下のコマンドを入力します。(所要時間:30時間程度)

# GPUが複数ある場合等、以下の方法でCUDA_VISIBLE_DEVICESを設定
# export CUDA_VISIBLE_DEVICES=0
# printenv

# カレントディレクトリはfairseq/examples/
cd ../../../
cd torchscale/examples/fairseq/
python train.py ../../../fairseq/examples/language_model/data-bin/wikitext-103 \
--task language_modeling \
--save-dir checkpoints/transformer_wikitext-103 \
--arch retnet_base --share-decoder-input-output-embed \
--save-interval 5 \
--dropout 0.1 \
--optimizer adam --adam-betas '(0.9, 0.98)' --weight-decay 0.01 --clip-norm 0.0 \
--lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \
--max-tokens 2048 --update-freq 16 \
--fp16 \
--batch-size 4 \
--max-update 50000 \
--tokens-per-sample 512 

出力結果は以下の通りです。

2023-08-03 22:07:58 | INFO | fairseq.checkpoint_utils | Preparing to save checkpoint for epoch 16 @ 50000 updates
2023-08-03 22:07:58 | INFO | fairseq.trainer | Saving checkpoint to checkpoints/transformer_wikitext-103/checkpoint_last.pt
2023-08-03 22:08:07 | INFO | fairseq.trainer | Finished saving checkpoint to checkpoints/transformer_wikitext-103/checkpoint_last.pt
2023-08-03 22:08:07 | INFO | fairseq.checkpoint_utils | Saved checkpoint checkpoints/transformer_wikitext-103/checkpoint_last.pt (epoch 16 @ 50000 updates, score 4.966) (writing took 9.198124899994582 seconds)
2023-08-03 22:08:07 | INFO | fairseq_cli.train | end of epoch 16 (average epoch stats below)
2023-08-03 22:08:07 | INFO | train | epoch 016 | loss 4.916 | ppl 30.19 | wps 21841.2 | ups 0.67 | wpb 32767.9 | bsz 64 | num_updates 50000 | lr 0.000141421 | gnorm 0.57 | loss_scale 16 | train_wall 4063 | cuda_gb_allocated 8.4 | cuda_gb_reserved 9.3 | cuda_gb_free 3.6 | wall 110664
2023-08-03 22:08:07 | INFO | fairseq_cli.train | done training in 110662.3 seconds

評価

torchscale/examples/fairseq/のgenerate.pyを使用し、評価を行います。
wsl上で、以下のコマンドを入力します(所要時間:10分程度)。

# カレントディレクトリはtorchscale/examples/fairseq/
python generate.py ../../../fairseq/examples/language_model/data-bin/wikitext-103 \
--path checkpoints/transformer_wikitext-103/checkpoint_best.pt \
--task language_modeling \
--tokens-per-sample 512 \
--max-tokens 512 \
--batch-size 2 \
--fp16 \
--beam 1

出力結果は以下の通りです。

~省略~
2023-08-06 08:32:53 | INFO | fairseq_cli.generate | NOTE: hypothesis and token scores are output in base 2
2023-08-06 08:32:53 | INFO | fairseq_cli.generate | Translated 480 sentences (37,031 tokens) in 441.4s (1.09 sentences/s, 83.90 tokens/s)
Generate test with beam=1: BLEU4 = 0.32, 98.5/97.1/96.2/95.4 (BP=0.003, ratio=0.149, syslen=36553, reflen=245553)

推論

こちらのREADMEのExample usageに記載の方法を、一部変更して使用します。

  1. /torchscale/examples/fairseq内にmain.pyを作成し、以下の内容を記載します。
    # main.py
    import torch
    from models.retnet import RetNetLanguageModel
    custom_lm = RetNetLanguageModel.from_pretrained('checkpoints/transformer_wikitext-103/', 'checkpoint_best.pt')
    # print(custom_lm.eval())
    # print(custom_lm.sample('Barack Obama', beam=1, sampling=False, sampling_topk=10, temperature=0.8))
    print(custom_lm.sample('Barack Obama', beam=2, sampling=False, temperature=0.1, no_repeat_ngram_size=2))
    # "Barack Obama (...)"
    
  2. wsl上で、以下のコマンドを入力します
    # 本記事の「環境構築」章で作成した.wsl_env仮想環境を起動しておく
    # カレントディレクトリはtorchscale/examples/fairseq/
    # main.py実行時にdict.txtが必要なため、以下のコマンドでfairseq/examples/language_model/data-bin/wikitext-103/dict.txtをコピーする
    cp ../../../fairseq/examples/language_model/data-bin/wikitext-103/dict.txt checkpoints/transformer_wikitext-103/dict.txt
    # 推論実行
    python main.py
    
  3. 以下のような出力が得られます。
    Barack Obama was the first to be elected to the United States Senate in the 2008 election .
    

TODO

  • 日本語データセットを使用した学習・評価
  • beamパラメータを3以上にするとエラーが発生する原因の調査
  • tensorboardXで可視化(この辺りの内容参照)

参考サイト

Discussion

yuji96yuji96

beamパラメータを 3 以上にするとエラーが発生する原因の調査

私もこの問題に直面しました.以下の部分を変更すると一応動きました.この修正方法が正しいのかどうかは確証がありません.

            prev_scale = prev_scale.view(bsz, -1, 1, 1)
            decay = decay.view(-1, self.num_heads, 1, 1)
            scale = prev_scale * decay + 1
            kv = prev_kv * (prev_scale.sqrt() * decay / scale.sqrt()) + kv / scale.sqrt()

https://github.com/microsoft/torchscale/blob/70e047a53bddda4e439ce453886ca52d65115560/torchscale/component/multiscale_retention.py#L107-L108

# 参考: --beam 7 --arch retnet_base (num_heads=2) のときの shape
print(prev_scale.shape)  # torch.Size([7, 1, 1, 1])  エラー時は torch.Size([7])
print(decay.shape)       # torch.Size([1, 2, 1, 1])  エラー時は torch.Size([2])
print(scale.shape)       # torch.Size([7, 2, 1, 1])
print(prev_kv.shape)     # torch.Size([7, 2, 512, 256])
print(kv.shape)          # torch.Size([7, 2, 512, 256])

ちなみに,--beam 2 のときは本来 ヘッド毎に掛かるはずの decay がビーム毎に掛かるので,動きはしますが想定外の処理になっているように思いました.