📏

PythonでLU分解と連立一次方程式の数値解法

2022/08/10に公開

問題

N次元の正方行列Aと列ベクトルB, Xについて、AX=BXについて解きたい。

\begin{pmatrix} a_{11} & a_{12} & \dots & a_{1N} \\ a_{21} & a_{22} & \dots & a_{2N} \\ \vdots & \vdots & \ddots & \vdots \\ a_{N1} & a_{N2} & \dots & a_{NN} \end{pmatrix} \begin{pmatrix} x_{1}\\ x_{2}\\ \vdots \\ x_{N} \end{pmatrix} = \begin{pmatrix} b_{1}\\ b_{2}\\ \vdots \\ b_{N} \end{pmatrix}

方針

Aを下三角行列L(Lower triangular matrix)と上三角行列(Upper triangular matrix)の積で表すことができれば、A=LUなので、AX=B\iff LUX=Bとなります。ここでUX=Yとおくと、

\left\{ \begin{aligned} LY &= B \\ UX &= Y \end{aligned} \right.

となり、この連立方程式をY, Xの順番で解くことでAX=BXについて解いたことなります。L, Uは三角行列であるため、求まった未知数を順番に代入することで簡単に全ての未知数を求められます。[1]

実装

ソースコードは以下のリポジトリにあります。

https://github.com/3w36zj6/linear-equation-solver

solve.py
from decimal import Decimal

print_matrix = lambda matrix: print(*matrix, sep="\n")  # 行列表示用


def lu_decomposition(A: list[list[Decimal]]) -> tuple[list[list[Decimal]], list[list[Decimal]], list[list[Decimal]]]:
    """
    正方行列AをLU分解する。

    Parameters
    ----------
    A : list[list[Decimal]]
        LU分解する正方行列A。

    Returns
    -------
    L : list[list[Decimal]]
        下三角行列L。
    U : list[list[Decimal]]
        上三角行列U。
    P : list[list[Decimal]]
        ピボット選択の置換行列P。
    """

    L: list[list[Decimal]] = [[Decimal(1) if i == j else Decimal(0) for j in range(N)] for i in range(N)]  # 単位行列
    U: list[list[Decimal]] = [[Decimal(0) for j in range(N)] for i in range(N)]  # 零行列
    P: list[list[Decimal]] = [[Decimal(1) if i == j else Decimal(0) for j in range(N)] for i in range(N)]  # 単位行列

    # ピボット選択
    for k in range(N):
        abs_col: list[Decimal] = [abs(A[i][k]) for i in range(N)]
        max_index: int = abs_col.index(max(abs_col))  # 絶対値が最も大きい成分を探す

        # swap
        A[k], A[max_index] = A[max_index], A[k]
        P[k], P[max_index] = P[max_index], P[k]

    for k in range(N):
        # U
        for j in range(k, N):
            U[k][j] = A[k][j]
            for s in range(k):
                U[k][j] -= L[k][s] * U[s][j]

        # L
        for i in range(1 + k, N):
            L[i][k] = A[i][k]
            for s in range(k):
                L[i][k] -= L[i][s] * U[s][k]
            L[i][k] /= U[k][k]

    return L, U, P


def backward_substitution(A: list[list[Decimal]], B: list[Decimal]) -> list[Decimal]:
    """
    後退代入で方程式AX = Bを解く。

    Parameters
    ----------
    A : list[list[Decimal]]
        上三角行列A。
    B : list[Decimal]
        行列B。

    Returns
    -------
    X : list[Decimal]
        方程式の解。
    """
    X: list[Decimal] = [Decimal(0) for i in range(N)]
    for i in reversed(range(0, N)):
        X[i] = B[i]
        for k in range(i + 1, N):
            X[i] -= A[i][k] * X[k]
        X[i] /= A[i][i]

    return X


def multiply_permutation_matrix(P: list[list[Decimal]], B: list[Decimal]) -> list[Decimal]:
    """
    置換行列Pを行列Bに左からかける。

    Parameters
    ----------
    P : list[list[Decimal]]
        置換行列P。
    B : list[Decimal]
        行列B。

    Returns
    -------
    PB : list[Decimal]
        置換後の行列。
    """
    PB: list[Decimal] = [Decimal(0) for i in range(N)]
    for i in range(N):
        for j in range(N):
            PB[i] += P[i][j] * B[j]

    return PB


if __name__ == "__main__":
    # 標準入力
    N: int = int(input())
    A: list[list[Decimal]] = [list(map(Decimal, input().split())) for i in range(N)]
    B: list[Decimal] = list(map(Decimal, input().split()))

    # LU分解
    L: list[list[Decimal]]
    U: list[list[Decimal]]
    P: list[list[Decimal]]
    L, U, P = lu_decomposition(A)

    # LY = PBを解く
    L = [row[::-1] for row in L[:]][::-1]  # 後退代入と上下逆なので逆順に
    PB: list[Decimal] = multiply_permutation_matrix(P, B)[::-1]  # 置換行列をかけて後退代入と上下逆なので逆順に
    Y: list[Decimal] = backward_substitution(L, PB)[::-1]

    # UX = Yを解く
    X: list[Decimal] = backward_substitution(U, Y)

    # 標準出力
    print(*X, sep="\n")

標準入力

入力は以下の形式の標準入力で与えることにします。

N
a_{11} a_{12} \cdots a_{1N}
a_{21} a_{22} \cdots a_{2N}
\vdots
a_{N1} a_{N2} \cdots a_{NN}
b_{11} b_{12} \cdots b_{1N}

from decimal import Decimal

N: int = int(input())
A: list[list[Decimal]] = [list(map(Decimal, input().split())) for i in range(N)]
B: list[Decimal] = list(map(Decimal, input().split()))

DecimalはPython標準の固定小数点数型です。浮動小数点数の演算による計算誤差を小さくするために使用しています。

https://docs.python.org/ja/3/library/decimal.html

LU分解

まずLUは以下の式のようにおけます。

\begin{aligned} A&=LU \\ =& \begin{pmatrix} 1 & 0 & \dots & 0 \\ l_{21} & 1 & \dots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ l_{N1} & l_{N2} & \dots & 1 \end{pmatrix} \begin{pmatrix} u_{11} & u_{12} & \dots & u_{1N} \\ 0 & u_{22} & \dots & u_{2N} \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \dots & u_{NN} \end{pmatrix} \\ =& \begin{pmatrix} u_{11} & u_{12} & \dots & u_{1N} \\ l_{21}u_{11} & l_{21}u_{12}+u_{22} & \dots & l_{21}u_{1N}+u_{2N} \\ \vdots & \vdots & \ddots & \vdots \\ l_{N1}u_{11} & l_{N1}u_{12}+l_{N2}u_{22} & \dots & l_{N1}u_{1N}+\cdots+l_{N(N-1)}u_{(N-1)N}+u_{NN} \end{pmatrix} \end{aligned}

ここで、以下の操作をk=1からk=NまでN回繰り返すことにより、L, Uの成分を求めます。

  • Uk行目をAL, Uの既知の成分より求める。
  • Lk列目をAL, Uの既知の成分より求める。

計算量はO(n^3)です。

def lu_decomposition(A: list[list[Decimal]]) -> tuple[list[list[Decimal]], list[list[Decimal]]]:
    L: list[list[Decimal]] = [[Decimal(1) if i == j else Decimal(0) for j in range(N)] for i in range(N)]  # 単位行列
    U: list[list[Decimal]] = [[Decimal(0) for j in range(N)] for i in range(N)]  # 零行列

    for k in range(N):
        # U
        for j in range(k, N):
            U[k][j] = A[k][j]
            for s in range(k):
                U[k][j] -= L[k][s] * U[s][j]

        # L
        for i in range(1 + k, N):
            L[i][k] = A[i][k]
            for s in range(k):
                L[i][k] -= L[i][s] * U[s][k]
            L[i][k] /= U[k][k]

    return L, U

部分ピボット選択

上記の式とソースコードからわかるように、除算の部分で0で割ってしまう、つまりUの対角成分に0が含まれていると計算ができません。また0でなくとも小さな値での除算は誤差が大きくなってしまいます。そこで対角成分と同じ列から絶対値が最大の要素がある行を選び交換する操作を行います。[2]

行の交換の際に置換行列Pを作っておきます。すると解くことになる方程式は、

\left\{ \begin{aligned} LY &= PB \\ UX &= Y \end{aligned} \right.

となります。

@@ -1,6 +1,16 @@
-def lu_decomposition(A: list[list[Decimal]]) -> tuple[list[list[Decimal]], list[list[Decimal]]]:
+def lu_decomposition(A: list[list[Decimal]]) -> tuple[list[list[Decimal]], list[list[Decimal]], list[list[Decimal]]]:
     L: list[list[Decimal]] = [[Decimal(1) if i == j else Decimal(0) for j in range(N)] for i in range(N)]  # 単位行列
     U: list[list[Decimal]] = [[Decimal(0) for j in range(N)] for i in range(N)]  # 零行列
+    P: list[list[Decimal]] = [[Decimal(1) if i == j else Decimal(0) for j in range(N)] for i in range(N)]  # 単位行列
+
+    # ピボット選択
+    for k in range(N):
+        abs_col: list[Decimal] = [abs(A[i][k]) for i in range(N)]
+        max_index: int = abs_col.index(max(abs_col))  # 絶対値が最も大きい成分を探す
+
+        # swap
+        A[k], A[max_index] = A[max_index], A[k]
+        P[k], P[max_index] = P[max_index], P[k]
 
     for k in range(N):
         # U
@@ -16,4 +26,4 @@
                 L[i][k] -= L[i][s] * U[s][k]
             L[i][k] /= U[k][k]
 
-    return L, U
\ No newline at end of file
+    return L, U, P
\ No newline at end of file

後退代入

三角行列が求まっているので、あとは順番に代入して解くだけです。

計算量はO(n^2)です。

def backward_substitution(A: list[list[Decimal]], B: list[Decimal]) -> list[Decimal]:
    X: list[Decimal] = [Decimal(0) for i in range(N)]
    for i in reversed(range(0, N)):
        X[i] = B[i]
        for k in range(i + 1, N):
            X[i] -= A[i][k] * X[k]
        X[i] /= A[i][i]

    return X

NumPyで検算

solve_numpy.py
import numpy

N: int = int(input())
A: list[list[float]] = [list(map(float, input().split())) for i in range(N)]
B: list[float] = list(map(float, input().split()))

print(*numpy.linalg.solve(A, B), sep="\n")

参考

https://mathwords.net/lubunkai

脚注
  1. この過程を前進代入(forward substitution)または後退代入(backward substitution)といいます。 ↩︎

  2. 行のみを交換する操作を部分ピボット選択といい、列も交換する操作を完全ピボット選択といいます。 ↩︎

Discussion