🤖

Google製GPU対応物理シミュレータBraxの調査

12 min read

0. はじめに

自身のブログに英語で書いたまとめの日本語焼き直し記事です。

BraxJAXを利用して書かれた物理シミュレータで、GPUなどのハードウェア・アクセラレータを活用することを最初から念頭において開発されています。

Braxといえば、先日もとても良い記事が公開されていましたね。

https://kngwyu.github.io/rlog/ja/2021/12/18/jax-brax-haiku.html

1. 環境

1.1 Brax Env

1.1.1 作成

BraxはOpenAI Gym類似の環境を定義しています。
brax.envs.create(env_name, episode_length=100, action_repeat=1, auto_reset=True, batch_size=None, **kwargs) 関数によって構築することができます。

Brax環境作成
# "import brax" では、 "envs"にアクセスできません。
# "envs" を以下のどちらかの方法で明示的に import する必要があります。
# - from brax import envs
# - import brax.envs
from brax import envs

ant = envs.create("ant")
引数 説明
env_name str 環境指定キー
episode_length int 最大エピソード長。この長さで自動的に打ち切られる。
action_repeat int (たぶんAtariでやってるみたいに複数回行動を繰り返す奴)
auto_reset bool エピソード終端で自動的に環境をリセットするかどうか。
batch_size Optional[int] 指定されていると、ベクトル化された環境ができる。
**kwargs 環境のコンストラクタに渡される。(ただし、見た限りどのコンストラクタもうまく受け取れないので、将来の拡張用か?)

調べると以下の、"env_name" が定義されていました。
なんとなく想像がつく環境もありますが、説明がついていなかったものは空欄としています。

env_name 説明
"ant" OpenAI Gym Ant-v2
"fast" (ユニットテスト用のTrivialな環境)
"fetch" 3次元のボールを追いかける犬(?)
"grasp" ボールをロボットハンドで掴む
"halfcheetah" OpenAI Gym HalfCheetah-v2
"hopper" OpenAI Gym Hopper-v2
"humanoid" OpenAI Gym Humanoid-v2
"humanoidstandup"
"inverted_pendulum"
"inverted_double_pendulum"
"reacher" OpenAI Gym Reacher-v2
"reacherangle"
"swimmer"
"ur5e" ur5e robot hand
"walker2d" OpenAI Gym Walker2d-v2

1.1.2 シミュレート

最も大きな違いは、brax.envs.Env はステートレスな実装になっていることです。

Env.step(self, state, action) は行動だけでなく現在の状態(brax.envs.env.State)を、Env.reset(self, rng) と乱数状態をそれぞれ要求します。
JAXベースであるため、乱数状態の取り扱いは少し癖があります。

Braxでの乱数の取り扱い例
from brax import jumpy as jp
# JAX/Numpy を抽象化したファサードモジュールです。

# 1. 最初の乱数状態は明示的なランダムシードによって生成します。
rng = jp.random_prngkey(seed=42)

# 2. その後は、**利用していない** 乱数状態を分割することで利用しつつ更新します。
# 重要: 乱数分割に利用されているので、元の乱数状態は上書きするなりして捨てる必要があります。
rng, rng2 = jp.random_split(rng)
state = ant.reset(rng2)

# オプション: もしより多くの乱数状態が必要であれば、一度に生成することもできます。
rng, rng2, rng3 = jp.random_split(rng, 3)
state = ant.reset(rng2)
state = ant.reset(rng3)

また、他の違いとして、render メソッドを所持しておらず、一連の状態を利用して別の関数で可視化します。(詳細は後述)

シミュレーション例
from brax import jumpy as jp
from brax import envs
from brax.io import html

# Jupyter Notebook想定
from IPython.display import HTML

ant = envs.create("ant")

rng = jp.random_prngkey(42)
rng, rng_use = jp.random_split(rng)

qps = []

state = ant.reset(rng_use)
qps.append(state.qp)

for _ in range(20):
    rng, rng_use = jp.random_split(rng)
    state = ant.step(state, jp.random_uniform(rng_use, (ant.action_size,)))
    qps.append(state.qp)

HTML(html.render(ant.sys, qps))


antシミュレーションの例 (上記の場合、マウス等でカメラを動かせるビューアーが作成されます。今回はそのスクリーンショット)

1.2 Gym互換 Env

brax.envs.wrappers.GymWrapper等を使って、類似ではなくGym互換の環境を作ることができます。
brax.envs.create_gym_env(env_name, batch_size=None, seed=0, backend=None, **kwargs)関数を利用することで、前節と同様に作成することができます。
また、step reset などといったメソッドを、jax.jit デコレータでラップしてくれるので、JAXの細かいことがわからない場合でも、それなりのパフォーマンスが出せるはずです。

引数 説明
env_name str 環境名キー
batch_size Optional[int] 指定されていればベクトル化される
seed int 乱数シード
backend Optional[str] jax.jitに渡されるバックエンド。("cpu""gpu""tpu"が存在するが、この機能自体 experimental とJAXにコメントあり)
**kwargs コンストラクタに渡される。

1.3 カスタム Env

ぱっと読んだ限り、独自の環境を作成するためのAPI等は用意されていないようでした。そのため、この節に記載する方法は若干ハック的で将来的に使えなくなる可能性があります。

1.3.1 事前定義済み環境へのカスタムパラメータ指定

"gravity""dt" (時間幅)、"substeps" (時間幅内のシミュレーションサブステップ数) などのパラメータだけを差し替えたいときは、_SYSTEM_CONFIG ()の文字列を書き換えるのが最も簡単だと思われます。

パラメータ自体にはEnv構築後も Env.sys.config.dt等とアクセスできますが、インテグレータ等の他の要素がEnv構築時に一緒に作られてしまうので、構築後に変更してはいけません。

カスタムパラメータ
from brax import envs
from brax.envs import ant, env

from brax import jumpy as jp
from brax.io import html

# Jupyter Notebook想定
from IPython.display import HTML

CUSTOM_CONFIG = ant._SYSTEM_CONFIG.replace("gravity { z: -9.8 }",
                                           "gravity { z: 0.0 }")
assert CUSTOM_CONFIG != ant._SYSTEM_CONFIG

class SpaceAnt(ant.Ant):
    def __init__(self, **kwargs):
        # env.Env.__init__ を明示的に呼び出して、ant.Ant.__init__ をバイパスする
        env.Env.__init__(self, CUSTOM_CONFIG, **kwargs)


# brax.envs.create に無理やり登録する
envs._envs["SpaceAnt"] = SpaceAnt
spaceAnt = envs.create("SpaceAnt")


rng = jp.random_prngkey(42)
rng, rng_use = jp.random_split(rng)

qps = []

state = spaceAnt.reset(rng_use)
qps.append(state.qp)

for _ in range(20):
    rng, rng_use = jp.random_split(rng)
    state = spaceAnt.step(state, jp.random_uniform(rng_use, (spaceAnt.action_size,)))
    qps.append(state.qp)

HTML(html.render(spaceAnt.sys, qps))


無重力下で浮かび上がるSpaceAnt(仮称)。動いていると思わず笑ってしまいました。

1.3.2 Protocol Bufferを利用したフルスクラッチ構築

BraxのモデルはProtocol Bufferとして解釈される文字列で定義されています。
なので、同様にモデルを定義することができれば、フルスクラッチでモデルを作成することができます。Protocol Bufferのバリデーションは次の関数によって行われています。brax.physics.base.validate_config(config, resource_path=None)

モデル定義とバリデーション
import brax
from brax.physics.base import validate_config
from google.protobuf import text_format

config = """
... (ここにモデルをProtocol Bufferと解釈できる文字列として定義)
"""

config = text_format.Parse(config, brax.Config())
validate_config(config)

Protocol Bufferのスキーマは、(転載元の英語ブログには書き起こしたんですが)あんまりにも多くて長くなるので、この記事への記載は辞めておきますが、定義元のファイルは以下になります。(読み方知らなかったですが、読めたので分かりやすいはず。)

https://github.com/google/brax/blob/main/brax/physics/config.proto

また、stepreset を実装する必要があります。
モデルに応じて独自に実装する必要がありますが、大枠は以下の様になるはずです。

カスタムEnv
from brax.envs import env
from brax import jumpy as jp

class CustomEnv(env.Env):
    def __init__(self, **kwargs):
        config = """
        ... (ここにモデルをProtocol Bufferとして解釈できる文字列として定義)
        """
        super().__init__(config, **kwargs)

    def reset(self, rng: jp.ndarray) -> env.State:
        # ジョイントの角度と速度
        # - Protocol Buffer定義のデフォルト角度は `default_angle()` で取得可能
        # - 必要に応じてランダムノイズなどを `rng` を使って付加する。
        qpos = self.sys.default_angle()
        qvel = jp.zeros((self.sys.num_joint_dof,))

        # `qp` (系の状態) をジョイント情報から構築する
        qp = self.sys.default_qp(joint_angle=qpos, joint_velocity=qvel)

        # `obs` (observation) を設定する
        obs = ...

        # `reward` (報酬)を計算する
        reward = ...

        # `metrics` を辞書形式で作成する
        metrics = {
            ...
        }

        return env.State(qp, obs, reward, done, metrics)

    def step(self, state: env.State, action: jp.ndarray) -> env.State:
        # 物理エンジンを利用して、ステップをすすめる
        qp, info = self.sys.step(state.qp, action)

        # `obs` (observation) を設定する
        obs = ...

        # `reward`(報酬)を計算する
        reward = ...

        # エピソードが終了していないかチェックする (`done`) 
        done = ...

        # `metrics`を更新する
        state.metrics.update(
            ... # `key = value` スタイルで指定
        )

        return state.replace(qp=qp, obs=obs, reward=reward, done=done)

また、勿論前節のように create から呼べるように(無理やり)登録することも可能です。

1.3.3 Python APIを利用したフルスクラッチ構築

Protocol Bufferからではなく、PythonのAPIを通して、モデルを構築することもできます。
ただ、ドキュメントもあまりなく公式のドキュメントや例から推定しています。

カスタムEnv(Python API)
import brax
from brax.envs import env

class CustomEnv(env.Env):
    def __init__(self, **kwargs):
        # env.Env.__init__ を呼ばずに、手動で self.sys を構築します。

        config = brax.Config(dt = 0.05, substeps = 4)

        ground = config.bodies.add(name="ground") # "ground"という名前でボディーを追加
        ground.frozen.all = True # 動かない固定化
        plane = ground.colliders.add().plane # 衝突判定を平面タイプ指定。
        plane.setInParent() # (たぶん)設置

        ball = config.bodies.add(name="ball", mass=1) # "ball" という名前で 1kg のボディーを追加
        capsule = ball.colliders.add().capsule # 衝突判定をカプセルタイプを指定。
        capsule.radius, capsule.length = 0.5, 1 # カプセルの半径と長さを指定。(結果「球」になる。)

        # 重力指定 m/s^2
        config.gravity.z = -9.8
	
	# 他にも色々モデルを追加したりできるはず。。。

        self.sys = brax.System(config)


    # 勿論 step / reset の実装が必要です。(ここでは省略)
    def reset(self, rng: jp.ndarray) -> env.State:
        pass

    def step(self, state: env.State, action: jp.ndarray) -> env.State:
        pass

2. 可視化

Braxでは大きく3種類の可視化方法があります。いずれもシステム(brax.System)と系の状態(brax.QP)を利用します。

2.1 HTML

brax.io.html.render(sys, qps, height = 480)関数が、Jupyter Notebook等に埋め込めるインタラクティブなHTMLビューアーを返します。

Jupyter Notebook等を利用している際には、第一候補になると思います。ただ、いったん一連の状態をすべて確保してから関数を呼ぶ必要があるので、あまり長いのは適していないかもしれません。
(また、ずっと画面内に表示しているとPCのファンが回りだすので、処理が重たいのかもしれません。)

引数 説明
sys brax.System シミュレーションの系
qps List[brax.QP] 一連のシミュレーションの状態
height int ビューアーの高さ
HTMLビューアーの例
from IPython.display import HTML

from brax import envs
from brax import jumpy as jp
from brax.io import html

ant = envs.create("ant")
state = ant.reset(jp.random_prngkey(seed=42))

HTML(html.render(ant.sys, [state.qp]))

追記 (2022/1/3)

BraxのHTML viewerがheightを無視するバグがある気がしたので、issueを立てました。

https://github.com/google/brax/issues/142

2.2 RGB配列

brax.io.image.render_array(sys, qp, width, height, light=None, camera=None, ssaa=2) は1ステップの状態をRGBの配列に変換します。

BraxはPyTinyRendererを利用しています。

引数 説明
sys brax.System シミュレーションの系
qp brax.QP シミュレーションのとある状態
width int 出力配列の幅
height int 出力配列の高さ
light Optional[pytinyrenderer.TinyRenderLight] 光源の位置。Noneの時、固定されたデフォルト位置を利用する。
camera Optional[pytinyrenderer.TinyRenderCamera] カメラの位置とアングル。Noneの時、最適なカメラの位置を推定する。
ssaa int Super-Sampling Anti-Aliasing。最初に、ssaa倍大きな画像をレンダリングして、その後にssaa * ssaaピクセルを平均する。
RGB配列化
from brax import envs
from brax import jumpy as jp
from brax.io import image

ant = envs.create("ant")
state = ant.reset(jp.random_prngkey(seed=42))

rgb = image.render_array(ant.sys, state.qp)

2.3 画像ファイル

brax.io.image.render(sys, qps, width, height, light=None, cameras=None, ssaa=2, fmt='png')関数は、画像フォーマットのbytesデータを作成します。(注意: 画像をファイルに書き出すわけではなく、メモリ上に画像フォーマットのデータを作成するだけです。)

内部では、各ステップごとに、brax.io.image.render_arrayでRGB配列を作成しています。

引数 説明
sys brax.System シミュレーションの系
qps List[brax.QP] 一連のシミュレーションの状態
width int 出力配列の幅
height int 出力配列の高さ
light Optional[pytinyrenderer.TinyRenderLight] 光源の位置。Noneの時、固定されたデフォルト位置を利用する。
cameras Optional[List[pytinyrenderer.TinyRenderCamera]] 一連のカメラの位置とアングル。Noneの時、最適なカメラの位置を推定する。
ssaa int Super-Sampling Anti-Aliasing。最初に、ssaa倍大きな画像をレンダリングして、その後にssaa * ssaaピクセルを平均する。
fmt str 画層ファイルエンコーディング

3. ハードウェア・アクセラレータ

3.1 GPU

BraxがベースとしているJAXは自動的に利用可能なGPUを利用するので、特に追加のコードは不要です。

3.2 TPU

TPUをGoogle Corab上で利用するには、そのセットアップをする必要があります。
といってもセットアップ用の補助関数がJAXに用意されているので簡単です。

Corab上でのTPU利用
import os
if 'COLAB_TPU_ADDR' in os.environ:
  from jax.tools import colab_tpu
  colab_tpu.setup_tpu()

Google Colab上で、ランタイムのタイプを変更 > ハードウェア アクセラレータ > TPUを選択すると、環境変数 COLAB_TPU_ADDR が設定されます。
この環境変数の中に、実際にTPUが存在する IPアドレス:ポート が書かれているので、その値を利用してセットアップされます。
環境変数が存在しないと、エラーになるので、ここでは環境変数チェックを挟んでいます。

4. おわりに

Google製の物理シミュレータBraxの調査をしてモデルの作成方法等、基本的な利用方法をまとめました。(本当はちゃんとシミュレーションを動かしてみて、コード例とともにってところまで、やりたかったのですが、長くなってきたのでここまでにしておきます。)

Discussion

ログインするとコメントできます