🤖

【強化学習】Braxの可視化を容易にする機能を実装して公開

2022/01/04に公開

0. はじめに

先日調査記事を書きましたGoogle製のGPU対応物理シミュレータBraxの可視化を容易にする機能を実装して、拙作のライブラリ Gym-Notebook-Wrapper(gnwrapper) の一部として公開しました。

https://zenn.dev/ymd_h/articles/092e3888e19046

1. Braxの可視化

先日の調査記事にも記載しましたが、Braxには3種類の可視化機能があります。

環境を可視化するという点では、brax.io.html.render() 関数により、HTML viewerでの可視化が視点操作や再生・停止をインタラクティブに操作できて便利です。

Google Corab/Jupyter Notebook上
!pip install brax

from IPython.display import HTML
from brax.io import html
from brax import envs
import brax.jumpy as jp

rng = jp.random_prngkey(42)
ant = envs.create("ant")

rng, rng_use = jp.random_split(rng)
state = ant.reset(rng_use)

qps = [state.qp]
while True:
    rng, rng_use = jp.random_split(rng)
    state = ant.step(state, jp.random_uniform(rng_use, (ant.action_size,)))
    qps.append(state.qp)
    if state.done:
        break

display(HTML(html.render(ant.sys, qps)))

また他にもRGB配列でレンダリングした画面を取得する方法や、(アニメーション)画像形式にして(メモリ上に)書き出す方法があります。これらは視野が固定のデータ形式で見る際の機能が制限されていますし、(私の書き方が悪かったのか)画面レンダリングが遅かったので、画像データを入力にモデルを学習させるのでなければ、率先して利用するメリットはなさそうに感じました。

2. gnwrapper への実装

前節で書いた処理は定形ですが、毎回書くのは面倒に感じたため、OpenAI GymをJupyter Notebook上で可視化するためのWrapperクラスを提供している拙作の Gym-Notebook-Wrapper (gnwrapper) ライブラリに新規にWrapperクラスを実装することにしました。

https://qiita.com/ymd_h/items/c393797deb72e1779269

2.1 Brax Env (brax.envs.Env) の可視化

Wrapperとして gnwrapper.brax.BraxHTML を実装したので、brax.envs.Envをラップして使うことになります。

step() メソッドの中で自動的に系の状態を記録し、state.done=1となるエピソード終端でHTMLファイルに書き出します。

reset() メソッドでエピソード番号を増加するとともに、保持している一連の系の状態をクリアします。

gnwrapperの従来のWrapper群と同様にdisplay() メソッドでNotebook上に埋め込みHTML viewerを表示します。引数を何も指定しなければ、記録済みの全エピソードを、intまたはList[int]で指定すればそのエピソードだけを表示します。

記録済みのエピソード一覧は、recorded_episodes()メソッドで取得できます。

また、副次的なメリットですが、HTMLファイルとして書き出すため、そのファイルをブラウザで直接開いてみることも可能です。(データ自体はファイル内に書き込まれていますが、ViewerはCDNにホストされているため、インターネット接続が必要です。)

BraxHTML例
!pip install brax gym-notebook-wrapper

from brax import envs
import brax.jumpy as jp
from gnwrapper.brax import BraxHTML

rng = jp.random_prngkey(42)
ant = BraxHTML(envs.create("ant", auto_reset=False))

rng, rng_use = jp.random_split(rng)
state = ant.reset(rng_use)

while True:
    rng, rng_use = jp.random_split(rng)
    state = ant.step(state, jp.random_uniform(rng_use, (ant.action_size,)))
    if state.done:
        break

episodes = ant.recorded_episodes()

ant.display() # 記録済みすべて
ant.display([1]) # エピソード1のみ (エピソード1が記録されていなければ何も表示されない)

gnwrapper.brax.BraxHTMLのコンストラクタの引数は以下のとおりです。このWrapperにPythonレベルでの副作用があり、あとからjax.jitでラップすることが(たぶん)できないので、渡される環境のstep()/reset()を内部でjax.jitするオプションを付けています。

引数(=デフォルト値) 説明
env brax.envs.Env ラップされる環境
directory=None Optional[str] ファイルを保存するディレクトリ。未指定の場合はタイムスタンプ "%Y%m%d-%H%M%S" を利用
height=480 int 出力Viewerの高さ[1]
video_callable=None Optional[Callable[[int], bool]] エピソードが記録されるかを判断するための関数。デフォルトでは、1000未満の立法数(n^3)と1000エピソード毎に記録される。尚エピソード番号はreset()メソッドを呼ぶたびに増加し、構築時には0のため事実上1スタート。
jit=True bool 渡された環境のstep()/reset()jax.jitでラップするかどうか。

2.2 Gym互換のBrax Env (brax.envs.wrappers.GymWrapper) の可視化

Braxでは、brax.envs.create_gym_env()関数を利用することで、Braxの環境をGym互換APIで利用することができます。(互換レイヤーbrax.envs.wrapper.GymWrapperでラップされている。)

上記の方法で作成した環境は gnwrapper.brax.GymHTML を使うことで前節と同様に記録・可視化できます。

GymHTML例
!pip install brax gym-notebook-wrapper

from brax import envs
import brax.jumpy as jp
from gnwrapper.brax import GymHTML

rng = jp.random_prngkey(42)
ant = GymHTML(envs.create_gym_env("ant", auto_reset=False, seed=0))

obs = ant.reset()

while True:
    rng, rng_use = jp.random_split(rng)
    obs, reward, done, _ = ant.step(jp.random_uniform(rng_use, ant.action_space.shape))
    if done:
        break

episodes = ant.recorded_episodes()

ant.display() # 記録済みすべて
ant.display([1]) # エピソード1のみ (エピソード1が記録されていなければ何も表示されない)

gnwrapper.brax.GymHTMLのコンストラクタの引数は以下のとおり。基本的には gnwrapper.brax.BraxHTMLと同様ですが、brax.envs.wrappers.GymWrapperが既に内部のstep()/reset()jax.jitでラップしているので、追加でラップするような機能はありません。

引数(=デフォルト値) 説明
env brax.envs.wrapeprs.GymWrapper ラップされるGym互換のBrax環境
directory=None Optional[str] ファイルを保存するディレクトリ。未指定の場合はタイムスタンプ "%Y%m%d-%H%M%S" を利用
height=480 int 出力Viewerの高さ[1:1]
video_callable=None Optional[Callable[[int], bool]] エピソードが記録されるかを判断するための関数。デフォルトでは、1000未満の立法数(n^3)と1000エピソード毎に記録される。尚エピソード番号はreset()メソッドを呼ぶたびに増加し、構築時には0のため事実上1スタート。

2.3 注意・制限事項

gnwrapper.brax サブモジュールはBraxを利用するため、単に import gnwrapper としただけではインポートされないように意図的にしています。
そのため、import gnwrapper.braxfrom gnwrapper import braxfrom gnwrapper.brax import BraxHTML等と明示的にインポートする必要があります。

また、Braxにはreset()メソッドを呼ばなくて良いようにするAuto Resetの機能 (brax.envs.wrappers.AutoResetにより提供) がありますが、エピソード終端処理を正しく行えなくなるので、サポートしていません。auto_reset=Falseを指定して環境 (Environment) を作成してください。

from brax import envs

ant = envs.create("ant", auto_reset=False)
ant_gym = envs.create_gym_env("ant", auto_reset=False)

同様に、Auto Resetの利用を(ほぼ)前提としているVectorizedされた環境も対応できていません。batch_sizeを指定しないでください。
(学習にはVectorizedされた環境を利用して、評価用の環境だけを可視化するという使い分けはあり得ると思います。)

3. その他

内部的には、brax.io.html.save_html()関数を真似してHTMLに書き出しています。(この関数自体は、heightパラメータを受け入れてくれないので、採用しませんでした。)

brax.io.json.save()という環境と軌道を保存する機能もあるのですが、保存機能だけでその後の読み出し & 利用方法を見つけられなかったので、採用しませんでした。
(JSONと比べても、HTMLのタグ要素が多少増加しているだけなので、記録ファイルのサイズ面でのメリットは殆どなく、JSONに拘る必要もありませんでした。HTMLだとブラウザで直接開けるメリットもありましたし。)

4. おわりに

Braxの可視化を容易にする機能を実装して、Gym-Notebook-Wrapperの一部として公開しました。

また、今回のテーマとはほとんど関係ないですが、brax/toolsとか、brax/experimental/composerとか気になっているので、そのうち調べたいです。

脚注
  1. Braxにバグらしき挙動があり、現時点ではうまく働いていません。(issue) → 2022/1/6追記: PRがマージされたので次の更新で修正されるでしょう。 ↩︎ ↩︎

Discussion