🙆

操作ログ方式でOptuna用ストレージを実装してみた話

2021/12/18に公開

これはOptuna Advent Calendar 2021の18日目の記事です。

以前に、現在のOptuna(v2.10.0)で採用されているものとは異なる設計で、ストレージ部分の実装を試したことがあるので、その紹介となります。
リポジトリはsile/optjournalです。

TL;DR

  • 現在のOptunaのストレージ実装は、StudyTrial状態をそのまま格納する方式
    • RDB等へのI/O負荷を減らすためにキャッシュの仕組みがある
  • optjournalは、StudyやTrialに対する操作のログ(ジャーナル)を追記していく方式
  • optjournalのメリット
    • シンプルな仕組みで、I/O回数をほぼ最小限にできる
    • 新しいバックエンドの追加が比較的容易
      • 現時点ではRDB以外に、ファイルシステムバックエンドにも対応(NFSでも動くはず)
  • optjournalのデメリット
    • StudyやTrialの一部の状態だけしか必要がない時に非効率
    • 性能が大幅に劣化する病的なケースが存在する
  • 上記「メリット」は開発者視点のもので、ユーザの立場なら、普通にOptuna公式ストレージの使用を推奨

※ 現在Optunaではv3ロードマップにもとづく開発が進行中で、その中にはストレージに関する改善も含まれているため、v3では本記事の内容が必ずしも当てはまらない可能性があります

Optuna(v2)のストレージについて

まずは、Optunaの現在のストレージについての説明から始めます。

Optunaでのストレージの位置づけ

次の図はOptunaの論文から引用したものですが、一番左側にストレージがあり、全ての構成要素が最終的にはそこと繋がっているのが見て取れます。
Overview of Optuna's system design

図中の「Suggest Algo」と「Pruning Algo」は、それぞれSamplerPrunerに該当しますが、これらのクラスは、過去および実行中の試行についての結果や状態を直接は(インスタンス変数として)保持せずに、必要に応じてストレージから読み書きする形となっています。
このような設計によって、複数のPythonプロセスからでも、ストレージ(とスタディ名)さえ共有していれば、簡単に分散最適化が行えるようになっています。

BaseStorageとその実装クラス

Optunaで利用可能なストレージクラスを実装する場合には、BaseStorageというクラスを継承して、その抽象メソッド群を実装する必要があります。
BaseStorageには、現在23個の抽象メソッドがあり、以下はその一例になります:

BaseStorageのメソッド名 対応する公開API
create_new_study optuna.create_study()
get_study_user_attrs Study.user_attrs
create_new_trial Study.optimize() or Study.ask()
get_all_trials Study.trials
set_trial_param Suggest API (e.g., Trial.suggest_float())
set_trial_intermediate_value Trial.report()

上の表の右側に記載の通り、BaseStorageのほとんどの抽象メソッドには、ユーザが触る公開APIに直接的な対応物があり、StudyTrialの機能を比較的素直に反映しているインタフェースとなっています。

公式に提供されているBaseStorageの実装クラスは、以下の通りです:

  • InMemoryStorage: デフォルトで使われるメモリ上のストレージ
  • RDBStorage: RDB(e.g., SQLite, PostgreSQL, MySQL)をバックエンドとしたストレージ
  • RedisStorage: Redisをバックエンドとしたストレージ(experimental)

他にも、ユーザには露出していない内部的なクラスとして_CachedStorageというものも存在しますが、これについてはまた後で触れます。

これ以降では、永続化や分散最適化に対応しているものの中で一番利用されているRDBStorageを中心として話を進めていきます。
※ キャッシュの話の一部を除いて、基本的にはRedisStorageにも当てはまる話が多いです

現在のストレージの設計・実装

概要

RDBStorageの設計および実装の概要は、以下のようになります:

  • StudyおよびTrialの状態を格納するためのテーブル群が存在する
    • e.g., StudyModel, StudyUserAttributeModel, TrialModel, TrialParamModel
    • 基本的には「RDBにStudyTrialの実体があり、必要に応じて読み書き」しているだけ
      • Study.set_user_attr()Trial.report()等の更新系処理は、呼び出し毎に上記の(対応するRDBStorageのメソッドを経由して)RDBテーブルを更新する
      • Study.user_attrsStudy.trialsTrial.params等は取得系処理はその反対
      • ※ キャッシュの関係で実際にはもう少し複雑になる
  • 特徴的な点としては、Optunaには「実行が終わったTrialの状態は変更できない(削除もない)」およびという仕様があり、これによってRDBStorageの前段にインメモリのキャッシュを置くことが可能となっています(詳細は後述)
    • 他にも「実行中のTrialの状態を更新するのは一度に一プロセスだけ(同時更新不可)」という要求もあったりします
    • この辺りの詳細に興味がある人はBaseStorageのドキュメントを参照してみてください

SamplerおよびPrunerのストレージ利用方法

ストレージの性能特性を把握するためには、SamplerPrunerが、ハイパラのサンプリング(suggest)やトライアルの枝刈り判定を行う際に、ストレージをどのように利用しているかも知る必要があります。
その際の流れは、基本的には以下のようになります:

  1. Study.get_trials() (or Storage.get_all_trials())を使用して、同じスタディ内の全トライアル(試行履歴)を取得
  2. その履歴を使って、次に探索すべき点や枝刈りすべきかどうかを決定する
    RandomSamplerGridSamplerThresholdPrunerのようにトライアルの履歴を必要としないものは例外

つまり「Trial.suggest_float()Trial.should_prune()が呼ばれる度に、内部ではStorage.get_all_trials()が実行されている」と考えてもらえれば、おおむね大丈夫です。

ただしget_all_trials()呼び出しの度にRDBにクエリを発行していては、I/Oの量がトライアル数に対して線形に増加していってしまい大変なことになるので、Optunaではキャッシュの導入によって、ほぼ定数オーダーのI/Oで済むように工夫が行われています。

ストレージのキャッシュ

RDBStorage用のキャッシュ機能は_CacheStorageで実装されており、get_all_trials()のキャッシュの仕組みは、ざっくりと言えば、以下のようになります:

  1. 一番最初の呼び出しでは、普通にRDBから全トライアルのデータを取得
  2. その中でTrialStateCOMPLETEPRUNEDFAILのいずれかだったトライアルについては、PythonのDictに保存しておく
  3. 次回のget_all_trials()呼び出しでは、上記のDictに格納されているトライアルは除外するように、RDBにクエリを発行する
  4. 新規取得分と既存キャッシュ分をマージして呼び出し元に返す
  5. 以後は、2~4を繰り返す

「概要」部分で書いた通り、Optunaでは終了したトライアルの状態が更新されることはないため、このようなキャッシュが安全に行えます。
これによって、最初の一回の呼び出しを除けば、get_all_trials()では、その時点で実行中のトライアルのデータのみがRDBから取得されるようになりました(トライアル数に対してではなく、分散最適化の同時実行数への比例で済むようになった)。

なお、上で触れたもの以外にも、Study関連だったり、取得ではなく更新時のキャッシュ等もあったり、その他細かい話はいろいろあるのですが、今回は割愛します。
キャッシュ周りの詳細に興味がある人は、以下のPRを参照してみてください:

コード量

後でoptjournalとの比較に使うためRDBStorageおよびRedisStorageのコード行数も載せておきます:

$ pwd
/.../optuna/optuna/storages

# `RDBStorage`系 (コメントと空行は除外)
$ grep -c -v -E '^[ ]*#|^[ ]*$' _cached_storage.py _rdb/models.py _rdb/storage.py
_cached_storage.py:358
_rdb/models.py:384
_rdb/storage.py:1090  # 合計: 1832

# `RedisStorage`系
$ grep -c -v -E '^[ ]*#|^[ ]*$' _redis.py
511

課題(私見)

以上で、Optunaの現在のストレージ実装(特にRDBStorage)についての説明は終わりとなります。

現在の実装でも、キャッシュを活用することで、不要なI/Oを極力削減するように頑張られていますが、個人的にはいくつか課題も残っていると感じるため、最後にそれを書いておきます:

  • キャッシュの仕組みは上手く機能してはいるけれど複雑度は増す
    • これは「RDBにStudyTrialの状態をそのまま格納」している以上、仕方がない問題
      • 採用した設計に起因する問題なので、抜本的に改善しようとするとRDBStorage等の設計自体を変える必要がでてくる
    • また(キャッシュ導入のために追加された)「終わったトライアルは更新できない」という制約が困る場面もまれにあったりする(e.g., optuna#2936)
  • BaseStorageを継承した自前ストレージを実装するのが大変
    • 要望としてはたまにある
      • e.g., 初期の頃にFirestore用のストレージをユーザが実装していた (optuna#309)
      • RedisStorageもユーザのPRによって導入されたもの (optuna#974)
    • ただし、これまでに書いた通りBaseStorageには23個の実装すべき抽象メソッドがあったり、効率を出すためにキャッシュのことも考慮が必要だったり、分散最適化(マルチスレッド含む)実行時にも問題なく動作するようにする必要がある、etcで、完全なものを実装するのは容易ではない
      • RedisStorageのPRもマージは2020年4月だけれど、RDBStorageの修正や機能追加に追従しきれておらず、まだexperimental扱いとなっている
      • ※ 補足ですが、Optunaのv3では、BaseStorageの要実装メソッド数を減らすことが検討されていたりします(optuna#2943
    • こちらについては設計と実装が原因の半々、といった印象
      • 例えば、それぞれの永続化ストレージで共通の機能を括りだして別のクラスに分離することで、個々のストレージバックエンド(e.g., RDB, Redis)クラスの実装負荷は減らせる可能性はある(これはRDBのテーブル定義等には影響を与えずに、実装の工夫だけで対処可能な話)

optjournalは、こういった課題に対処する(方法を探る)ために開発されたライブラリです。

optjournalの設計や実装

ここからが、ようやく、本題のoptjournal (v0.0.2)の話となります。

状態の読み書きではなく、適用された操作の追記

optjournalは、RDBStorage等とは異なり、RDB等のストレージバックエンドの中にStudyTrialの完全な状態がそのまま格納されている訳ではありません。
その代わり、例えば「新しいトライアルを開始した」とか「ユーザ属性が更新された」等といった操作のログ(ジャーナル)を、RDB等に追記していく方式を採用しています。

具体的には、RDBの場合、以下の二つのテーブルのみが存在します(optunajournal/_models.py):

# Studyの名前とIDの対応を管理するためのテーブル
class StudyModel(_BaseModel):
    __tablename__ = "optjournal_studies"
    __table_args__ = (UniqueConstraint("name"),)
    id = Column(Integer, primary_key=True)
    name = Column(String(256), index=True, nullable=False)

# Study/Trialに適用された操作ログを保持するためのテーブル
class OperationModel(_BaseModel):
    __tablename__ = "optjournal_operations"
    id = Column(Integer, primary_key=True)
    study_id = Column(Integer, ForeignKey("optjournal_studies.id"), index=True, nullable=False)
    data = Column(String(4096), nullable=False) # [操作ID, 操作固有のデータ]形式のJSON

操作一覧はoptjournal/_operation.pyで定義されています:

class _Operation(enum.Enum):
    CREATE_STUDY = 0
    CREATE_TRIAL = 1
    SET_STUDY_USER_ATTR = 2
    SET_STUDY_SYSTEM_ATTR = 3
    SET_STUDY_DIRECTIONS = 4
    SET_TRIAL_PARAM = 5
    SET_TRIAL_VALUES = 6
    SET_TRIAL_USER_ATTR = 7
    SET_TRIAL_SYSTEM_ATTR = 8
    SET_TRIAL_STATE = 9
    SET_TRIAL_INTERMEDIATE_VALUE = 10

StudyTrialの更新系メソッドが呼び出された際には、対応するエントリがストレージ内の操作ログに追記されていくことになります。
取得時の流れは、もう少し複雑で、ざっくりと言えば以下のようになります:

  1. 最初は、対象スタディの全操作ログを取得する
  2. 取得したログを、先頭から走査し、各操作をローカルに存在する_Studyおよび_Trialオブジェクトに反映する
    • e.g., 操作が[SET_TRIAL_USER_ATTR, {"key": k, "value": v}]なら_Trial.user_attrs[k] = vを実行
    • _Trialoptuna.trial.FrozenTrialの単純なラッパーで、_Studyも似たような感じ
    • ログの最後まで反映すれば、StudyTrialの最新の状態がローカルで再構築できたことになる
  3. ログの最後の操作のIDを覚えておく
  4. 取得メソッドの呼び出し元には、構築された_Studyおよび_Trialから値を返す
  5. 次の取得系メソッド呼び出し時も基本的には流れは同様
    • ただし「全操作ログ」ではなく「3で覚えておいたIDよりも新しいもの」を対象にする(差分取得)
    • これによって、RDB等から取得するデータ量は「新たに適用された更新系操作の数」にしか比例しなくなるので、ほぼ最小限となる(逆に効率が悪くなるケースについては後述)

細かい話は他にもいろいろとあるのですが、基本的な設計・方針としては以上となります。

以降では「なぜこの設計だと上述のRDBStorage等の課題が解決・軽減できるのか」を見ていきます。
また、その後にはoptjournal側の問題点にも触れます。

optjournalのメリット

キャッシュが不要

上述の通り、optjournalは「永続化された操作ログ」と「そのログが反映されたローカルオブジェクト」の二つから構成されています。
操作ログの差分をローカルオブジェクトに反映さえしてしまえば、取得系メソッドではそのローカルオブジェクトの値を直接参照できるので、BaseStorage.get_all_trials()のようなメソッドを複数回呼び出したとしても、I/Oという観点では、ほとんどコストが掛かりません。
そのため、RDBStorageのように前段に別途キャッシュを設ける必要がなくなり、実装をシンプルにすることができるのがメリットの一つです。
(ローカルオブジェクトがある種のキャッシュの役割を果たしている、とも言えるかもしれません)

TrialStateに依存したキャッシュが不要となるので「終了したTrialは更新不可」という制約を外すことも可能です。

なお、以前に雑に計測したベンチマークでは、optjournalのRDBを使用したストレージ実装の性能は、RDBStorageとほぼ同等といった感じでした。

ストレージバックエンドの責務が少なく、自前実装の追加が比較的容易

optjournalが提供しているストレージ群も、Optunaのストレージとして利用可能になっています(以下はその例):

from optjournal import JournalStorage
import optuna

def objective(trial):
    x = trial.suggest_float('x', 0, 1)
    y = trial.suggest_float('y', 0, 1)
    return x * y

storage = JournalStorage("sqlite:///optuna.db")
study = optuna.create_study(storage=storage)
study.optimize(objective, n_trials=100)

上のコード内のJournalStorageは、当然BaseStorageは継承し、必要なメソッド群を実装している訳なのですが、その辺りは個々のストレージバックエンド(e.g., RDB、ファイルシステム)とは独立した共通部分として切り出されています(ファイルとしてはoptjournal/_storage.py)。

各ストレージバックエンドに要求するインタフェースは、別にDatabaseというクラスで抽象化されています(optjournal/_database.py):

# 各バックエンドが実装する必要があるメソッドは以下の七個
class Database(object, metaclass=abc.ABCMeta):
    def create_study(self, study_name: str) -> _models.StudyModel:
    def find_study(self, study_id: int) -> Optional[_models.StudyModel]:
    def find_study_by_name(self, study_name: str) -> Optional[_models.StudyModel]:
    def delete_study(self, study_id: int) -> Optional[_models.StudyModel]:
    def get_all_studies(self) -> List[_models.StudyModel]:
    def append_operations(self, ops: List[_models.OperationModel]) -> None:
    def read_operations(self, study_id: int, next_op_id: int) -> List[_models.OperationModel]:

このため、個々のバックエンドの実装は、以下のように小規模で済んでいます:

# RDBバックエンド
$ grep -c -v -E '^[ ]*#|^[ ]*$' optjournal/_rdb.py
103

# ファイルシステムバックエンド:
# - 各スタディに対して、別個のファイルが割り当てられ、そこに操作ログ追記される
# - NFSでも動作するように実装されている(はず)
$ grep -c -v -E '^[ ]*#|^[ ]*$' optjournal/_file_system.py
176

# 参考までに、optjournal全体の行数
$ grep -c -v -E '^[ ]*#|^[ ]*$' optjournal/*.py
optjournal/version.py:1
optjournal/_database.py:26
optjournal/_file_system.py:176
optjournal/_id.py:7
optjournal/_lazy_study_summary.py:56
optjournal/_models.py:24
optjournal/_operation.py:13
optjournal/_rdb.py:103
optjournal/_storage.py:255
optjournal/_study.py:177
optjournal/__init__.py:3  # 合計: 841行

※ Optunaのv2.5.0RDBStorageに導入されたハートビード機能には未対応だったりするので、その辺りも全部ちゃんと追従しようとするともう少し行数は増えるかもしれません(逆に、_Operationenumへの追加+αくらいで済むなら、_rdb.py_file_system.pyの行数は変わらない可能性もあります)

これくらいなら、自前ストレージの追加も、そこそこ現実的な印象です。

optjournalのデメリット

ここまでは良い点ばかりに触れてきましたが、optjournal(の現状の実装)にはデメリットもあるため、次はそれらの紹介です。

StudyTrialの一部の状態だけしか必要がない時に非効率

RDBStorageでは、RDB内にStudyTrialの完全な情報(スナップショット)が格納されています。そのため、その一部だけを取得する、といったことが極めて効率的に行えます。例えばユーザがStudy.directionプロパティにアクセスした場合には、StudyDirectionModelテーブルのdirectionカラムの値をRDBから取得するだけで済みます。

それに対してoptjournalのJournalStorageは、Database内には操作ログの情報しか存在しないため、一度それらを全て取得して、ローカルでStudyTrialの情報を再構築する必要があります。
幸い、Optunaの通常のユースケース(e.g., 最適化や可視化)では、いずれにせよどこかのタイミングで、スタディ内の全てのトライアルへのアクセスが必要となることが普通なので、この最初のコストは償却できることがほとんどです。
ただし、ユーザが本当に(例えば)Study.direction等の一部の情報のみが必要な場合には、無駄がとても大きくなります。

性能が大幅に劣化する病的なケースが存在する

他にも(単純な)操作ログ方式だと無駄が多くなるケースは存在します。
以下のシナリオを考えてみましょう:

import optuna

def objective(trial):
    # 同じユーザ属性に大量の更新を行う
    for i in range(0, 10000):
        trial.set_user_attr("key", i)
    return 0
    
study = optuna.create_study(storage=...)
study.optimize(objective, n_trials=1)

この場合、RDBStorageRedisStorageでは、対象のユーザ属性のエントリが更新されるだけです。
対して、JournalStorageの場合には、一万個の操作ログエントリがDatabase内に保存され、また再構築時にはそれらを全て順番に適用する必要があります。一番最後のエントリ以外は、単に捨てられるだけなので、これは明らかに無駄です。

これらの問題に対処することも不可能ではありません。
例えば、以下のようなスナップショット機能の導入が考えられます:

  • 操作ログの数が一定数(e.g, 1000)溜まったら、その時点でのStudyの状態(スナップショット)を保存する
    • スナップショットの形式は、例えばRDBStorageのテーブル定義をそのまま使っても良い
    • あるいは、Study全体を単にJSONシリアライズした結果を保存しても良い
      • こちらの場合は「一部の情報だけを効率的に取得」には対応不可
  • スナップショットの取得が終わったら、それ以前の操作ログは削除する
  • ローカルオブジェクトの初期値には、スナップショットの状態を反映する
  • それ以外は、以前と同様

これにより上で取り上げた二つの性能問題を解消することは可能ではありますが、代わりにコードが複雑になってしまいます。
上述の問題点は、いずれもOptunaの典型的なユースケースでは発生しなさそう、ということもあり現在のoptjournalでは、特に対処しないままとなっています。

終わりに

長々と書いてきましたが、ユーザ側の視点からすれば、optjournalを採用する理由はほぼないと思うので、Optunaが公式で提供しているストレージを使用するのをお勧めします。

ただ、万が一需要があれば、もう少し真面目にリポジトリにドキュメントを書いたり、PyPIに登録したりするかもしれないので、issue等で教えて貰えればと思います🙏

Discussion