非線形最適制御:C/GMRESのJAX実装
本記事は東京大学の村山裕和(https://www.linkedin.com/in/裕和-村山-9b42252b1/ )による寄稿です。
はじめに
お久しぶりです。前回の記事を出してから一ヶ月以上経ってしまいました。iLQRの基礎理論解説の続編として、拡張ラグランジュ法により制約条件を考慮出来るようにしたAL-iLQRの解説記事を書こうとしているのですが、何故か論文通りに再現出来ず頭を抱えています。代わりと言ってはなんですが、今回は非線形最適制御/NMPC(Nonlinear Model Predictive Control)の手法として有名なC/GMRESのJAX実装について解説します。JAXを使わずとも実装は可能なのですが、JAXを使うと少しだけ嬉しい事があります。
この記事のシミュレーションコードはこちらの github リポジトリにも載せてあります。
C/GMRESとは?
非線形最適制御手法の一つです。京都大学の大塚敏之先生によって提案された手法であり、日本語の本や関連論文も充実しています。詳しい説明は大塚先生の著書『非線形最適制御入門』『実時間最適化による制御の実応用』を読んでいただければと思います。
また、個人的には以下の記事はC/GMRESを理解する上で非常に参考になりました。
- C/GMRESによる非線形モデル予測制御1(Hamachi,2024)
- C/GMRES法の例題実装(佐藤郁弥,2019)
- 非線形モデル予測制御におけるCGMRES法をpythonで実装する(MENDY,2019)
JAXについて
JAXはGoogleによって開発されたPythonライブラリです。機械学習に使われる事が多いようですが、JITコンパイル機能、並列計算機能や自動微分機能を備えており、その他の様々な数値計算にも適用する事が出来ます。特に今回は自動微分機能が威力を発揮します。
JAXを用いる意義
JAXを用いる事により、以下の様な微分計算を自動微分により実装する事が出来ます。
つまり、手計算による微分から解放され、数式が複雑な場合でも気にせず実装する事が出来ます。
また、C/GMRESでは以下の式が登場しました。
JAXによるC/GMRESの実装
では、早速実装して行きましょう。変数の文字や記号は大塚先生の著書と揃えています。
まずは下準備
最初にコントローラー本体のアルゴリズムと直接関係の無い所を実装してしまいます。
ライブラリのインポート
必要なライブラリを揃えます。JAX、及びJAXのNumpy機能をインポートします。その他、パラメータを格納するのに使うdataclass
をインポートしています。
import jax
import jax.numpy as jnp
from dataclasses import dataclass
ロボットの状態方程式
制御対象となる非ホロノミックロボットの状態方程式を作ります。入力は良くあるTwist型を用います。デコレーターで@jax.jitとする事でJITコンパイルが適用されるようになり、通常のPythonよりも高速な計算が可能となります。
(ただこのデコレータを全ての関数につける必要があるのかはイマイチ良く分かっていません。念のため全てに付けています。)
#連続状態方程式
@jax.jit
def model_func(x, u):
Bk = jnp.array([[jnp.cos(x[2]), 0],
[jnp.sin(x[2]), 0],
[0, 1]], dtype=jnp.float32)
x_dot = Bk @ u
return x_dot
パラメータ設定
予測ホライズンの長さ、評価関数の重みなど、主に人間が設計するパラメータを格納します。
# コントローラーに関するパラメーター
@dataclass
class Cont_Args:
# コントローラーのパラメータ
Ts = 0.02 # 制御周期
tf = 1.0 # 予測ホライズンの最終長さ
N = 50 # 予測ホライズンの分割数
#dt = Ts # 今回は意味の無いパラメータなので無視してください
alpha = 0.5 # 予測ホライズンを変化させる計算のパラメータ
zeta = 1 # U_dotを計算する時の係数パラメータ(zetaと書いてますがツェータです)
# 評価関数中の重み
# 状態変数の項
Q = jnp.array([[1, 0, 0],
[0, 1, 0],
[0, 0, 0]], dtype=jnp.float32) * 0
# 制御入力の項
R = jnp.array([[100,0],
[0,10]], dtype=jnp.float32)
# 最終地点の項
S = jnp.array([[1, 0, 0],
[0, 1, 0],
[0, 0, 0]], dtype=jnp.float32) * 100
# 目標地点
x_ob = jnp.array([5, 0, 0], dtype=jnp.float32)
# 目標入力
u_ob = jnp.array([0, 0], dtype=jnp.float32)
# 次元データ
obss_dim = 3 # 状態変数の次元
action_dim = 2 # 入力変数の次元
# 状態と入力
x = None
u = None
us = None
# 障害物の場所(中心)
ev_pos = jnp.array([[2.5,0.15,0]],dtype=jnp.float32)
# 障害物の半径
d_ = 0.3
# ロボットの半径
r_ = 0.1
# 障害物の中心から取るべき距離
d = d_ + r_
# 緩和対数バリア関数の緩和値
del_bar = 0.05
# 回避バリア関数の重み
r = 50
# 入力制限
umax = jnp.array([1.0,1.0],dtype=jnp.float32)
umin = jnp.array([-1.0,-1.0],dtype=jnp.float32)
# 計算用行列
bar_C = np.concatenate([jnp.eye(action_dim,dtype=jnp.float32),-jnp.eye(action_dim,dtype=jnp.float32)],0)
bar_d = np.concatenate([umax,-umin],0)
# 速度制限のバリア関数の重み
b = 10
args = Cont_Args()
評価関数の設計
制御入力の拘束条件を満たすよう、緩和対数バリア関数を使います。また、障害物を避ける為の回避関数としても緩和対数バリア関数を使います。それをステージコストと終端コストに組み込みます。後々の計算で使用するので、終端コストの微分関数も定義しておきます。
バリア関数については、以下の記事を参照しています。
終端コスト
# バリア関数(-logか二次関数かを勝手に切り替える)
@jax.jit
def barrier_z(z):
pred = z > args.del_bar
true_fun = lambda z: - jnp.log(z)
false_fun = lambda z: 0.5 * (((z - 2*args.del_bar) / args.del_bar)**2 - 1) - jnp.log(args.del_bar)
return jax.lax.cond(pred, true_fun, false_fun, z)
# バリア関数(全体)
@jax.jit
def barrier(u):
zs = args.bar_d - args.bar_C @ u
def vmap_fun(b, z, margin=0.5):
return b * jnp.where(z>=margin,barrier_z(margin),barrier_z(z))
Bars = jax.vmap(vmap_fun, (None,0))(args.b, zs)
Bar = jnp.sum(Bars)
return Bar
# 対数バリア関数型回避関数
@jax.jit
def evasion(x):
def vmap_fun(x, xe, r, d, margin=0.5):
distance = jnp.linalg.norm(x-xe,ord=2)
z = distance**2 - d**2
ref = d + margin
return r * jnp.where(distance>=ref, barrier_z(ref**2-d**2), barrier_z(z))
evas = jax.vmap(vmap_fun, (None, 0, None, None))(x, args.ev_pos, args.r, args.d)
eva = jnp.sum(evas)
return eva
# ステージコスト
@jax.jit
def stage_cost(x,u):
cost = 0.5 * ( (x-args.x_ob) @ args.Q @ (x-args.x_ob) \
+ (u-args.u_ob) @ args.R @ (u-args.u_ob)) \
+ evasion(x) \
+ barrier(u)
return cost
# 終端コスト
@jax.jit
def term_cost(x):
cost = 0.5 * (x-args.x_ob) @ args.S @ (x-args.x_ob)
return cost
# 終端コストの微分
grad_x_term = jax.jit(jax.grad(term_cost,0))
コントローラー本体の実装
いよいよC/GMRESのアルゴリズム本体を実装します。まずはコントローラー内部で定義される関数として、次の5つを作ります。
- rollout
- Hamilton
- Backward
- F
- GMRES
rollout
現在時刻の状態変数 x
と、制御入力の初期解 us
(自分の場合、前の制御ステップで算出した制御入力を使っています)から、将来の状態変数 xs
を予測する関数です。Forwardと呼んだ方がわかりやすいかもしれません。
状態方程式を
xs
は現在の状態変数
def rollout(x_init,us,dt):
def rollout_body(carry,u):
x = carry
x = x + model_func(x,u) * dt
return x, x
_, xs = jax.lax.scan(rollout_body, x_init, us)
xs = jnp.vstack([x_init[None], xs])
return xs
Hamilton
数学だと
随伴変数を
ここで、自動微分により
def Hamilton(x,u,lambda_):
H = stage_cost(x,u) + lambda_ @ model_func(x,u)
return H
dHdx = jax.grad(Hamilton,0)
dHdu = jax.grad(Hamilton,1)
Backward
制御入力の初期解 us
、及び rollout 関数で求めた xs
から、随伴変数lambdas
を求める関数です。自動微分で実装した
まず、次のように
そこから、以下の計算を繰り返します。
def Backward(xs,us,dt):
def Backward_body(carry,val):
lambda_ = carry
x, u = val
lambda_ = lambda_ + dHdx(x,u,lambda_) * dt
return lambda_, lambda_
lambda_ = grad_x_term(xs[-1])
_, out_lambdas = jax.lax.scan(Backward_body,lambda_,(jnp.flip(xs[1:-1],0),jnp.flip(us[1:],0)))
lambdas = jnp.flip(jnp.vstack([lambda_,out_lambdas]),axis=0)
return lambdas
F
大塚先生の著書や論文では jax.vmap
で並列計算させることで楽に実装出来ます。
自動微分を使い、
def F_(x,us,t):
us = jnp.reshape(us,(-1,args.action_dim)) # 計算の都合上、入力の時にusを横一列に並べ直しているので、ここで直す
dt = (1-jnp.exp(-args.alpha*t)) * args.tf/args.N # 予測ホライズンの分割幅を計算
xs = rollout(x,us,dt)
lambdas = Backward(xs,us,dt)
F = jax.vmap(dHdu,(0,0,0))(xs[:-1],us,lambdas)
F = jnp.reshape(F,(-1,))
return F
dFdU_ = jax.jacrev(F_,1)
dFdx_ = jax.jacrev(F_,0)
dFdt_ = jax.jacrev(F_,2)
GMRES
GMRES法(Generalized Minimal RESidual method/一般化残差最小法)は連立一次方程式を解く数値計算手法の一つです。手法自体の説明はここでは割愛します。自分は以下の資料を元にコーディングしました。
本当はJAXそのものにGMRES関数は存在します。ただ、それだとC/GMRESのキモとなる部分を既存の関数に頼る事になって面白みが無いのもあり、自作しました。
最後、連立一次方程式を解く部分なのですが、自作した後退代入の関数では何故か計算結果が荒れてしまいました。仕方なく、ここだけjax.numpy.linalg.solve
に頼る事にしました。計算誤差や前処理の問題でしょうか...?
#GMRES法関数(Ax=bの初期残差をrとする)
def GMRES(A, r, max_iter=5):
def arnoldi(A, v1, m):
n = v1.shape[0]
Vm_1 = jnp.zeros((n, m+1))
H = jnp.zeros((m+1, m))
Vm_1 = Vm_1.at[:, 0].set(v1)
def body_fun(j, val):
Vm_1, H = val
v = A @ Vm_1[:, j]
def body_in(i,val):
Vm_1, H, v = val
H = H.at[i, j].set(jnp.dot(Vm_1[:, i], v))
v = v - H[i, j] * Vm_1[:, i]
return Vm_1, H, v
Vm_1, H, v = jax.lax.fori_loop(0,j+1,body_in,(Vm_1,H,v))
H = H.at[j+1, j].set(jnp.linalg.norm(v))
Vm_1 = Vm_1.at[:, j+1].set(v / H[j+1, j])
return Vm_1, H
Vm_1, H = jax.lax.fori_loop(0, m, body_fun, (Vm_1, H))
return Vm_1, H
def givens_rotation(v1, v2):
t = jnp.sqrt(v1**2 + v2**2)
c = v1 / t
s = -v2 / t
return c, s
#n = r.shape[0]
Vm_1, H = arnoldi(A, r / jnp.linalg.norm(r), max_iter)
beta = jnp.linalg.norm(r)
e1 = jnp.zeros(max_iter + 1)
e1 = e1.at[0].set(beta)
def body_fun(i, val):
H, e1 = val
c, s = givens_rotation(H[i, i], H[i+1, i])
Givens = jnp.array([[c, s], [-s, c]])
H_col = jax.lax.dynamic_slice(H,(i,0),(2,max_iter))
H = jax.lax.dynamic_update_slice(H, Givens @ H_col, (i,0))
e1_slice = jax.lax.dynamic_slice(e1,(i,),(2,))
e1 = jax.lax.dynamic_update_slice(e1, Givens @ e1_slice, (i,))
return H, e1
H, e1 = jax.lax.fori_loop(0, max_iter, body_fun, (H, e1))
y = jnp.linalg.solve(H[:max_iter, :max_iter], e1[:max_iter])
x = Vm_1[:, :max_iter] @ y
return x
コントローラー全体の実装
最後に、この5つの関数を使い、コントローラーを完成させます。この記事の最後にシミュレーションも含めたコードを置いてあるので、自分のPCで試す際はそちらを使ってください。
# コントローラー関数
@jax.jit
def CGMRES_control(x,us,t):
# ~中略~(ここで5つの関数を定義する)
us_ = jnp.reshape(us,(-1,)) # Fの計算の都合上、横一列に並べ直す
W = 0.1 * jnp.ones((args.action_dim*args.N),dtype=jnp.float32) # GMRES法における初期解
x_dot = model_func(x,us[0])
dFdU = dFdU_(x,us_,t) # 自動微分が使えるので、差分近似を使わずに計算できる
dFdx = dFdx_(x,us_,t)
dFdt = dFdt_(x,us_,t)
r = - args.zeta * F_(x,us_,t) - dFdx @ x_dot - dFdt - dFdU @ W
kai = GMRES(dFdU,r)
U_dot = W + kai
U_dot = jnp.reshape(U_dot,(-1,args.action_dim))
U = us + U_dot * args.Ts
return U
遊んでみる
それでは、出来たコントローラーを使って制御を簡単にシミュレーションしてみましょう。アニメーションにしてみると、障害物を避けて目的地に向かっているのがわかります。
最後に
学部の時はC/GMRESに触れる機会が多かったのですが、修士になってからiLQRばかり使っておりました。ふとC/GMRESをもう一度やってみようと思い書いてみましたが、思いのほか上手く行って何よりです。MPCやるといつも実感しますが、パラメータ調整は本当に厄介ですね。ある問題設定で上手く行っても他の問題設定だときちんと避けきってくれなかったりとかザラにありますから。機会があれば、ここら辺のオートチューニング機能とかも調べて記事にしてみたいですね。
実は、コード作るのに少しだけclaude.aiの力を借りました。勿論、きちんとしたコードを出力してもらうには細かく手法や仕様を指示する必要があるので、理解しないままコードだけ作らせるというのは難しいです。とは言え、そこがある程度出来れば、自分の思いつかなかった楽な実装法を教えてくれたりするのでとても助かります。凄い時代になったとしみじみ感じます。
参考文献
C/GMRESについて
- 非線形最適制御入門(大塚敏之,2011)
- 実時間最適化による制御の実応用(大塚敏之,2015)
- C/GMRESによる非線形モデル予測制御1(Hamachi,2024)
- C/GMRES法の例題実装(佐藤郁弥,2019)
- 非線形モデル予測制御におけるCGMRES法をpythonで実装する(MENDY,2019)
- 大規模連立1次方程式に対する一般化残差最小法について(茂木渉,篠原正明,2010)
バリア関数について
- モデル予測制御における対数バリア関数の実装(絶望M2生,2023)
- Relaxed Logarithmic Barrier Function Based Model Predictive Control of Linear Systems (C.Feller and C. Ebenbauer, 2015, arxiv)
シミュレーションコード全体
長いので、こちらの github リポジトリにも載せてあります。
import jax
import jax.numpy as jnp
import numpy as np
import time
from dataclasses import dataclass
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
from matplotlib.patches import Circle
#連続状態方程式
@jax.jit
def model_func(x, u):
Bk = jnp.array([[jnp.cos(x[2]), 0],
[jnp.sin(x[2]), 0],
[0, 1]], dtype=jnp.float32)
x_dot = Bk @ u
return x_dot
#model_dfdx = jax.jit(jax.jacfwd(model_func,0))
#model_dfdu = jax.jit(jax.jacfwd(model_func,1))
# コントローラーに関するパラメーター
@dataclass
class Cont_Args:
# コントローラーのパラメータ
Ts = 0.02 # 制御周期
tf = 1.0 # 予測ホライズンの最終長さ
N = 50 # 予測ホライズンの分割数
#dt = Ts # 今回は意味の無いパラメータなので無視してください
alpha = 0.5 # 予測ホライズンを変化させる計算のパラメータ
zeta = 1 # U_dotを計算する時の係数パラメータ(zetaと書いてますがツェータです)
# 評価関数中の重み
# 状態変数の項
Q = jnp.array([[1, 0, 0],
[0, 1, 0],
[0, 0, 0]], dtype=jnp.float32) * 0
# 制御入力の項
R = jnp.array([[100,0],
[0,10]], dtype=jnp.float32)
# 最終地点の項
S = jnp.array([[1, 0, 0],
[0, 1, 0],
[0, 0, 0]], dtype=jnp.float32) * 100
# 目標地点
x_ob = jnp.array([5, 0, 0], dtype=jnp.float32)
# 目標入力
u_ob = jnp.array([0, 0], dtype=jnp.float32)
# 次元データ
obss_dim = 3 # 状態変数の次元
action_dim = 2 # 入力変数の次元
# 状態と入力
x = None
u = None
us = None
# 障害物の場所(中心)
ev_pos = jnp.array([[2.5,0.15,0]],dtype=jnp.float32)
# 障害物の半径
d_ = 0.3
# ロボットの半径
r_ = 0.1
# 障害物の中心から取るべき距離
d = d_ + r_
# 緩和対数バリア関数の緩和値
del_bar = 0.05
# 回避バリア関数の重み
r = 50
# 入力制限
umax = jnp.array([1.0,1.0],dtype=jnp.float32)
umin = jnp.array([-1.0,-1.0],dtype=jnp.float32)
# 計算用行列
bar_C = jnp.concatenate([jnp.eye(action_dim,dtype=jnp.float32),-jnp.eye(action_dim,dtype=jnp.float32)],0)
bar_d = jnp.concatenate([umax,-umin],0)
# 速度制限のバリア関数の重み
b = 10
args = Cont_Args()
# バリア関数(-logか二次関数かを勝手に切り替える)
@jax.jit
def barrier_z(z):
pred = z > args.del_bar
true_fun = lambda z: - jnp.log(z)
false_fun = lambda z: 0.5 * (((z - 2*args.del_bar) / args.del_bar)**2 - 1) - jnp.log(args.del_bar)
return jax.lax.cond(pred, true_fun, false_fun, z)
# バリア関数(全体)
@jax.jit
def barrier(u):
zs = args.bar_d - args.bar_C @ u
def vmap_fun(b, z, margin=0.5):
return b * jnp.where(z>=margin,barrier_z(margin),barrier_z(z))
Bars = jax.vmap(vmap_fun, (None,0))(args.b, zs)
Bar = jnp.sum(Bars)
return Bar
# 対数バリア関数型回避関数
@jax.jit
def evasion(x):
def vmap_fun(x, xe, r, d, margin=0.5):
distance = jnp.linalg.norm(x-xe,ord=2)
z = distance**2 - d**2
ref = d + margin
return r * jnp.where(distance>=ref, barrier_z(ref**2-d**2), barrier_z(z))
evas = jax.vmap(vmap_fun, (None, 0, None, None))(x, args.ev_pos, args.r, args.d)
eva = jnp.sum(evas)
return eva
# ステージコスト
@jax.jit
def stage_cost(x,u):
cost = 0.5 * ( (x-args.x_ob) @ args.Q @ (x-args.x_ob) \
+ (u-args.u_ob) @ args.R @ (u-args.u_ob)) \
+ evasion(x) \
+ barrier(u)
return cost
# 終端コスト
@jax.jit
def term_cost(x):
cost = 0.5 * (x-args.x_ob) @ args.S @ (x-args.x_ob)
return cost
# ステージコストの微分
#grad_x_stage = jax.jit(jax.grad(stage_cost,0))
#grad_u_stage = jax.jit(jax.grad(stage_cost,1))
#hes_x_stage = jax.jit(jax.hessian(stage_cost,0))
#hes_u_stage = jax.jit(jax.hessian(stage_cost,1))
#hes_ux_stage = jax.jit(jax.jacfwd(jax.grad(stage_cost,1),0))
# 終端コストの微分
grad_x_term = jax.jit(jax.grad(term_cost,0))
#hes_x_term = jax.jit(jax.hessian(term_cost,0))
# コントローラー関数
@jax.jit
def CGMRES_control(x,us,t):
def rollout(x_init,us,dt):
def rollout_body(carry,u):
x = carry
x = x + model_func(x,u) * dt
return x, x
_, xs = jax.lax.scan(rollout_body, x_init, us)
xs = jnp.vstack([x_init[None], xs])
return xs
def Hamilton(x,u,lambda_):
H = stage_cost(x,u) + lambda_ @ model_func(x,u)
return H
dHdx = jax.grad(Hamilton,0)
dHdu = jax.grad(Hamilton,1)
def Backward(xs,us,dt):
def Backward_body(carry,val):
lambda_ = carry
x, u = val
lambda_ = lambda_ + dHdx(x,u,lambda_) * dt
return lambda_, lambda_
lambda_ = grad_x_term(xs[-1])
_, out_lambdas = jax.lax.scan(Backward_body,lambda_,(jnp.flip(xs[1:-1],0),jnp.flip(us[1:],0)))
lambdas = jnp.flip(jnp.vstack([lambda_,out_lambdas]),axis=0)
return lambdas
def F_(x,us,t):
us = jnp.reshape(us,(-1,args.action_dim)) # 計算の都合上、入力の時にusを横一列に並べ直しているので、ここで直す
dt = (1-jnp.exp(-args.alpha*t)) * args.tf/args.N # 予測ホライズンの分割幅を計算
xs = rollout(x,us,dt)
lambdas = Backward(xs,us,dt)
F = jax.vmap(dHdu,(0,0,0))(xs[:-1],us,lambdas)
F = jnp.reshape(F,(-1,))
return F
dFdU_ = jax.jacrev(F_,1)
dFdx_ = jax.jacrev(F_,0)
dFdt_ = jax.jacrev(F_,2)
#GMRES法関数(Ax=bの初期残差をrとする)
def GMRES(A, r, max_iter=5):
def arnoldi(A, v1, m):
n = v1.shape[0]
Vm_1 = jnp.zeros((n, m+1))
H = jnp.zeros((m+1, m))
Vm_1 = Vm_1.at[:, 0].set(v1)
def body_fun(j, val):
Vm_1, H = val
v = A @ Vm_1[:, j]
def body_in(i,val):
Vm_1, H, v = val
H = H.at[i, j].set(jnp.dot(Vm_1[:, i], v))
v = v - H[i, j] * Vm_1[:, i]
return Vm_1, H, v
Vm_1, H, v = jax.lax.fori_loop(0,j+1,body_in,(Vm_1,H,v))
H = H.at[j+1, j].set(jnp.linalg.norm(v))
Vm_1 = Vm_1.at[:, j+1].set(v / H[j+1, j])
return Vm_1, H
Vm_1, H = jax.lax.fori_loop(0, m, body_fun, (Vm_1, H))
return Vm_1, H
def givens_rotation(v1, v2):
t = jnp.sqrt(v1**2 + v2**2)
c = v1 / t
s = -v2 / t
return c, s
#n = r.shape[0]
Vm_1, H = arnoldi(A, r / jnp.linalg.norm(r), max_iter)
beta = jnp.linalg.norm(r)
e1 = jnp.zeros(max_iter + 1)
e1 = e1.at[0].set(beta)
def body_fun(i, val):
H, e1 = val
c, s = givens_rotation(H[i, i], H[i+1, i])
Givens = jnp.array([[c, s], [-s, c]])
H_col = jax.lax.dynamic_slice(H,(i,0),(2,max_iter))
H = jax.lax.dynamic_update_slice(H, Givens @ H_col, (i,0))
e1_slice = jax.lax.dynamic_slice(e1,(i,),(2,))
e1 = jax.lax.dynamic_update_slice(e1, Givens @ e1_slice, (i,))
return H, e1
H, e1 = jax.lax.fori_loop(0, max_iter, body_fun, (H, e1))
y = jnp.linalg.solve(H[:max_iter, :max_iter], e1[:max_iter])
x = Vm_1[:, :max_iter] @ y
return x
# ここから本計算
us_ = jnp.reshape(us,(-1,)) # Fの計算の都合上、横一列に並べ直す
W = 0.1 * jnp.ones((args.action_dim*args.N),dtype=jnp.float32) # GMRES法における初期解
x_dot = model_func(x,us[0])
dFdU = dFdU_(x,us_,t) # 自動微分が使えるので、差分近似を使わずに計算できる
dFdx = dFdx_(x,us_,t)
dFdt = dFdt_(x,us_,t)
r = - args.zeta * F_(x,us_,t) - dFdx @ x_dot - dFdt - dFdU @ W
kai = GMRES(dFdU,r)
U_dot = W + kai
U_dot = jnp.reshape(U_dot,(-1,args.action_dim))
U = us + U_dot * args.Ts
return U
# 初期条件
args.u = jnp.zeros((args.action_dim), dtype=jnp.float32)
args.us = jnp.zeros((args.N, args.action_dim), dtype=jnp.float32)
args.x = jnp.zeros((args.obss_dim), dtype=jnp.float32)
Time = 0.0
time_stamp = []
x_log = []
u_log = []
start = time.time()
while Time <= 20:
print("-------------Position-------------")
print(args.x)
print("-------------Input-------------")
print(args.u)
time_stamp.append(Time)
x_log.append(args.x)
u_log.append(args.u)
us = CGMRES_control(args.x,args.us,Time)
#print(_)
x_dot = model_func(args.x,args.u)
x = args.x + x_dot * args.Ts
Time += args.Ts
args.x = x
args.u = us[0]
args.us = us
end = time.time()
loop_time = end - start
print("計算時間:{}[s]".format(loop_time))
time_log = np.array(time_stamp)
x_log = np.array(x_log)
u_log = np.array(u_log)
fig = plt.figure()
ax = plt.axes()
ax.set_xlim(-1,6)
ax.set_ylim(-2,2)
plt.axis("equal")
robot = Circle(xy=x_log[0][:2], radius=args.r_, fill=False)
ax.add_artist(robot)
line, = ax.plot([],[],'r-',lw=2)
def update(frame):
x , y, theta = x_log[frame]
robot.center = (x,y)
line_x = [x, x+args.r_*np.cos(theta)]
line_y = [y, y+args.r_*np.sin(theta)]
line.set_data(line_x,line_y)
return robot, line
obstacle = Circle(xy=args.ev_pos[0], radius=args.d_, ec="k")
ax.add_artist(obstacle)
anim = FuncAnimation(fig, update, frames=501, interval=20, blit=True)
writer = PillowWriter(fps=50) # fpsはフレームレートを指定
anim.save("CGMRES.gif", writer=writer)
fig2 = plt.figure()
ax1 = fig2.add_subplot(211)
ax2 = fig2.add_subplot(212)
ax1.plot(time_log,u_log[:,0])
ax2.plot(time_log,x_log[:,1])
plt.show()
Discussion