【強化学習】cpprb に遷移のファイル保存機能を追加
こちらの Scrap で調査・検討していた機能をリリースしたので、記事化。
1. きっかけ
off-policy の強化学習でReplay Bufferに一時的に保存しておく遷移は、
個人開発している cpprb では、メモリ上の同じ位置を参照するように内部データを保持することでメモリ使用量を削減する機能をオプションで提供しています。
(この機能に触れた記事が2ヶ月前から書きかけの下書きのまま忘れられてることに気が付きました。。。)
内部に保持している遷移を取り出す際には、(普通は)ランダムに取り出すため、それぞれが独立したデータとして取り出されます。そのため、全データを取り出してファイルに保存するという戦略をとると、
そんな状況で、ユーザーからの「ファイルサイズ小さくしたい」というフィードバックがきっかけになって今回の機能開発に取り組みました。
2. 調査
2.1 そもそも、必要?
ファイル保存ライブラリの多くは圧縮をかけるので、実は面倒なことをしなくても圧縮アルゴリズムが重複データを排除してくれるかもしれない・・・そんな淡い期待は見事に破られました。
ファイルサイズが小さい時には、排除してくれましたが大きくなってくると、重複領域を見つけられませんでした。
2.2 使えそうなライブラリ(または関数)
-
pickle
- Pythonの標準ライブラリで、データの直接化(つまり、ファイル保存できる形式に変換)を行う。(他のライブラリも内部的に
pickle
を利用していることが多い) - 圧縮機能自体は無いが、別のライブラリ(標準ライブラリ含む)で実施できる
- Pythonの標準ライブラリで、データの直接化(つまり、ファイル保存できる形式に変換)を行う。(他のライブラリも内部的に
-
joblib.dump
- 複数の圧縮アルゴリズムを採用しており、使い勝手・性能ともに良い
- TensorFlow など他のライブラリからも使われるので、依存関係の複雑化になりうる
-
numpy.savez_compressed
- 複数の
ndarray
オブジェクトを一括して保存・圧縮する -
dtype=object
のndarray
も (内部でpickle
を使うことで) 利用可能
- 複数の
-
PyArrow
- 列指向フォーマットの parquet形式を扱う Apache Arrow の Python バインディング
-
fastparquet
- 同じく parquet 形式を扱うライブラリ。
- 並列分析ライブラリの dask が開発
joblib かなぁと実装を進めていましたが、新しい依存ライブラリを追加することになるので、途中で方針転換をして、 numpy.savez_compressed を利用することにしました。
3. 使い方
公式ドキュメントも参照
3.1 API
ReplayBuffer.save_transitions(self, file, *, safe=True)
と ReplayBuffer.load_transitions(self, file)
として実装しました。
デフォルトの safe=True
の時には、遷移を一旦取り出して保存するのと同じことをしています。ユーザーが safe=False
と手動で指定したときに始めて、上で書いたような重複排除の状態でデータを保存します。なぜなら、この圧縮は内部構造に依存しており、(可能な限り変換等で対応しますが)将来の変更時に影響を受ける可能性があるからです。safe=True
であれば、最悪ユーザーが手作業で中のデータをサルベージすることも可能であるとも思われます。
3.2 Example
from cpprb import ReplayBuffer
rb1 = ReplayBuffer(256,
{"obs": {"shape": 3}, "act": {},
"rew": {}, "done": {}},
next_of="obs")
# 保存する遷移を詰める。
# (本当は、エージェントを行動させて得た遷移をですが、ここでは適当に)
rb1.add(obs=[1, 2, 3], act=1, rew=1, next_obs=[2, 3, 4], done=0)
# ファイル拡張子は、 ".npz" (違うと勝手に付け足されます。)
rb1.save_transitions("transitions.npz")
# 通常はここでプログラムが終わって、別のプログラムで読み出す。
# 新しく Replay Buffer を(互換性のある構成で)作る。
rb2 = ReplayBuffer(256,
{"obs": {"shape": 3}, "act": {},
"rew": {}, "done": {}},
next_of="obs")
# ファイルを読み込む。(上書きではなく、追加。)
rb2.load_transitions("transitions.npz")
4. 余談
4.1 実は。。。
実は、もっと昔からファイルへの保存機能については要望が上がっていました。
だけれども、できないからReplay Bufferから遷移を取り出して自分で保存してと対応してきました。
(当時は、全てのデータを取り出す get_all_transitions(self)
すらなかったので、内部APIの _encode_sample(self, idx)
を使ってという、ちょっと微妙な回答ですが。)
というのも、ファイルにデータを保存するとなると、プログラムが動いている間の整合性だけではなくて、将来的な互換性についても考える必要があり、他の方法で実現できるのであれば、こちらで提供することがためらわれたからです。
今回は、単純な利便性だけの話ではなく、こちらで提供しないと重複排除は実現できないという理由があり、重い腰を上げて実装に取り組みました。(この機能が、将来の改善における足かせとならないことを期待します。)
4.2 バグが。。。
機能開発している中で、ベンチマークとして比較対象に入れている Ray/RLlib のバグに遭遇して issue を報告しておきました。メンバー変数名の Refactoring 時のミスだと思われます。(Prioritizedの方では問題なかったので、使われずに見落とされていたのだろうなぁと。)
4.3 テスト
今回は、1機能の実装ですが、影響が大きいので、新しくテスト用のファイルを作って、いつもよりかなり多めのテストを書いたつもりですが、はてさて。
Discussion