操作ログ方式でOptuna用ストレージを実装してみた話
これはOptuna Advent Calendar 2021の18日目の記事です。
以前に、現在のOptuna(v2.10.0)で採用されているものとは異なる設計で、ストレージ部分の実装を試したことがあるので、その紹介となります。
リポジトリはsile/optjournalです。
TL;DR
- 現在のOptunaのストレージ実装は、StudyやTrialの状態をそのまま格納する方式
- RDB等へのI/O負荷を減らすためにキャッシュの仕組みがある
- optjournalは、StudyやTrialに対する操作のログ(ジャーナル)を追記していく方式
- optjournalのメリット
- シンプルな仕組みで、I/O回数をほぼ最小限にできる
- 新しいバックエンドの追加が比較的容易
- 現時点ではRDB以外に、ファイルシステムバックエンドにも対応(NFSでも動くはず)
- optjournalのデメリット
- StudyやTrialの一部の状態だけしか必要がない時に非効率
- 性能が大幅に劣化する病的なケースが存在する
- 上記「メリット」は開発者視点のもので、ユーザの立場なら、普通にOptuna公式ストレージの使用を推奨
※ 現在Optunaではv3ロードマップにもとづく開発が進行中で、その中にはストレージに関する改善も含まれているため、v3では本記事の内容が必ずしも当てはまらない可能性があります
Optuna(v2)のストレージについて
まずは、Optunaの現在のストレージについての説明から始めます。
Optunaでのストレージの位置づけ
次の図はOptunaの論文から引用したものですが、一番左側にストレージがあり、全ての構成要素が最終的にはそこと繋がっているのが見て取れます。
図中の「Suggest Algo」と「Pruning Algo」は、それぞれSamplerとPrunerに該当しますが、これらのクラスは、過去および実行中の試行についての結果や状態を直接は(インスタンス変数として)保持せずに、必要に応じてストレージから読み書きする形となっています。
このような設計によって、複数のPythonプロセスからでも、ストレージ(とスタディ名)さえ共有していれば、簡単に分散最適化が行えるようになっています。
BaseStorage
とその実装クラス
Optunaで利用可能なストレージクラスを実装する場合には、BaseStorageというクラスを継承して、その抽象メソッド群を実装する必要があります。
BaseStorage
には、現在23個の抽象メソッドがあり、以下はその一例になります:
BaseStorage のメソッド名 |
対応する公開API |
---|---|
create_new_study | optuna.create_study() |
get_study_user_attrs | Study.user_attrs |
create_new_trial | Study.optimize() or Study.ask() |
get_all_trials | Study.trials |
set_trial_param | Suggest API (e.g., Trial.suggest_float()) |
set_trial_intermediate_value | Trial.report() |
上の表の右側に記載の通り、BaseStorage
のほとんどの抽象メソッドには、ユーザが触る公開APIに直接的な対応物があり、Study
やTrial
の機能を比較的素直に反映しているインタフェースとなっています。
公式に提供されているBaseStorage
の実装クラスは、以下の通りです:
- InMemoryStorage: デフォルトで使われるメモリ上のストレージ
- RDBStorage: RDB(e.g., SQLite, PostgreSQL, MySQL)をバックエンドとしたストレージ
- RedisStorage: Redisをバックエンドとしたストレージ(experimental)
他にも、ユーザには露出していない内部的なクラスとして_CachedStorageというものも存在しますが、これについてはまた後で触れます。
これ以降では、永続化や分散最適化に対応しているものの中で一番利用されているRDBStorage
を中心として話を進めていきます。
※ キャッシュの話の一部を除いて、基本的にはRedisStorage
にも当てはまる話が多いです
現在のストレージの設計・実装
概要
RDBStorage
の設計および実装の概要は、以下のようになります:
-
Study
およびTrial
の状態を格納するためのテーブル群が存在する- e.g.,
StudyModel
,StudyUserAttributeModel
,TrialModel
,TrialParamModel
- 基本的には「RDBに
Study
やTrial
の実体があり、必要に応じて読み書き」しているだけ-
Study.set_user_attr()
やTrial.report()
等の更新系処理は、呼び出し毎に上記の(対応するRDBStorage
のメソッドを経由して)RDBテーブルを更新する -
Study.user_attrs
やStudy.trials
、Trial.params
等は取得系処理はその反対 - ※ キャッシュの関係で実際にはもう少し複雑になる
-
- e.g.,
- 特徴的な点としては、Optunaには「実行が終わった
Trial
の状態は変更できない(削除もない)」およびという仕様があり、これによってRDBStorage
の前段にインメモリのキャッシュを置くことが可能となっています(詳細は後述)- 他にも「実行中の
Trial
の状態を更新するのは一度に一プロセスだけ(同時更新不可)」という要求もあったりします - この辺りの詳細に興味がある人はBaseStorageのドキュメントを参照してみてください
- 他にも「実行中の
Sampler
およびPruner
のストレージ利用方法
ストレージの性能特性を把握するためには、Sampler
やPruner
が、ハイパラのサンプリング(suggest)やトライアルの枝刈り判定を行う際に、ストレージをどのように利用しているかも知る必要があります。
その際の流れは、基本的には以下のようになります:
-
Study.get_trials()
(orStorage.get_all_trials()
)を使用して、同じスタディ内の全トライアル(試行履歴)を取得 - その履歴を使って、次に探索すべき点や枝刈りすべきかどうかを決定する
※RandomSampler
やGridSampler
、ThresholdPruner
のようにトライアルの履歴を必要としないものは例外
つまり「Trial.suggest_float()
やTrial.should_prune()
が呼ばれる度に、内部ではStorage.get_all_trials()
が実行されている」と考えてもらえれば、おおむね大丈夫です。
ただしget_all_trials()
呼び出しの度にRDBにクエリを発行していては、I/Oの量がトライアル数に対して線形に増加していってしまい大変なことになるので、Optunaではキャッシュの導入によって、ほぼ定数オーダーのI/Oで済むように工夫が行われています。
ストレージのキャッシュ
RDBStorage
用のキャッシュ機能は_CacheStorage
で実装されており、get_all_trials()
のキャッシュの仕組みは、ざっくりと言えば、以下のようになります:
- 一番最初の呼び出しでは、普通にRDBから全トライアルのデータを取得
- その中で
TrialState
がCOMPLETE
、PRUNED
、FAIL
のいずれかだったトライアルについては、PythonのDict
に保存しておく - 次回の
get_all_trials()
呼び出しでは、上記のDict
に格納されているトライアルは除外するように、RDBにクエリを発行する - 新規取得分と既存キャッシュ分をマージして呼び出し元に返す
- 以後は、2~4を繰り返す
「概要」部分で書いた通り、Optunaでは終了したトライアルの状態が更新されることはないため、このようなキャッシュが安全に行えます。
これによって、最初の一回の呼び出しを除けば、get_all_trials()
では、その時点で実行中のトライアルのデータのみがRDBから取得されるようになりました(トライアル数に対してではなく、分散最適化の同時実行数への比例で済むようになった)。
なお、上で触れたもの以外にも、Study
関連だったり、取得ではなく更新時のキャッシュ等もあったり、その他細かい話はいろいろあるのですが、今回は割愛します。
キャッシュ周りの詳細に興味がある人は、以下のPRを参照してみてください:
- optuna#349 Add cache to RDBStorage
- optuna#1140 Add storage cache
- optuna#1263 Move caching mechanism from
RDBStorage
to_CacheStroage
- optuna#1264 Cache study-related info in
_CacheStorage
- optuna#3112 Remove flush update functionality from _CachedStorage
コード量
後でoptjournalとの比較に使うためRDBStorage
およびRedisStorage
のコード行数も載せておきます:
$ pwd
/.../optuna/optuna/storages
# `RDBStorage`系 (コメントと空行は除外)
$ grep -c -v -E '^[ ]*#|^[ ]*$' _cached_storage.py _rdb/models.py _rdb/storage.py
_cached_storage.py:358
_rdb/models.py:384
_rdb/storage.py:1090 # 合計: 1832
# `RedisStorage`系
$ grep -c -v -E '^[ ]*#|^[ ]*$' _redis.py
511
課題(私見)
以上で、Optunaの現在のストレージ実装(特にRDBStorage
)についての説明は終わりとなります。
現在の実装でも、キャッシュを活用することで、不要なI/Oを極力削減するように頑張られていますが、個人的にはいくつか課題も残っていると感じるため、最後にそれを書いておきます:
- キャッシュの仕組みは上手く機能してはいるけれど複雑度は増す
- これは「RDBに
Study
やTrial
の状態をそのまま格納」している以上、仕方がない問題- 採用した設計に起因する問題なので、抜本的に改善しようとすると
RDBStorage
等の設計自体を変える必要がでてくる
- 採用した設計に起因する問題なので、抜本的に改善しようとすると
- また(キャッシュ導入のために追加された)「終わったトライアルは更新できない」という制約が困る場面もまれにあったりする(e.g., optuna#2936)
- これは「RDBに
-
BaseStorage
を継承した自前ストレージを実装するのが大変- 要望としてはたまにある
- e.g., 初期の頃にFirestore用のストレージをユーザが実装していた (optuna#309)
-
RedisStorage
もユーザのPRによって導入されたもの (optuna#974)
- ただし、これまでに書いた通り
BaseStorage
には23個の実装すべき抽象メソッドがあったり、効率を出すためにキャッシュのことも考慮が必要だったり、分散最適化(マルチスレッド含む)実行時にも問題なく動作するようにする必要がある、etcで、完全なものを実装するのは容易ではない-
RedisStorage
のPRもマージは2020年4月だけれど、RDBStorage
の修正や機能追加に追従しきれておらず、まだexperimental扱いとなっている - ※ 補足ですが、Optunaのv3では、
BaseStorage
の要実装メソッド数を減らすことが検討されていたりします(optuna#2943)
-
- こちらについては設計と実装が原因の半々、といった印象
- 例えば、それぞれの永続化ストレージで共通の機能を括りだして別のクラスに分離することで、個々のストレージバックエンド(e.g., RDB, Redis)クラスの実装負荷は減らせる可能性はある(これはRDBのテーブル定義等には影響を与えずに、実装の工夫だけで対処可能な話)
- 要望としてはたまにある
optjournalは、こういった課題に対処する(方法を探る)ために開発されたライブラリです。
optjournalの設計や実装
ここからが、ようやく、本題のoptjournal (v0.0.2)の話となります。
状態の読み書きではなく、適用された操作の追記
optjournalは、RDBStorage
等とは異なり、RDB等のストレージバックエンドの中にStudy
やTrial
の完全な状態がそのまま格納されている訳ではありません。
その代わり、例えば「新しいトライアルを開始した」とか「ユーザ属性が更新された」等といった操作のログ(ジャーナル)を、RDB等に追記していく方式を採用しています。
具体的には、RDBの場合、以下の二つのテーブルのみが存在します(optunajournal/_models.py):
# Studyの名前とIDの対応を管理するためのテーブル
class StudyModel(_BaseModel):
__tablename__ = "optjournal_studies"
__table_args__ = (UniqueConstraint("name"),)
id = Column(Integer, primary_key=True)
name = Column(String(256), index=True, nullable=False)
# Study/Trialに適用された操作ログを保持するためのテーブル
class OperationModel(_BaseModel):
__tablename__ = "optjournal_operations"
id = Column(Integer, primary_key=True)
study_id = Column(Integer, ForeignKey("optjournal_studies.id"), index=True, nullable=False)
data = Column(String(4096), nullable=False) # [操作ID, 操作固有のデータ]形式のJSON
操作一覧はoptjournal/_operation.pyで定義されています:
class _Operation(enum.Enum):
CREATE_STUDY = 0
CREATE_TRIAL = 1
SET_STUDY_USER_ATTR = 2
SET_STUDY_SYSTEM_ATTR = 3
SET_STUDY_DIRECTIONS = 4
SET_TRIAL_PARAM = 5
SET_TRIAL_VALUES = 6
SET_TRIAL_USER_ATTR = 7
SET_TRIAL_SYSTEM_ATTR = 8
SET_TRIAL_STATE = 9
SET_TRIAL_INTERMEDIATE_VALUE = 10
Study
やTrial
の更新系メソッドが呼び出された際には、対応するエントリがストレージ内の操作ログに追記されていくことになります。
取得時の流れは、もう少し複雑で、ざっくりと言えば以下のようになります:
- 最初は、対象スタディの全操作ログを取得する
- 取得したログを、先頭から走査し、各操作をローカルに存在する
_Study
および_Trial
オブジェクトに反映する- e.g., 操作が
[SET_TRIAL_USER_ATTR, {"key": k, "value": v}]
なら_Trial.user_attrs[k] = v
を実行 -
_Trial
はoptuna.trial.FrozenTrial
の単純なラッパーで、_Study
も似たような感じ - ログの最後まで反映すれば、
Study
やTrial
の最新の状態がローカルで再構築できたことになる
- e.g., 操作が
- ログの最後の操作のIDを覚えておく
- 取得メソッドの呼び出し元には、構築された
_Study
および_Trial
から値を返す - 次の取得系メソッド呼び出し時も基本的には流れは同様
- ただし「全操作ログ」ではなく「3で覚えておいたIDよりも新しいもの」を対象にする(差分取得)
- これによって、RDB等から取得するデータ量は「新たに適用された更新系操作の数」にしか比例しなくなるので、ほぼ最小限となる(逆に効率が悪くなるケースについては後述)
細かい話は他にもいろいろとあるのですが、基本的な設計・方針としては以上となります。
以降では「なぜこの設計だと上述のRDBStorage
等の課題が解決・軽減できるのか」を見ていきます。
また、その後にはoptjournal側の問題点にも触れます。
optjournalのメリット
キャッシュが不要
上述の通り、optjournalは「永続化された操作ログ」と「そのログが反映されたローカルオブジェクト」の二つから構成されています。
操作ログの差分をローカルオブジェクトに反映さえしてしまえば、取得系メソッドではそのローカルオブジェクトの値を直接参照できるので、BaseStorage.get_all_trials()
のようなメソッドを複数回呼び出したとしても、I/Oという観点では、ほとんどコストが掛かりません。
そのため、RDBStorage
のように前段に別途キャッシュを設ける必要がなくなり、実装をシンプルにすることができるのがメリットの一つです。
(ローカルオブジェクトがある種のキャッシュの役割を果たしている、とも言えるかもしれません)
TrialState
に依存したキャッシュが不要となるので「終了したTrial
は更新不可」という制約を外すことも可能です。
なお、以前に雑に計測したベンチマークでは、optjournalのRDBを使用したストレージ実装の性能は、RDBStorage
とほぼ同等といった感じでした。
ストレージバックエンドの責務が少なく、自前実装の追加が比較的容易
optjournalが提供しているストレージ群も、Optunaのストレージとして利用可能になっています(以下はその例):
from optjournal import JournalStorage
import optuna
def objective(trial):
x = trial.suggest_float('x', 0, 1)
y = trial.suggest_float('y', 0, 1)
return x * y
storage = JournalStorage("sqlite:///optuna.db")
study = optuna.create_study(storage=storage)
study.optimize(objective, n_trials=100)
上のコード内のJournalStorage
は、当然BaseStorage
は継承し、必要なメソッド群を実装している訳なのですが、その辺りは個々のストレージバックエンド(e.g., RDB、ファイルシステム)とは独立した共通部分として切り出されています(ファイルとしてはoptjournal/_storage.py)。
各ストレージバックエンドに要求するインタフェースは、別にDatabase
というクラスで抽象化されています(optjournal/_database.py):
# 各バックエンドが実装する必要があるメソッドは以下の七個
class Database(object, metaclass=abc.ABCMeta):
def create_study(self, study_name: str) -> _models.StudyModel:
def find_study(self, study_id: int) -> Optional[_models.StudyModel]:
def find_study_by_name(self, study_name: str) -> Optional[_models.StudyModel]:
def delete_study(self, study_id: int) -> Optional[_models.StudyModel]:
def get_all_studies(self) -> List[_models.StudyModel]:
def append_operations(self, ops: List[_models.OperationModel]) -> None:
def read_operations(self, study_id: int, next_op_id: int) -> List[_models.OperationModel]:
このため、個々のバックエンドの実装は、以下のように小規模で済んでいます:
# RDBバックエンド
$ grep -c -v -E '^[ ]*#|^[ ]*$' optjournal/_rdb.py
103
# ファイルシステムバックエンド:
# - 各スタディに対して、別個のファイルが割り当てられ、そこに操作ログ追記される
# - NFSでも動作するように実装されている(はず)
$ grep -c -v -E '^[ ]*#|^[ ]*$' optjournal/_file_system.py
176
# 参考までに、optjournal全体の行数
$ grep -c -v -E '^[ ]*#|^[ ]*$' optjournal/*.py
optjournal/version.py:1
optjournal/_database.py:26
optjournal/_file_system.py:176
optjournal/_id.py:7
optjournal/_lazy_study_summary.py:56
optjournal/_models.py:24
optjournal/_operation.py:13
optjournal/_rdb.py:103
optjournal/_storage.py:255
optjournal/_study.py:177
optjournal/__init__.py:3 # 合計: 841行
※ Optunaのv2.5.0でRDBStorage
に導入されたハートビード機能には未対応だったりするので、その辺りも全部ちゃんと追従しようとするともう少し行数は増えるかもしれません(逆に、_Operation
enumへの追加+αくらいで済むなら、_rdb.py
や_file_system.py
の行数は変わらない可能性もあります)
これくらいなら、自前ストレージの追加も、そこそこ現実的な印象です。
optjournalのデメリット
ここまでは良い点ばかりに触れてきましたが、optjournal(の現状の実装)にはデメリットもあるため、次はそれらの紹介です。
Study
やTrial
の一部の状態だけしか必要がない時に非効率
RDBStorage
では、RDB内にStudy
やTrial
の完全な情報(スナップショット)が格納されています。そのため、その一部だけを取得する、といったことが極めて効率的に行えます。例えばユーザがStudy.direction
プロパティにアクセスした場合には、StudyDirectionModel
テーブルのdirection
カラムの値をRDBから取得するだけで済みます。
それに対してoptjournalのJournalStorage
は、Database
内には操作ログの情報しか存在しないため、一度それらを全て取得して、ローカルでStudy
やTrial
の情報を再構築する必要があります。
幸い、Optunaの通常のユースケース(e.g., 最適化や可視化)では、いずれにせよどこかのタイミングで、スタディ内の全てのトライアルへのアクセスが必要となることが普通なので、この最初のコストは償却できることがほとんどです。
ただし、ユーザが本当に(例えば)Study.direction
等の一部の情報のみが必要な場合には、無駄がとても大きくなります。
性能が大幅に劣化する病的なケースが存在する
他にも(単純な)操作ログ方式だと無駄が多くなるケースは存在します。
以下のシナリオを考えてみましょう:
import optuna
def objective(trial):
# 同じユーザ属性に大量の更新を行う
for i in range(0, 10000):
trial.set_user_attr("key", i)
return 0
study = optuna.create_study(storage=...)
study.optimize(objective, n_trials=1)
この場合、RDBStorage
やRedisStorage
では、対象のユーザ属性のエントリが更新されるだけです。
対して、JournalStorage
の場合には、一万個の操作ログエントリがDatabase
内に保存され、また再構築時にはそれらを全て順番に適用する必要があります。一番最後のエントリ以外は、単に捨てられるだけなので、これは明らかに無駄です。
これらの問題に対処することも不可能ではありません。
例えば、以下のようなスナップショット機能の導入が考えられます:
- 操作ログの数が一定数(e.g, 1000)溜まったら、その時点での
Study
の状態(スナップショット)を保存する- スナップショットの形式は、例えば
RDBStorage
のテーブル定義をそのまま使っても良い - あるいは、
Study
全体を単にJSONシリアライズした結果を保存しても良い- こちらの場合は「一部の情報だけを効率的に取得」には対応不可
- スナップショットの形式は、例えば
- スナップショットの取得が終わったら、それ以前の操作ログは削除する
- ローカルオブジェクトの初期値には、スナップショットの状態を反映する
- それ以外は、以前と同様
これにより上で取り上げた二つの性能問題を解消することは可能ではありますが、代わりにコードが複雑になってしまいます。
上述の問題点は、いずれもOptunaの典型的なユースケースでは発生しなさそう、ということもあり現在のoptjournalでは、特に対処しないままとなっています。
終わりに
長々と書いてきましたが、ユーザ側の視点からすれば、optjournalを採用する理由はほぼないと思うので、Optunaが公式で提供しているストレージを使用するのをお勧めします。
ただ、万が一需要があれば、もう少し真面目にリポジトリにドキュメントを書いたり、PyPIに登録したりするかもしれないので、issue等で教えて貰えればと思います🙏
Discussion