非線形システムの最適制御を勾配法のJAX実装で解く
はじめに
の7.2節の勉強ノートです.
によるとJAXにより非線形最適制御が効率良く解けるとのことだったので, シンプルな勾配法(最急降下法・共役勾配法)での実装にチャレンジしてみました!
問題設定
状態ベクトルの次元
を考えます. ここで最終時刻は
初期値は
とします. 評価関数は
において
とします. すなわち
を最小化する最適制御を考えます.
制御入力がない場合
まず, 制御入力を与えない場合(
パラメータ設定
共通のパラメータはグローバル変数として定義しておきます
import jax
import jax.numpy as jnp
# 問題設定
S_f = jnp.array([[13, 0], [0, 1]], dtype=float)
Q = jnp.array([[13, 0], [0, 1]], dtype=float)
R = jnp.array([[1]], dtype=float)
x_0 = jnp.array([[0.5], [0]], dtype=float)
# 解く区間
t0, t1 = 0, 4
dt = 0.01
評価関数の定義
@jax.jit
def compute_J(x, u):
N = x.shape[0]
dt_ = (t1 - t0) / N
x_T = x[-1] # 最後の時刻の状態, 形状は (n, 1)
terminal_cost = 0.5 * jnp.matmul(x_T.T, jnp.matmul(S_f, x_T)).squeeze()
xQx = jnp.einsum("nkj,ki,nij->n", x, Q, x) # 形状は (N,)
uRu = jnp.einsum("nkj,ki,nij->n", u, R, u) # 形状は (N,)
integral_cost = 0.5 * jnp.sum(xQx + uRu) * dt_
J = terminal_cost + integral_cost
return J
diffrax で常微分方程式を解く際の共通設定
常微分方程式はJAXベースの微分方程式ソルバー提供ライブラリである diffrax
を使用します
import diffrax
N = 1000
ts = jnp.linspace(t0, t1, N)
solver = diffrax.Tsit5()
saveat = diffrax.SaveAt(ts=ts)
状態方程式の定義
@jax.jit
def function_f(x, u):
"""状態方程式における f"""
x1, x2 = x
u1 = u.squeeze()
return jnp.array([x2, -2 * (x1**3) + 2 * x1 - x2 + u1], dtype=float)
def vector_field_x(t, x, args):
u_t = args.evaluate(t)
return function_f(x, u_t)
state_eq = diffrax.ODETerm(vector_field_x)
制御入力に対する状態ベクトルの可視化
状態ベクトル
制御入力がない場合, 安定平衡点
import matplotlib.pyplot as plt
def plot_control(u):
# u を与えて状態方程式の解 x を求める
u_func = diffrax.LinearInterpolation(ts=ts, ys=u)
sol = diffrax.diffeqsolve(state_eq, solver, t0, t1, dt, x_0, args=u_func, saveat=saveat)
x = sol.ys
X = jnp.array(x).reshape(N, 2)
U = jnp.array(u).reshape(N, 1)
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
for k in range(2):
plt.plot(ts, X[:, k], label=f"x_{k}")
plt.axhline(1, color="blue", linestyle="--", linewidth=0.7)
plt.axhline(0, color="black", linestyle="--", linewidth=0.7)
plt.ylim(-0.5, 1.2)
plt.legend()
plt.subplot(1, 2, 2)
for k in range(1):
plt.plot(ts, U[:, k], linestyle="--", label="u")
plt.axhline(0, color="black", linestyle="--", linewidth=0.7)
plt.ylim(-3.0, 0.2)
plt.legend()
plt.show()
score = float(compute_J(x, u))
print(f"{score=}")
u = jnp.zeros((N, 1, 1))
plot_control(u)
score=28.224021911621094
オイラー・ラグランジュ方程式
ハミルトニアンを
となります.
第一式を状態方程式, 第二式を随伴方程式と言います. 拘束条件を表す第三式は本記事では以降は明示的に使用しません.
最急降下法
以下のアルゴリズムにより最適制御を求めます
- 適当な
を制御入力の初期推定解とするu -
を用いて状態方程式を解いてu を, 随伴方程式を解いてx を求める\lambda -
,x ,u から\lambda を計算する.\frac{\partial H}{\partial u} -
とおくs=-\left(\frac{\partial H}{\partial u}\right)^T - 制御入力を
としたときの評価関数値u+\alpha s が最小になるスカラーJ[u+\alpha s] を求め,\alpha と更新してステップ2に戻るu=u+\alpha s - ただし, 以下の条件を満たす場合は収束したとみなす
- 勾配のノルム
が十分小さい再場合\left(\int_{0}^{T}\left\|s(t)\right\|^2\, dt\right)^{\frac{1}{2}} - 制御の変更のノルム
が十分小さい場合\alpha \left(\int_{0}^{T}\left\|s(t)\right\|^2\, dt\right)^{\frac{1}{2}}
- 勾配のノルム
- ただし, 以下の条件を満たす場合は収束したとみなす
随伴方程式の定義
ハミルトニアンの偏微分を用いて書けることから, 自動微分を用いて定義します.
また, 時間逆方向に解くので符号を逆転しておきます.
@jax.jit
def hamiltonian(x, u, lambda_):
L = 0.5 * (jnp.matmul(x.T, jnp.matmul(Q, x)) + jnp.matmul(u.T, jnp.matmul(R, u)))
f = function_f(x, u)
H = L + jnp.matmul(lambda_.T, f)
return H.squeeze()
grad_H_x = jax.grad(hamiltonian, argnums=0)
def vector_field_lambda(t, lambda_, args):
x_t = args[0].evaluate(t)
u = args[1].evaluate(t)
dot_lambda = grad_H_x(x_t, u, lambda_)
return dot_lambda
lambda_eq = diffrax.ODETerm(vector_field_lambda)
目的関数の勾配計算
制御入力に対する変分として, ハミルトニアンの勾配を用いるのでこちらも自動微分を用いて定義します.
直線探索の最適値
@jax.jit
def compute_sequential_hamiltonian_and_gradients(x, u, lambda_):
"""
各時刻におけるハミルトニアンとその u に関する勾配を計算します。
Parameters:
x (jnp.array): 状態変数、形状は (N, n, 1)
u (jnp.array): 制御入力、形状は (N, m, 1)
lambda_ (jnp.array): ラグランジュ乗数、形状は (N, n, 1)
Returns:
Tuple[jnp.array, jnp.array]: ハミルトニアンの配列と勾配の配列
"""
# ベクトル化されたハミルトニアン関数
H = jax.vmap(hamiltonian, in_axes=(0, 0, 0))(x, u, lambda_)
# ベクトル化された勾配関数
hamiltonian_grad_u = jax.grad(hamiltonian, argnums=1)
grad_u = jax.vmap(hamiltonian_grad_u, in_axes=(0, 0, 0))(x, u, lambda_)
return H, grad_u
@jax.jit
def J_alpha(alpha, u, s):
"""argmin_{alpha} J(u + alpha * s) を算出する"""
u_new = u + alpha * s
u_func = diffrax.LinearInterpolation(ts=ts, ys=u_new)
sol = diffrax.diffeqsolve(state_eq, solver, t0, t1, dt, x_0, args=u_func, saveat=saveat)
x_new = sol.ys
return compute_J(x_new, u_new)
最急降下法アルゴリズムの実行
%%time
from scipy.optimize import minimize_scalar
u = jnp.zeros((N, 1, 1))
eps1 = 1e-2
eps2 = 1e-7
for i in range(10**3):
# u を与えて状態方程式の解 x を得る
u_func = diffrax.LinearInterpolation(ts=ts, ys=u)
sol = diffrax.diffeqsolve(state_eq, solver, t0, t1, dt, x_0, args=u_func, saveat=saveat)
x = sol.ys
# u, x を与えて随伴方程式の解 lambda_ を得る
lambda_tf = S_f @ x[-1]
x_func = diffrax.LinearInterpolation(ts=ts, ys=x[::-1])
sol = diffrax.diffeqsolve(lambda_eq, solver, t0, t1, dt, lambda_tf, args=[x_func, u_func], saveat=saveat)
lambda_ = sol.ys[::-1]
H, grad_u = compute_sequential_hamiltonian_and_gradients(x, u, lambda_)
grad_norm = jnp.sqrt(jnp.sum(grad_u**2) * dt)
if grad_norm < eps1:
print(f"roop done for {i=}, because {grad_norm=:.5f}")
break
s = -grad_u
alpha_opt = minimize_scalar(J_alpha, bounds=(0, 1), args=(u, s)).x
diff_norm = alpha_opt * jnp.sqrt(jnp.sum(s**2) * dt)
if diff_norm < eps2:
print(f"roop done for {i=}, because {diff_norm=:.8f}")
break
if i % 100 == 0:
score = compute_J(x, u)
print(f"{score=:.7f}, {grad_norm=:.5f}, {diff_norm=:.8f}")
u += alpha_opt * s
score=28.2240219, grad_norm=14.06202, diff_norm=3.41370487
score=2.0676792, grad_norm=0.72200, diff_norm=0.00191596
score=2.0527604, grad_norm=1.48579, diff_norm=0.00058045
score=2.0234804, grad_norm=0.36918, diff_norm=0.00208106
score=2.0040922, grad_norm=0.02795, diff_norm=0.00001824
score=2.0040870, grad_norm=0.07580, diff_norm=0.00002677
score=2.0040841, grad_norm=0.05358, diff_norm=0.00003470
roop done for i=682, because grad_norm=0.00986
CPU times: user 13.7 s, sys: 0 ns, total: 13.7 s
Wall time: 13.7 s
最急降下法の計算結果
不安定平衡点の近くに状態を定める制御が構成できています
plot_control(u)
score=2.0040838718414307
共役勾配法
次のアルゴリズムで最適制御を求めます
- 適当な
を制御入力の初期推定解,u とし,s_-=0 は適当(計算に使われない)に置くd_- - 制御入力を
としたときの状態方程式の解u , 随伴方程式の解x に対して\lambda とするd = -\frac{\partial H}{\partial u} - ポラック・リビエ・ポリャック法やフレッチャー・リーブス法により
を\beta から定めるd_-, d - ポラック・リビエ・ポリャック法のほうが収束が速かったのでそちらを採用します
-
とするs=d + \beta s_- - 制御入力を
と したときの評価関数値u+\alpha s が最小になるスカラーJ[u+\alpha s] を求め,\alpha とおくu=u+\alpha s - ただし, 以下の条件を満たす場合は収束したとみなす
- 勾配のノルム
が十分小さい再場合\left(\int_{0}^{T}\left\|d(t)\right\|^2\, dt\right)^{\frac{1}{2}} - 制御の変更のノルム
が十分小さい場合\alpha \left(\int_{0}^{T}\left\|d(t)\right\|^2\, dt\right)^{\frac{1}{2}}
- 勾配のノルム
- ただし, 以下の条件を満たす場合は収束したとみなす
-
と代入するd_-=d, s_-=s - 共役方向の誤差が蓄積するので定期的にリセットする(これよくわからなかったけどやらないと収束しない)
共役勾配法アルゴリズムの実行
%time
u = jnp.zeros((N, 1, 1))
s_ = jnp.zeros((N, 1, 1))
d_ = jnp.ones((N, 1, 1))
eps1 = 1e-2
eps2 = 1e-7
for i in range(10**3):
# u を与えて状態方程式の解 x を得る
u_func = diffrax.LinearInterpolation(ts=ts, ys=u)
sol = diffrax.diffeqsolve(state_eq, solver, t0, t1, dt, x_0, args=u_func, saveat=saveat)
x = sol.ys
# u, x を与えて随伴方程式の解 lambda_ を得る
lambda_tf = S_f @ x[-1]
x_func = diffrax.LinearInterpolation(ts=ts, ys=x[::-1])
sol = diffrax.diffeqsolve(lambda_eq, solver, t0, t1, dt, lambda_tf, args=[x_func, u_func], saveat=saveat)
lambda_ = sol.ys[::-1]
_, grad_u = compute_sequential_hamiltonian_and_gradients(x, u, lambda_)
d = -grad_u
grad_norm = jnp.sqrt(jnp.sum(d**2) * dt).squeeze()
if grad_norm < eps1:
print(f"roop done for {i=}, because {grad_norm=:.5f}")
break
beta = jnp.einsum("ijk, ijk", d, d - d_) / jnp.einsum("ijk, ijk", d_, d_) # ポラック・リビエ・ポリャック法
# beta = jnp.einsum("ijk, ijk", d, d) / jnp.einsum("ijk, ijk", d_, d_) # フレッチャー・リーブス法
s = d + beta * s_
alpha_opt = minimize_scalar(J_alpha, bounds=(0, 1), args=(u, s)).x
diff_norm = alpha_opt * jnp.sqrt(jnp.sum(s**2) * dt)
if diff_norm < eps2:
print(f"roop done for {i=}, because {diff_norm=:.8f}")
break
if i % 3 == 0:
# s_ には共役方向の誤差が蓄積するので定期的にリセットする(?)
s_ = jnp.zeros((N, 1, 1))
if i % 100 == 0:
score = compute_J(x, u)
print(f"{score=:.7f}, {grad_norm=:.5f}, {diff_norm=:.8f}")
u += alpha_opt * s
d_, s_ = d, s
score=28.2240219, grad_norm=14.06202, diff_norm=3.41370487
roop done for i=71, because grad_norm=0.00364
CPU times: user 1.52 s, sys: 0 ns, total: 1.52 s
Wall time: 1.55 s
共役勾配法の計算結果
最急降下法より速く収束判定され, こちらでも制御ができています
plot_control(u)
score=2.004102945327759
わからなかったことメモ
- 終端時刻
を大きくすると全然収束しなくなった。。T=4 - 勾配法に従っているのに、
grad_norm
やdiff_norm
が単調減少していない - 共役勾配法で定期的に
s
をゼロに置き直す必要があるのはなぜかわからなかった
Discussion