Google製GPU対応物理シミュレータBraxの調査
0. はじめに
自身のブログに英語で書いたまとめの日本語焼き直し記事です。
BraxはJAXを利用して書かれた物理シミュレータで、GPUなどのハードウェア・アクセラレータを活用することを最初から念頭において開発されています。
Braxといえば、先日もとても良い記事が公開されていましたね。
1. 環境
1.1 Brax Env
1.1.1 作成
BraxはOpenAI Gym類似の環境を定義しています。
brax.envs.create(env_name, episode_length=100, action_repeat=1, auto_reset=True, batch_size=None, **kwargs)
関数によって構築することができます。
# "import brax" では、 "envs"にアクセスできません。
# "envs" を以下のどちらかの方法で明示的に import する必要があります。
# - from brax import envs
# - import brax.envs
from brax import envs
ant = envs.create("ant")
引数 | 型 | 説明 |
---|---|---|
env_name |
str |
環境指定キー |
episode_length |
int |
最大エピソード長。この長さで自動的に打ち切られる。 |
action_repeat |
int |
(たぶんAtariでやってるみたいに複数回行動を繰り返す奴) |
auto_reset |
bool |
エピソード終端で自動的に環境をリセットするかどうか。 |
batch_size |
Optional[int] |
指定されていると、ベクトル化された環境ができる。 |
**kwargs |
環境のコンストラクタに渡される。(ただし、見た限りどのコンストラクタもうまく受け取れないので、将来の拡張用か?) |
調べると以下の、"env_name"
が定義されていました。
なんとなく想像がつく環境もありますが、説明がついていなかったものは空欄としています。
env_name |
説明 |
---|---|
"ant" |
OpenAI Gym Ant-v2 |
"fast" |
(ユニットテスト用のTrivialな環境) |
"fetch" |
3次元のボールを追いかける犬(?) |
"grasp" |
ボールをロボットハンドで掴む |
"halfcheetah" |
OpenAI Gym HalfCheetah-v2 |
"hopper" |
OpenAI Gym Hopper-v2 |
"humanoid" |
OpenAI Gym Humanoid-v2 |
"humanoidstandup" |
|
"inverted_pendulum" |
|
"inverted_double_pendulum" |
|
"reacher" |
OpenAI Gym Reacher-v2 |
"reacherangle" |
|
"swimmer" |
|
"ur5e" |
ur5e robot hand |
"walker2d" |
OpenAI Gym Walker2d-v2 |
1.1.2 シミュレート
最も大きな違いは、brax.envs.Env
はステートレスな実装になっていることです。
Env.step(self, state, action)
は行動だけでなく現在の状態(brax.envs.env.State
)を、Env.reset(self, rng)
と乱数状態をそれぞれ要求します。
JAXベースであるため、乱数状態の取り扱いは少し癖があります。
from brax import jumpy as jp
# JAX/Numpy を抽象化したファサードモジュールです。
# 1. 最初の乱数状態は明示的なランダムシードによって生成します。
rng = jp.random_prngkey(seed=42)
# 2. その後は、**利用していない** 乱数状態を分割することで利用しつつ更新します。
# 重要: 乱数分割に利用されているので、元の乱数状態は上書きするなりして捨てる必要があります。
rng, rng2 = jp.random_split(rng)
state = ant.reset(rng2)
# オプション: もしより多くの乱数状態が必要であれば、一度に生成することもできます。
rng, rng2, rng3 = jp.random_split(rng, 3)
state = ant.reset(rng2)
state = ant.reset(rng3)
また、他の違いとして、render
メソッドを所持しておらず、一連の状態を利用して別の関数で可視化します。(詳細は後述)
from brax import jumpy as jp
from brax import envs
from brax.io import html
# Jupyter Notebook想定
from IPython.display import HTML
ant = envs.create("ant")
rng = jp.random_prngkey(42)
rng, rng_use = jp.random_split(rng)
qps = []
state = ant.reset(rng_use)
qps.append(state.qp)
for _ in range(20):
rng, rng_use = jp.random_split(rng)
state = ant.step(state, jp.random_uniform(rng_use, (ant.action_size,)))
qps.append(state.qp)
HTML(html.render(ant.sys, qps))
antシミュレーションの例 (上記の場合、マウス等でカメラを動かせるビューアーが作成されます。今回はそのスクリーンショット)
1.2 Gym互換 Env
brax.envs.wrappers.GymWrapper
等を使って、類似ではなくGym互換の環境を作ることができます。
brax.envs.create_gym_env(env_name, batch_size=None, seed=0, backend=None, **kwargs)
関数を利用することで、前節と同様に作成することができます。
また、step
reset
などといったメソッドを、jax.jit
デコレータでラップしてくれるので、JAXの細かいことがわからない場合でも、それなりのパフォーマンスが出せるはずです。
引数 | 型 | 説明 |
---|---|---|
env_name |
str |
環境名キー |
batch_size |
Optional[int] |
指定されていればベクトル化される |
seed |
int |
乱数シード |
backend |
Optional[str] |
jax.jit に渡されるバックエンド。("cpu" 、"gpu" 、"tpu" が存在するが、この機能自体 experimental とJAXにコメントあり) |
**kwargs |
コンストラクタに渡される。 |
1.3 カスタム Env
ぱっと読んだ限り、独自の環境を作成するためのAPI等は用意されていないようでした。そのため、この節に記載する方法は若干ハック的で将来的に使えなくなる可能性があります。
1.3.1 事前定義済み環境へのカスタムパラメータ指定
"gravity"
、"dt"
(時間幅)、"substeps"
(時間幅内のシミュレーションサブステップ数) などのパラメータだけを差し替えたいときは、_SYSTEM_CONFIG
(例)の文字列を書き換えるのが最も簡単だと思われます。
from brax import envs
from brax.envs import ant, env
from brax import jumpy as jp
from brax.io import html
# Jupyter Notebook想定
from IPython.display import HTML
CUSTOM_CONFIG = ant._SYSTEM_CONFIG.replace("gravity { z: -9.8 }",
"gravity { z: 0.0 }")
assert CUSTOM_CONFIG != ant._SYSTEM_CONFIG
class SpaceAnt(ant.Ant):
def __init__(self, **kwargs):
# env.Env.__init__ を明示的に呼び出して、ant.Ant.__init__ をバイパスする
env.Env.__init__(self, CUSTOM_CONFIG, **kwargs)
# brax.envs.create に無理やり登録する
envs._envs["SpaceAnt"] = SpaceAnt
spaceAnt = envs.create("SpaceAnt")
rng = jp.random_prngkey(42)
rng, rng_use = jp.random_split(rng)
qps = []
state = spaceAnt.reset(rng_use)
qps.append(state.qp)
for _ in range(20):
rng, rng_use = jp.random_split(rng)
state = spaceAnt.step(state, jp.random_uniform(rng_use, (spaceAnt.action_size,)))
qps.append(state.qp)
HTML(html.render(spaceAnt.sys, qps))
無重力下で浮かび上がるSpaceAnt(仮称)。動いていると思わず笑ってしまいました。
1.3.2 Protocol Bufferを利用したフルスクラッチ構築
BraxのモデルはProtocol Bufferとして解釈される文字列で定義されています。
なので、同様にモデルを定義することができれば、フルスクラッチでモデルを作成することができます。Protocol Bufferのバリデーションは次の関数によって行われています。brax.physics.base.validate_config(config, resource_path=None)
import brax
from brax.physics.base import validate_config
from google.protobuf import text_format
config = """
... (ここにモデルをProtocol Bufferと解釈できる文字列として定義)
"""
config = text_format.Parse(config, brax.Config())
validate_config(config)
Protocol Bufferのスキーマは、(転載元の英語ブログには書き起こしたんですが)あんまりにも多くて長くなるので、この記事への記載は辞めておきますが、定義元のファイルは以下になります。(読み方知らなかったですが、読めたので分かりやすいはず。)
また、step
と reset
を実装する必要があります。
モデルに応じて独自に実装する必要がありますが、大枠は以下の様になるはずです。
from brax.envs import env
from brax import jumpy as jp
class CustomEnv(env.Env):
def __init__(self, **kwargs):
config = """
... (ここにモデルをProtocol Bufferとして解釈できる文字列として定義)
"""
super().__init__(config, **kwargs)
def reset(self, rng: jp.ndarray) -> env.State:
# ジョイントの角度と速度
# - Protocol Buffer定義のデフォルト角度は `default_angle()` で取得可能
# - 必要に応じてランダムノイズなどを `rng` を使って付加する。
qpos = self.sys.default_angle()
qvel = jp.zeros((self.sys.num_joint_dof,))
# `qp` (系の状態) をジョイント情報から構築する
qp = self.sys.default_qp(joint_angle=qpos, joint_velocity=qvel)
# `obs` (observation) を設定する
obs = ...
# `reward` (報酬)を計算する
reward = ...
# `metrics` を辞書形式で作成する
metrics = {
...
}
return env.State(qp, obs, reward, done, metrics)
def step(self, state: env.State, action: jp.ndarray) -> env.State:
# 物理エンジンを利用して、ステップをすすめる
qp, info = self.sys.step(state.qp, action)
# `obs` (observation) を設定する
obs = ...
# `reward`(報酬)を計算する
reward = ...
# エピソードが終了していないかチェックする (`done`)
done = ...
# `metrics`を更新する
state.metrics.update(
... # `key = value` スタイルで指定
)
return state.replace(qp=qp, obs=obs, reward=reward, done=done)
また、勿論前節のように create
から呼べるように(無理やり)登録することも可能です。
1.3.3 Python APIを利用したフルスクラッチ構築
Protocol Bufferからではなく、PythonのAPIを通して、モデルを構築することもできます。
ただ、ドキュメントもあまりなく公式のドキュメントや例から推定しています。
import brax
from brax.envs import env
class CustomEnv(env.Env):
def __init__(self, **kwargs):
# env.Env.__init__ を呼ばずに、手動で self.sys を構築します。
config = brax.Config(dt = 0.05, substeps = 4)
ground = config.bodies.add(name="ground") # "ground"という名前でボディーを追加
ground.frozen.all = True # 動かない固定化
plane = ground.colliders.add().plane # 衝突判定を平面タイプ指定。
plane.setInParent() # (たぶん)設置
ball = config.bodies.add(name="ball", mass=1) # "ball" という名前で 1kg のボディーを追加
capsule = ball.colliders.add().capsule # 衝突判定をカプセルタイプを指定。
capsule.radius, capsule.length = 0.5, 1 # カプセルの半径と長さを指定。(結果「球」になる。)
# 重力指定 m/s^2
config.gravity.z = -9.8
# 他にも色々モデルを追加したりできるはず。。。
self.sys = brax.System(config)
# 勿論 step / reset の実装が必要です。(ここでは省略)
def reset(self, rng: jp.ndarray) -> env.State:
pass
def step(self, state: env.State, action: jp.ndarray) -> env.State:
pass
2. 可視化
Braxでは大きく3種類の可視化方法があります。いずれもシステム(brax.System
)と系の状態(brax.QP
)を利用します。
2.1 HTML
brax.io.html.render(sys, qps, height = 480)
関数が、Jupyter Notebook等に埋め込めるインタラクティブなHTMLビューアーを返します。
Jupyter Notebook等を利用している際には、第一候補になると思います。ただ、いったん一連の状態をすべて確保してから関数を呼ぶ必要があるので、あまり長いのは適していないかもしれません。
(また、ずっと画面内に表示しているとPCのファンが回りだすので、処理が重たいのかもしれません。)
引数 | 型 | 説明 |
---|---|---|
sys |
brax.System |
シミュレーションの系 |
qps |
List[brax.QP] |
一連のシミュレーションの状態 |
height |
int |
ビューアーの高さ |
from IPython.display import HTML
from brax import envs
from brax import jumpy as jp
from brax.io import html
ant = envs.create("ant")
state = ant.reset(jp.random_prngkey(seed=42))
HTML(html.render(ant.sys, [state.qp]))
追記 (2022/1/3)
BraxのHTML viewerがheight
を無視するバグがある気がしたので、issueを立てました。
2.2 RGB配列
brax.io.image.render_array(sys, qp, width, height, light=None, camera=None, ssaa=2)
は1ステップの状態をRGBの配列に変換します。
BraxはPyTinyRendererを利用しています。
引数 | 型 | 説明 |
---|---|---|
sys |
brax.System |
シミュレーションの系 |
qp |
brax.QP |
シミュレーションのとある状態 |
width |
int |
出力配列の幅 |
height |
int |
出力配列の高さ |
light |
Optional[pytinyrenderer.TinyRenderLight] |
光源の位置。None の時、固定されたデフォルト位置を利用する。 |
camera |
Optional[pytinyrenderer.TinyRenderCamera] |
カメラの位置とアングル。None の時、最適なカメラの位置を推定する。 |
ssaa |
int |
Super-Sampling Anti-Aliasing。最初に、ssaa 倍大きな画像をレンダリングして、その後にssaa * ssaa ピクセルを平均する。 |
from brax import envs
from brax import jumpy as jp
from brax.io import image
ant = envs.create("ant")
state = ant.reset(jp.random_prngkey(seed=42))
rgb = image.render_array(ant.sys, state.qp)
2.3 画像ファイル
brax.io.image.render(sys, qps, width, height, light=None, cameras=None, ssaa=2, fmt='png')
関数は、画像フォーマットのbytes
データを作成します。(注意: 画像をファイルに書き出すわけではなく、メモリ上に画像フォーマットのデータを作成するだけです。)
内部では、各ステップごとに、brax.io.image.render_array
でRGB配列を作成しています。
引数 | 型 | 説明 |
---|---|---|
sys |
brax.System |
シミュレーションの系 |
qps |
List[brax.QP] |
一連のシミュレーションの状態 |
width |
int |
出力配列の幅 |
height |
int |
出力配列の高さ |
light |
Optional[pytinyrenderer.TinyRenderLight] |
光源の位置。None の時、固定されたデフォルト位置を利用する。 |
cameras |
Optional[List[pytinyrenderer.TinyRenderCamera]] |
一連のカメラの位置とアングル。None の時、最適なカメラの位置を推定する。 |
ssaa |
int |
Super-Sampling Anti-Aliasing。最初に、ssaa 倍大きな画像をレンダリングして、その後にssaa * ssaa ピクセルを平均する。 |
fmt |
str |
画層ファイルエンコーディング |
3. ハードウェア・アクセラレータ
3.1 GPU
BraxがベースとしているJAXは自動的に利用可能なGPUを利用するので、特に追加のコードは不要です。
3.2 TPU
TPUをGoogle Corab上で利用するには、そのセットアップをする必要があります。
といってもセットアップ用の補助関数がJAXに用意されているので簡単です。
import os
if 'COLAB_TPU_ADDR' in os.environ:
from jax.tools import colab_tpu
colab_tpu.setup_tpu()
Google Colab上で、ランタイムのタイプを変更 > ハードウェア アクセラレータ > TPU
を選択すると、環境変数 COLAB_TPU_ADDR
が設定されます。
この環境変数の中に、実際にTPUが存在する IPアドレス:ポート
が書かれているので、その値を利用してセットアップされます。
環境変数が存在しないと、エラーになるので、ここでは環境変数チェックを挟んでいます。
4. おわりに
Google製の物理シミュレータBraxの調査をしてモデルの作成方法等、基本的な利用方法をまとめました。(本当はちゃんとシミュレーションを動かしてみて、コード例とともにってところまで、やりたかったのですが、長くなってきたのでここまでにしておきます。)
Discussion