Closed14

cpprb のデータ保存機能のための調査・検討

ユーザーからの機能リクエスト

https://github.com/ymd-h/cpprb/discussions/10

s_ts_{t+1} をセットで取り扱うことに起因する重複をメモリ上だけではなく、ファイルに保存するときにもなんとかしたい。

簡単に実験をした限りでは、サイズが大きくなってくると圧縮アルゴリズムは重複データを見つけられなくなるので、はじめから重複部分を排除しておくことは効果があると思われる。

検討ポイント

最も重要なことは、後から再利用できること。

メモリ削減のために、内部データでは重複を排除しているが、内部データ構造は将来の互換性を保証していない(し、保証できない)ので、そのまま保存する形にはしたくない

Nstep 対応

Nstep 報酬利用時には、\lbrace s_t, s_{t+1}\rbrace \to \lbrace s_t, s_{t+N}\rbrace などと置き換えて保存している。

そのまま保存して、何も考えずに再度詰め直すと、Nstep の効果が2重に掛けられてしまう。

MPReplayBuffer 対応

マルチプロセスにおけるグローバルバッファとしての利用を前提とした MPReplayBuffer 自体には Nstep などの機能は無い。

しかし、ローカルバッファで、Nstep補正をしたものを取り込んでいる可能性があり、 MPReplayBuffer からはそのオン・オフを判別することはできない。

MPReplayBuffer から取り出したデータを MPReplayBuffer に再度読み込む場合であれば問題は無いが、MPReplayBuffer から取り出して、ReplayBuffer に読み込む、またはその逆の時にどうすれば良いか?

友人のアドバイスによると、マルチプロセスまで利用して行動方策に近いデータを大量に集めるスキームでは、データを保存するというニーズとは相容れないとのことだったので、一旦 MPReplayBuffer は対象外に

追加で要望が出てから検討しよう。

使えそうなもの

  1. pickle
    • 標準
    • 効率の良い protocol=5 が Python 3.8 から利用可能に (3.6 ユーザーもまだいるのでは??)
  2. joblib
    • 複数の圧縮アルゴリズムを採用しており、使い勝手も良い
    • 依存ライブラリが増える
    • TensorFlow等が要求するバージョンと不整合がおきる可能性はないか?
  3. numpy.savez_compressed
    • NumPyが提供する圧縮保存機能

pickle は protocol が上がっても、圧縮はしなさそう。(protocol が小さいと元のデータよりも大きくなっているが。)

効率的とはPEPを読む限りでは、書き込みサイズを大きくしたり (protocol=4)、 無駄なコピーを防いだり (protocol=5) ということみたい。

おそらく joblib の採用が有力。

標準のライブラリの組合せで、効率よく pickleしつつ圧縮掛けられるならそちらにしたい。

データサイズが大きいので、無駄なコピーは避けたい。

jpblib にしようかと思っていたが、依存ライブラリが増えるのがやっぱり気になったので、 numpy.savez_compressed にする。

ndarray じゃないといけないと思って、候補から外そうかと思っていたが、よく考えたら object 型の ndarray がありましたね。(使えることも確認済み)

API

save_transitions(self, file,* , safe=True) / load_transitions(self, file) が有力

PER の priority や、Nstep の途中のデータなど保存できないデータがあるので、バッファそのものを保存すると誤解を与えそうな save / load とはしない。

デフォルト値として、 safe=True を指定し、ユーザーが意図的に safe=False を指定したときのみ next_of などの内部データ構造に応じた積極的な圧縮を実施する。

(たぶん内部は joblib にするつもりだが、)joblibに渡す圧縮パラメータをユーザーが指定できるようにするべきか?
自由度が上がる反面、将来ライブラリを移行したいと思ったときの障害となりうる。

Format

numpy.savez_compressed に以下のキーで渡す。

以下、案。(適宜修正する)

キー 備考
safe True / False
version 1 将来の変更のため。知らないバージョンには、読み込み時に ValueError を投げる
data dict[str, np.ndarray] safe=True なら get_all_transitions()safe=False ならself.buffer
Nstep True / False 読み込み先バッファに、Nstepのサイズなどを正しく設定するのはユーザーの責任。Nstepの利用の有無が不整合がある際には、読み込み時に ValueError を投げる
cache dict or None safe=Falseならself.cache
next_of np.ndarray or None safe=Falseならself.next_of
このスクラップは4ヶ月前にクローズされました
ログインするとコメントできます