⛰️

# pythonでの直接法による最適計算の一例

2023/08/27に公開

これをpythonでやってみるだけのメモ

# 例題1

1次元の運動

-1 <= u <= 1
と制限される。

x(0) = 0
v(0) = 0
v(1) = 0

\max x(1)

つまり、

ロケットの軌跡（推力の履歴）を求めろ。

## コード

### pulp

import pulp
N=10
dt = 1/N

problem = pulp.LpProblem('case1', pulp.LpMaximize)

x = pulp.LpVariable.dicts('x', range(N+1), cat='Continuous')
v = pulp.LpVariable.dicts('v', range(N+1), cat='Continuous')

# 端条件
problem += x[0] == 0
problem += v[0] == 0
problem += v[N] == 0

# 各時刻の拘束条件
for i in range(N):
problem += (x[i+1] - x[i])/dt - (v[i+1] + v[i])/2 == 0
problem += (v[i+1] - v[i])/dt >= -1
problem += (v[i+1] - v[i])/dt <= 1

# 目的関数
problem += x[N]


print(x)

>>> {0: x_0,
1: x_1,
2: x_2,
3: x_3,
4: x_4,
5: x_5,
6: x_6,
7: x_7,
8: x_8,
9: x_9,
10: x_10}

print(ploblem)
>>> SLE:
MAXIMIZE
1*x_10 + 0
SUBJECT TO
_C1: x_0 = 0
_C2: v_0 = 0
_C3: v_10 = 0
_C4: - 0.5 v_0 - 0.5 v_1 - 10 x_0 + 10 x_1 = 0
_C5: - 10 v_0 + 10 v_1 >= -1
_C6: - 10 v_0 + 10 v_1 <= 1
...


status = problem.solve()
print('Status:', pulp.LpStatus[status])

>>> Status: Optimal


print(pulp.value(problem.objective))
>>> 0.25


import matplotlib.pyplot as plt

xs = [pulp.value(x[xi]) for xi in x]
vs = [pulp.value(v[vi]) for vi in v]
ts = range(0,N+1)
ymin = min(min(xs), min(vs))
ymax = max(max(xs), max(vs))

plt.plot(ts, xs, marker = 'o')
plt.plot(ts, vs, marker = 'x')
plt.xlim(0, N)
plt.ylim(ymin, ymax)

plt.xlabel('time', fontsize = 10)
plt.tick_params(labelsize=10)
plt.grid(True)
plt.show()


## cvxopt

from cvxopt import matrix, solvers
import numpy as np

"""
min c·x
s.t. Gx ≤ h, Ax ≤ b

x = [
x0,v0,
x1,v1,
x2,v2,
...
xN,vN
]
-> xi = x[i*2], vi = x[i*2+1]
"""

N = 10
dt = 1/N
dim = 2

# 目的関数
cn = np.zeros(dim*(N+1))
cn[dim*N] = -1
c=matrix(cn)

# 不等式拘束条件
Gn = np.zeros((2*N, dim*(N+1)))
hn = np.zeros(2*N)
for i in range(N):
Gn[i,(i+1)*dim+1] = 1
Gn[i,i*dim+1] = -1
hn[i] = dt

Gn[i+N,(i+1)*dim+1] = -1
Gn[i+N,i*dim+1] = 1
hn[i+N] = dt

# 等式拘束条件
An = np.zeros((N + 3, dim*(N+1)))
bn = np.zeros(N + 3)
for i in range(N):
An[i,(i+1)*dim] = 1/dt
An[i,i*dim] = -1/dt
An[i,(i+1)*dim+1] = -0.5
An[i,i*dim+1] = -0.5

An[N, 0] = 1
An[N+1, 1] = 1
An[N+2, dim*N+1] = 1

A=matrix(An)
b=matrix(bn)
G=matrix(Gn)
h=matrix(hn)
c=matrix(cn)
sol=solvers.lp(c, G, h, A, b)


xs = [sol['x'][i] for i in range(0, len(sol['x']), 2)]
vs = [sol['x'][i] for i in range(1, len(sol['x']), 2)]


## cvxpy

import numpy as np
import cvxpy as cp

N = 10
dt = 1/N
dim = 2

# 変数
x = cp.Variable(N+1)
v = cp.Variable(N+1)

# 目的関数
objective = cp.Maximize(x[N])

# 制約条件
constraints = [
x[0] == 0,
v[0] == 0,
v[N] == 0
]
for i in range(N):
constraints.append(
v[i+1] - v[i] <= dt
)
constraints.append(
v[i] - v[i+1] <= dt
)
constraints.append(
(x[i+1] - x[i])/dt
- (v[i+1] + v[i])/2 == 0
)

problem = cp.Problem(objective, constraints)
result = problem.solve()


xs = x.value
vs = v.value


# 例題2

2次元の運動

x(0) = 0
y(0) = 0
u(0) = 0
v(0) = 0
y(1) = h
v(1) = 0

\max u(1)

つまり、

## コード

h=1, a=8とする。

なお文献では推力の拘束条件の記述が誤っているのでコードに落とし込む時に修正している。

cvxoptは式の記述がめんどくさかったのでcvxpyのみ。

## cvxpy

import numpy as np
import cvxpy as cp

h = 1
a = 8
N = 10
dt = 1/N

# 変数
x = cp.Variable(N+1)
u = cp.Variable(N+1)
y = cp.Variable(N+1)
v = cp.Variable(N+1)

# 目的関数
objective = cp.Maximize(u[N])

# 制約条件
constraints = [
x[0] == 0,
u[0] == 0,
y[0] == 0,
v[0] == 0,
y[N] == h,
v[N] == 0
]
for i in range(N):
constraints.append(
cp.power((u[i+1] - u[i])/dt, 2) + cp.power((v[i+1] - v[i])/dt, 2) <= a**2
)
constraints.append(
(x[i+1] - x[i])/dt - (u[i+1] + u[i])/2 == 0
)
constraints.append(
(y[i+1] - y[i])/dt - (v[i+1] + v[i])/2 == 0
)

problem = cp.Problem(objective, constraints)
result = problem.solve()


print(result)
>>> 7.168105386107712


import matplotlib.pyplot as plt

xs = x.value
us = u.value
ys = y.value
vs = v.value

ts = range(0,N+1)
ymin = min(min(xs), min(us), min(ys), min(vs))
ymax = max(max(xs), max(us), max(ys), max(vs))

plt.plot(ts, xs, marker = 'o')
plt.plot(ts, us, marker = 'x')
plt.plot(ts, ys, marker = '+')
plt.plot(ts, vs, marker = '*')
plt.xlim(0, N)
plt.ylim(ymin, ymax)

plt.xlabel('time', fontsize = 10)
plt.tick_params(labelsize=10)
plt.grid(True)
plt.show()