🚀

非線形最適制御:C/GMRESのJAX実装

2024/07/28に公開

本記事は東京大学の村山裕和(https://www.linkedin.com/in/裕和-村山-9b42252b1/ )による寄稿です。

はじめに

お久しぶりです。前回の記事を出してから一ヶ月以上経ってしまいました。iLQRの基礎理論解説の続編として、拡張ラグランジュ法により制約条件を考慮出来るようにしたAL-iLQRの解説記事を書こうとしているのですが、何故か論文通りに再現出来ず頭を抱えています。代わりと言ってはなんですが、今回は非線形最適制御/NMPC(Nonlinear Model Predictive Control)の手法として有名なC/GMRESのJAX実装について解説します。JAXを使わずとも実装は可能なのですが、JAXを使うと少しだけ嬉しい事があります。

この記事のシミュレーションコードはこちらの github リポジトリにも載せてあります。

C/GMRESとは?

非線形最適制御手法の一つです。京都大学の大塚敏之先生によって提案された手法であり、日本語の本や関連論文も充実しています。詳しい説明は大塚先生の著書『非線形最適制御入門』『実時間最適化による制御の実応用』を読んでいただければと思います。

また、個人的には以下の記事はC/GMRESを理解する上で非常に参考になりました。

JAXについて

JAXはGoogleによって開発されたPythonライブラリです。機械学習に使われる事が多いようですが、JITコンパイル機能、並列計算機能や自動微分機能を備えており、その他の様々な数値計算にも適用する事が出来ます。特に今回は自動微分機能が威力を発揮します。

JAXを用いる意義

JAXを用いる事により、以下の様な微分計算を自動微分により実装する事が出来ます

\frac{\partial H}{\partial u}, \frac{\partial H}{\partial x}, \frac{\partial f}{\partial x}, \frac{\partial f}{\partial u}...

つまり、手計算による微分から解放され、数式が複雑な場合でも気にせず実装する事が出来ます

また、C/GMRESでは以下の式が登場しました。

\frac{\partial F}{\partial U} \dot{U} = - \zeta F -\frac{\partial F}{\partial x} \dot{x} - \frac{\partial F}{\partial t}

\partial F / \partial U\partial F / \partial x を計算するのが難しいため、従来のアルゴリズムでは、これらを直接求めずに差分近似を用いていました。ここも、自動微分によって \partial F / \partial U\partial F / \partial x を直接計算する事が出来るようになります。

JAXによるC/GMRESの実装

では、早速実装して行きましょう。変数の文字や記号は大塚先生の著書と揃えています。

まずは下準備

最初にコントローラー本体のアルゴリズムと直接関係の無い所を実装してしまいます。

ライブラリのインポート

必要なライブラリを揃えます。JAX、及びJAXのNumpy機能をインポートします。その他、パラメータを格納するのに使うdataclassをインポートしています。

import jax
import jax.numpy as jnp
from dataclasses import dataclass

ロボットの状態方程式

制御対象となる非ホロノミックロボットの状態方程式を作ります。入力は良くあるTwist型を用います。デコレーターで@jax.jitとする事でJITコンパイルが適用されるようになり、通常のPythonよりも高速な計算が可能となります。

(ただこのデコレータを全ての関数につける必要があるのかはイマイチ良く分かっていません。念のため全てに付けています。)

#連続状態方程式
@jax.jit
def model_func(x, u):
    Bk = jnp.array([[jnp.cos(x[2]), 0],
                    [jnp.sin(x[2]), 0],
                    [0, 1]], dtype=jnp.float32)
    x_dot = Bk @ u
    return x_dot

パラメータ設定

予測ホライズンの長さ、評価関数の重みなど、主に人間が設計するパラメータを格納します。

# コントローラーに関するパラメーター
@dataclass
class Cont_Args:

    # コントローラーのパラメータ
    Ts = 0.02  # 制御周期
    tf = 1.0  # 予測ホライズンの最終長さ
    N = 50  # 予測ホライズンの分割数
    #dt = Ts  # 今回は意味の無いパラメータなので無視してください
    alpha = 0.5  # 予測ホライズンを変化させる計算のパラメータ
    zeta = 1  # U_dotを計算する時の係数パラメータ(zetaと書いてますがツェータです)


    # 評価関数中の重み
    # 状態変数の項
    Q = jnp.array([[1, 0, 0],
                   [0, 1, 0],
                   [0, 0, 0]], dtype=jnp.float32) * 0

    # 制御入力の項
    R = jnp.array([[100,0],
                   [0,10]], dtype=jnp.float32)

    # 最終地点の項
    S = jnp.array([[1, 0, 0],
                   [0, 1, 0],
                   [0, 0, 0]], dtype=jnp.float32) * 100

    # 目標地点
    x_ob = jnp.array([5, 0, 0], dtype=jnp.float32)

    # 目標入力
    u_ob = jnp.array([0, 0], dtype=jnp.float32)

    # 次元データ
    obss_dim = 3  # 状態変数の次元
    action_dim = 2  # 入力変数の次元

    # 状態と入力
    x = None
    u = None
    us = None

    # 障害物の場所(中心) 
    ev_pos = jnp.array([[2.5,0.15,0]],dtype=jnp.float32)
    # 障害物の半径
    d_ = 0.3
    # ロボットの半径
    r_ = 0.1
    # 障害物の中心から取るべき距離
    d = d_ + r_

    # 緩和対数バリア関数の緩和値
    del_bar = 0.05

    # 回避バリア関数の重み
    r = 50

    # 入力制限
    umax = jnp.array([1.0,1.0],dtype=jnp.float32)
    umin = jnp.array([-1.0,-1.0],dtype=jnp.float32)

    # 計算用行列
    bar_C = np.concatenate([jnp.eye(action_dim,dtype=jnp.float32),-jnp.eye(action_dim,dtype=jnp.float32)],0)
    bar_d = np.concatenate([umax,-umin],0)

    # 速度制限のバリア関数の重み
    b = 10

args = Cont_Args()

評価関数の設計

制御入力の拘束条件を満たすよう、緩和対数バリア関数を使います。また、障害物を避ける為の回避関数としても緩和対数バリア関数を使います。それをステージコストと終端コストに組み込みます。後々の計算で使用するので、終端コストの微分関数も定義しておきます。

バリア関数については、以下の記事を参照しています。

終端コスト \phi(x) とステージコスト l(x,u) を数式で提示しておきます。x_{ob} はゴール地点、u_{ob} は目標とする制御入力値です。つまり、なるべく u_{ob} に近い制御入力で目標地点 x_{ob} に向かう制御を目指す評価関数です。

\phi(x) = \frac{1}{2} (x-x_{ob})^{T} S (x-x_{ob}) \\ l(x,u) = \frac{1}{2} (x-x_{ob})^{T} Q (x-x_{ob}) + \frac{1}{2} (u-u_{ob})^{T} R (u-u_{ob}) + \text{Barrier function}
# バリア関数(-logか二次関数かを勝手に切り替える)
@jax.jit
def barrier_z(z):
    pred = z > args.del_bar
    true_fun = lambda z: - jnp.log(z)
    false_fun = lambda z: 0.5 * (((z - 2*args.del_bar) / args.del_bar)**2 - 1) - jnp.log(args.del_bar)
    return jax.lax.cond(pred, true_fun, false_fun, z)


# バリア関数(全体)
@jax.jit
def barrier(u):

    zs = args.bar_d - args.bar_C @ u

    def vmap_fun(b, z, margin=0.5):
        return b * jnp.where(z>=margin,barrier_z(margin),barrier_z(z))

    Bars = jax.vmap(vmap_fun, (None,0))(args.b, zs)
    Bar = jnp.sum(Bars)
    return Bar


# 対数バリア関数型回避関数
@jax.jit
def evasion(x):

    def vmap_fun(x, xe, r, d, margin=0.5):
        distance = jnp.linalg.norm(x-xe,ord=2)
        z = distance**2 - d**2
        ref = d + margin
        return r * jnp.where(distance>=ref, barrier_z(ref**2-d**2), barrier_z(z))

    evas = jax.vmap(vmap_fun, (None, 0, None, None))(x, args.ev_pos, args.r, args.d)
    eva = jnp.sum(evas)
    return eva



# ステージコスト
@jax.jit
def stage_cost(x,u):
    cost = 0.5 * ( (x-args.x_ob) @ args.Q @ (x-args.x_ob) \
                  + (u-args.u_ob) @ args.R @ (u-args.u_ob))  \
                  + evasion(x) \
                  + barrier(u)
    return cost

# 終端コスト
@jax.jit
def term_cost(x):
    cost = 0.5 * (x-args.x_ob) @ args.S @ (x-args.x_ob)
    return cost

# 終端コストの微分
grad_x_term = jax.jit(jax.grad(term_cost,0))

コントローラー本体の実装

いよいよC/GMRESのアルゴリズム本体を実装します。まずはコントローラー内部で定義される関数として、次の5つを作ります。

  • rollout
  • Hamilton
  • Backward
  • F
  • GMRES

rollout

現在時刻の状態変数 x と、制御入力の初期解 us (自分の場合、前の制御ステップで算出した制御入力を使っています)から、将来の状態変数 xs を予測する関数です。Forwardと呼んだ方がわかりやすいかもしれません。

状態方程式を \dot{x} = f(x,u) として、以下の計算を予測ホライズンの末端x_{N}まで繰り返します。

x_{k+1} = x_{k} + f(x_{k},u_{k}) \Delta \tau

xsは現在の状態変数 x_{0} と、予測した将来の状態変数 x_{1} \cdots x_{N} を格納した変数です。

def rollout(x_init,us,dt):

    def rollout_body(carry,u):
        x = carry
        x = x + model_func(x,u) * dt
        return x, x

    _, xs = jax.lax.scan(rollout_body, x_init, us)
    xs = jnp.vstack([x_init[None], xs])

    return xs

Hamilton

数学だとHと書かれる事が多いハミルトン関数の実装です。大塚先生の書籍での定義通りにコーディングします。ここで、先ほど作ったステージコストの関数を使います。

随伴変数を \lambdaとすれば、以下の様になります。

H(x_{k},u_{k},\lambda_{k+1},t) = l(x_{k},u_{k}) + \lambda_{k+1}f(x_{k},u_{k})

ここで、自動微分により \partial H/ \partial u\partial H/ \partial x の関数を作ります。これにより、これらを手計算で求める必要が無くなります。

def Hamilton(x,u,lambda_):
    H = stage_cost(x,u) + lambda_ @ model_func(x,u)
    return H

dHdx = jax.grad(Hamilton,0)
dHdu = jax.grad(Hamilton,1)

Backward

制御入力の初期解 us、及び rollout 関数で求めた xs から、随伴変数lambdasを求める関数です。自動微分で実装した \partial H / \partial x 関数の出番です。同じく自動微分で実装した終端コストの微分関数も登場します。

まず、次のように \lambda_{N} を求めます。

\lambda_{N} = \frac{\partial \phi}{\partial x}(x_{N})

そこから、以下の計算を繰り返します。

\lambda_{k} = \lambda_{k+1} + \frac{\partial H}{\partial x}(x_{k},u_{k},\lambda_{k+1},t) \Delta \tau
def Backward(xs,us,dt):

    def Backward_body(carry,val):
        lambda_ = carry
        x, u = val
        lambda_ = lambda_ + dHdx(x,u,lambda_) * dt
        return lambda_, lambda_

    lambda_ = grad_x_term(xs[-1])

    _, out_lambdas = jax.lax.scan(Backward_body,lambda_,(jnp.flip(xs[1:-1],0),jnp.flip(us[1:],0)))
    lambdas = jnp.flip(jnp.vstack([lambda_,out_lambdas]),axis=0)

    return lambdas

F

大塚先生の著書や論文では F と表されている関数です。\partial H / \partial u を予測ホライズンの最後まで並べます。自動微分で実装した \partial H/ \partial u 関数をjax.vmapで並列計算させることで楽に実装出来ます。

自動微分を使い、\partial F / \partial U\partial F / \partial x\partial F / \partial t も実装します。

def F_(x,us,t):

    us = jnp.reshape(us,(-1,args.action_dim))  # 計算の都合上、入力の時にusを横一列に並べ直しているので、ここで直す
    dt = (1-jnp.exp(-args.alpha*t)) * args.tf/args.N  # 予測ホライズンの分割幅を計算

    xs = rollout(x,us,dt)

    lambdas = Backward(xs,us,dt)

    F = jax.vmap(dHdu,(0,0,0))(xs[:-1],us,lambdas)
    F = jnp.reshape(F,(-1,))

    return F

dFdU_ = jax.jacrev(F_,1)
dFdx_ = jax.jacrev(F_,0)
dFdt_ = jax.jacrev(F_,2)

GMRES

GMRES法(Generalized Minimal RESidual method/一般化残差最小法)は連立一次方程式を解く数値計算手法の一つです。手法自体の説明はここでは割愛します。自分は以下の資料を元にコーディングしました。

本当はJAXそのものにGMRES関数は存在します。ただ、それだとC/GMRESのキモとなる部分を既存の関数に頼る事になって面白みが無いのもあり、自作しました。

最後、連立一次方程式を解く部分なのですが、自作した後退代入の関数では何故か計算結果が荒れてしまいました。仕方なく、ここだけjax.numpy.linalg.solveに頼る事にしました。計算誤差や前処理の問題でしょうか...?

#GMRES法関数(Ax=bの初期残差をrとする)
def GMRES(A, r, max_iter=5):

    def arnoldi(A, v1, m):
        n = v1.shape[0]
        Vm_1 = jnp.zeros((n, m+1))
        H = jnp.zeros((m+1, m))
        Vm_1 = Vm_1.at[:, 0].set(v1)

        def body_fun(j, val):
            Vm_1, H = val
            v = A @ Vm_1[:, j]

            def body_in(i,val):
                Vm_1, H, v = val
                H = H.at[i, j].set(jnp.dot(Vm_1[:, i], v))
                v = v - H[i, j] * Vm_1[:, i]
                return Vm_1, H, v

            Vm_1, H, v = jax.lax.fori_loop(0,j+1,body_in,(Vm_1,H,v))

            H = H.at[j+1, j].set(jnp.linalg.norm(v))
            Vm_1 = Vm_1.at[:, j+1].set(v / H[j+1, j])
            return Vm_1, H

        Vm_1, H = jax.lax.fori_loop(0, m, body_fun, (Vm_1, H))
        return Vm_1, H

    def givens_rotation(v1, v2):
        t = jnp.sqrt(v1**2 + v2**2)
        c = v1 / t
        s = -v2 / t
        return c, s

    #n = r.shape[0]
    Vm_1, H = arnoldi(A, r / jnp.linalg.norm(r), max_iter)

    beta = jnp.linalg.norm(r)
    e1 = jnp.zeros(max_iter + 1)
    e1 = e1.at[0].set(beta)

    def body_fun(i, val):
        H, e1 = val
        c, s = givens_rotation(H[i, i], H[i+1, i])
        Givens = jnp.array([[c, s], [-s, c]])
        H_col = jax.lax.dynamic_slice(H,(i,0),(2,max_iter))
        H = jax.lax.dynamic_update_slice(H, Givens @ H_col, (i,0))
        e1_slice = jax.lax.dynamic_slice(e1,(i,),(2,))
        e1 = jax.lax.dynamic_update_slice(e1, Givens @ e1_slice, (i,))
        return H, e1

    H, e1 = jax.lax.fori_loop(0, max_iter, body_fun, (H, e1))

    y = jnp.linalg.solve(H[:max_iter, :max_iter], e1[:max_iter])
    x = Vm_1[:, :max_iter] @ y

    return x

コントローラー全体の実装

最後に、この5つの関数を使い、コントローラーを完成させます。この記事の最後にシミュレーションも含めたコードを置いてあるので、自分のPCで試す際はそちらを使ってください。

# コントローラー関数
@jax.jit
def CGMRES_control(x,us,t):

    # ~中略~(ここで5つの関数を定義する)

    us_ = jnp.reshape(us,(-1,))  # Fの計算の都合上、横一列に並べ直す
    W = 0.1 * jnp.ones((args.action_dim*args.N),dtype=jnp.float32)  # GMRES法における初期解
    x_dot = model_func(x,us[0])
    dFdU = dFdU_(x,us_,t)  # 自動微分が使えるので、差分近似を使わずに計算できる
    dFdx = dFdx_(x,us_,t)
    dFdt = dFdt_(x,us_,t)

    r = - args.zeta * F_(x,us_,t) - dFdx @ x_dot - dFdt - dFdU @ W

    kai = GMRES(dFdU,r)
    U_dot = W + kai
    U_dot = jnp.reshape(U_dot,(-1,args.action_dim))

    U = us + U_dot * args.Ts

    return U

遊んでみる

それでは、出来たコントローラーを使って制御を簡単にシミュレーションしてみましょう。アニメーションにしてみると、障害物を避けて目的地に向かっているのがわかります。

最後に

学部の時はC/GMRESに触れる機会が多かったのですが、修士になってからiLQRばかり使っておりました。ふとC/GMRESをもう一度やってみようと思い書いてみましたが、思いのほか上手く行って何よりです。MPCやるといつも実感しますが、パラメータ調整は本当に厄介ですね。ある問題設定で上手く行っても他の問題設定だときちんと避けきってくれなかったりとかザラにありますから。機会があれば、ここら辺のオートチューニング機能とかも調べて記事にしてみたいですね。

実は、コード作るのに少しだけclaude.aiの力を借りました。勿論、きちんとしたコードを出力してもらうには細かく手法や仕様を指示する必要があるので、理解しないままコードだけ作らせるというのは難しいです。とは言え、そこがある程度出来れば、自分の思いつかなかった楽な実装法を教えてくれたりするのでとても助かります。凄い時代になったとしみじみ感じます。

参考文献

C/GMRESについて

バリア関数について

シミュレーションコード全体

長いので、こちらの github リポジトリにも載せてあります。

import jax
import jax.numpy as jnp
import numpy as np
import time
from dataclasses import dataclass
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
from matplotlib.patches import Circle



#連続状態方程式
@jax.jit
def model_func(x, u):
    Bk = jnp.array([[jnp.cos(x[2]), 0],
                    [jnp.sin(x[2]), 0],
                    [0, 1]], dtype=jnp.float32)
    x_dot = Bk @ u
    return x_dot


#model_dfdx = jax.jit(jax.jacfwd(model_func,0))
#model_dfdu = jax.jit(jax.jacfwd(model_func,1))


# コントローラーに関するパラメーター
@dataclass
class Cont_Args:

    # コントローラーのパラメータ
    Ts = 0.02  # 制御周期
    tf = 1.0  # 予測ホライズンの最終長さ
    N = 50  # 予測ホライズンの分割数
    #dt = Ts  # 今回は意味の無いパラメータなので無視してください
    alpha = 0.5  # 予測ホライズンを変化させる計算のパラメータ
    zeta = 1  # U_dotを計算する時の係数パラメータ(zetaと書いてますがツェータです)


    # 評価関数中の重み
    # 状態変数の項
    Q = jnp.array([[1, 0, 0],
                   [0, 1, 0],
                   [0, 0, 0]], dtype=jnp.float32) * 0

    # 制御入力の項
    R = jnp.array([[100,0],
                   [0,10]], dtype=jnp.float32)

    # 最終地点の項
    S = jnp.array([[1, 0, 0],
                   [0, 1, 0],
                   [0, 0, 0]], dtype=jnp.float32) * 100

    # 目標地点
    x_ob = jnp.array([5, 0, 0], dtype=jnp.float32)

    # 目標入力
    u_ob = jnp.array([0, 0], dtype=jnp.float32)

    # 次元データ
    obss_dim = 3  # 状態変数の次元
    action_dim = 2  # 入力変数の次元

    # 状態と入力
    x = None
    u = None
    us = None

    # 障害物の場所(中心) 
    ev_pos = jnp.array([[2.5,0.15,0]],dtype=jnp.float32)
    # 障害物の半径
    d_ = 0.3
    # ロボットの半径
    r_ = 0.1
    # 障害物の中心から取るべき距離
    d = d_ + r_

    # 緩和対数バリア関数の緩和値
    del_bar = 0.05

    # 回避バリア関数の重み
    r = 50

    # 入力制限
    umax = jnp.array([1.0,1.0],dtype=jnp.float32)
    umin = jnp.array([-1.0,-1.0],dtype=jnp.float32)

    # 計算用行列
    bar_C = jnp.concatenate([jnp.eye(action_dim,dtype=jnp.float32),-jnp.eye(action_dim,dtype=jnp.float32)],0)
    bar_d = jnp.concatenate([umax,-umin],0)

    # 速度制限のバリア関数の重み
    b = 10

args = Cont_Args()


# バリア関数(-logか二次関数かを勝手に切り替える)
@jax.jit
def barrier_z(z):
    pred = z > args.del_bar
    true_fun = lambda z: - jnp.log(z)
    false_fun = lambda z: 0.5 * (((z - 2*args.del_bar) / args.del_bar)**2 - 1) - jnp.log(args.del_bar)
    return jax.lax.cond(pred, true_fun, false_fun, z)


# バリア関数(全体)
@jax.jit
def barrier(u):

    zs = args.bar_d - args.bar_C @ u

    def vmap_fun(b, z, margin=0.5):
        return b * jnp.where(z>=margin,barrier_z(margin),barrier_z(z))

    Bars = jax.vmap(vmap_fun, (None,0))(args.b, zs)
    Bar = jnp.sum(Bars)
    return Bar


# 対数バリア関数型回避関数
@jax.jit
def evasion(x):

    def vmap_fun(x, xe, r, d, margin=0.5):
        distance = jnp.linalg.norm(x-xe,ord=2)
        z = distance**2 - d**2
        ref = d + margin
        return r * jnp.where(distance>=ref, barrier_z(ref**2-d**2), barrier_z(z))

    evas = jax.vmap(vmap_fun, (None, 0, None, None))(x, args.ev_pos, args.r, args.d)
    eva = jnp.sum(evas)
    return eva



# ステージコスト
@jax.jit
def stage_cost(x,u):
    cost = 0.5 * ( (x-args.x_ob) @ args.Q @ (x-args.x_ob) \
                  + (u-args.u_ob) @ args.R @ (u-args.u_ob))  \
                  + evasion(x) \
                  + barrier(u)
    return cost

# 終端コスト
@jax.jit
def term_cost(x):
    cost = 0.5 * (x-args.x_ob) @ args.S @ (x-args.x_ob)
    return cost

# ステージコストの微分
#grad_x_stage = jax.jit(jax.grad(stage_cost,0))
#grad_u_stage = jax.jit(jax.grad(stage_cost,1))
#hes_x_stage = jax.jit(jax.hessian(stage_cost,0))
#hes_u_stage = jax.jit(jax.hessian(stage_cost,1))
#hes_ux_stage = jax.jit(jax.jacfwd(jax.grad(stage_cost,1),0))

# 終端コストの微分
grad_x_term = jax.jit(jax.grad(term_cost,0))
#hes_x_term = jax.jit(jax.hessian(term_cost,0))


# コントローラー関数
@jax.jit
def CGMRES_control(x,us,t):

    def rollout(x_init,us,dt):

        def rollout_body(carry,u):
            x = carry
            x = x + model_func(x,u) * dt
            return x, x
        
        _, xs = jax.lax.scan(rollout_body, x_init, us)
        xs = jnp.vstack([x_init[None], xs])

        return xs
    
    
    def Hamilton(x,u,lambda_):
        H = stage_cost(x,u) + lambda_ @ model_func(x,u)
        return H
    
    dHdx = jax.grad(Hamilton,0)
    dHdu = jax.grad(Hamilton,1)
    
    
    def Backward(xs,us,dt):

        def Backward_body(carry,val):
            lambda_ = carry
            x, u = val
            lambda_ = lambda_ + dHdx(x,u,lambda_) * dt
            return lambda_, lambda_

        lambda_ = grad_x_term(xs[-1])

        _, out_lambdas = jax.lax.scan(Backward_body,lambda_,(jnp.flip(xs[1:-1],0),jnp.flip(us[1:],0)))
        lambdas = jnp.flip(jnp.vstack([lambda_,out_lambdas]),axis=0)

        return lambdas
    

    def F_(x,us,t):

        us = jnp.reshape(us,(-1,args.action_dim))  # 計算の都合上、入力の時にusを横一列に並べ直しているので、ここで直す
        dt = (1-jnp.exp(-args.alpha*t)) * args.tf/args.N  # 予測ホライズンの分割幅を計算

        xs = rollout(x,us,dt)

        lambdas = Backward(xs,us,dt)

        F = jax.vmap(dHdu,(0,0,0))(xs[:-1],us,lambdas)
        F = jnp.reshape(F,(-1,))

        return F
    
    dFdU_ = jax.jacrev(F_,1)
    dFdx_ = jax.jacrev(F_,0)
    dFdt_ = jax.jacrev(F_,2)

    
    #GMRES法関数(Ax=bの初期残差をrとする)
    def GMRES(A, r, max_iter=5):
        
        def arnoldi(A, v1, m):
            n = v1.shape[0]
            Vm_1 = jnp.zeros((n, m+1))
            H = jnp.zeros((m+1, m))
            Vm_1 = Vm_1.at[:, 0].set(v1)
            
            def body_fun(j, val):
                Vm_1, H = val
                v = A @ Vm_1[:, j]

                def body_in(i,val):
                    Vm_1, H, v = val
                    H = H.at[i, j].set(jnp.dot(Vm_1[:, i], v))
                    v = v - H[i, j] * Vm_1[:, i]
                    return Vm_1, H, v
                
                Vm_1, H, v = jax.lax.fori_loop(0,j+1,body_in,(Vm_1,H,v))

                H = H.at[j+1, j].set(jnp.linalg.norm(v))
                Vm_1 = Vm_1.at[:, j+1].set(v / H[j+1, j])
                return Vm_1, H
            
            Vm_1, H = jax.lax.fori_loop(0, m, body_fun, (Vm_1, H))
            return Vm_1, H

        def givens_rotation(v1, v2):
            t = jnp.sqrt(v1**2 + v2**2)
            c = v1 / t
            s = -v2 / t
            return c, s

        #n = r.shape[0]
        Vm_1, H = arnoldi(A, r / jnp.linalg.norm(r), max_iter)
        
        beta = jnp.linalg.norm(r)
        e1 = jnp.zeros(max_iter + 1)
        e1 = e1.at[0].set(beta)
        
        def body_fun(i, val):
            H, e1 = val
            c, s = givens_rotation(H[i, i], H[i+1, i])
            Givens = jnp.array([[c, s], [-s, c]])
            H_col = jax.lax.dynamic_slice(H,(i,0),(2,max_iter))
            H = jax.lax.dynamic_update_slice(H, Givens @ H_col, (i,0))
            e1_slice = jax.lax.dynamic_slice(e1,(i,),(2,))
            e1 = jax.lax.dynamic_update_slice(e1, Givens @ e1_slice, (i,))
            return H, e1
        
        H, e1 = jax.lax.fori_loop(0, max_iter, body_fun, (H, e1))
        
        y = jnp.linalg.solve(H[:max_iter, :max_iter], e1[:max_iter])
        x = Vm_1[:, :max_iter] @ y
    
        return x


    # ここから本計算

    us_ = jnp.reshape(us,(-1,))  # Fの計算の都合上、横一列に並べ直す
    W = 0.1 * jnp.ones((args.action_dim*args.N),dtype=jnp.float32)  # GMRES法における初期解
    x_dot = model_func(x,us[0])
    dFdU = dFdU_(x,us_,t)  # 自動微分が使えるので、差分近似を使わずに計算できる
    dFdx = dFdx_(x,us_,t)
    dFdt = dFdt_(x,us_,t)

    r = - args.zeta * F_(x,us_,t) - dFdx @ x_dot - dFdt - dFdU @ W

    kai = GMRES(dFdU,r)
    U_dot = W + kai
    U_dot = jnp.reshape(U_dot,(-1,args.action_dim))

    U = us + U_dot * args.Ts

    return U



# 初期条件
args.u = jnp.zeros((args.action_dim), dtype=jnp.float32)
args.us = jnp.zeros((args.N, args.action_dim), dtype=jnp.float32)
args.x = jnp.zeros((args.obss_dim), dtype=jnp.float32)

Time = 0.0
time_stamp = []
x_log = []
u_log = []

start = time.time()
while Time <= 20:
    print("-------------Position-------------")
    print(args.x)
    print("-------------Input-------------")
    print(args.u)

    time_stamp.append(Time)
    x_log.append(args.x)
    u_log.append(args.u)

    us = CGMRES_control(args.x,args.us,Time)
    #print(_)

    x_dot = model_func(args.x,args.u)
    x = args.x + x_dot * args.Ts
    
    Time += args.Ts
    args.x = x
    args.u = us[0]
    args.us = us

end = time.time()
loop_time = end - start

print("計算時間:{}[s]".format(loop_time))

time_log = np.array(time_stamp)
x_log = np.array(x_log)
u_log = np.array(u_log)
fig = plt.figure()
ax = plt.axes()
ax.set_xlim(-1,6)
ax.set_ylim(-2,2)
plt.axis("equal")

robot = Circle(xy=x_log[0][:2], radius=args.r_, fill=False)
ax.add_artist(robot)
line, = ax.plot([],[],'r-',lw=2)

def update(frame):
    x , y, theta = x_log[frame]
    robot.center = (x,y)
    line_x = [x, x+args.r_*np.cos(theta)]
    line_y = [y, y+args.r_*np.sin(theta)]
    line.set_data(line_x,line_y)
    return robot, line

obstacle = Circle(xy=args.ev_pos[0], radius=args.d_, ec="k")
ax.add_artist(obstacle)

anim = FuncAnimation(fig, update, frames=501, interval=20, blit=True)

writer = PillowWriter(fps=50)  # fpsはフレームレートを指定
anim.save("CGMRES.gif", writer=writer)

fig2 = plt.figure()
ax1 = fig2.add_subplot(211)
ax2 = fig2.add_subplot(212)

ax1.plot(time_log,u_log[:,0])
ax2.plot(time_log,x_log[:,1])

plt.show()

Discussion