🧠

今更だけどLSTMを自作してみた

に公開

こんにちは!
株式会社アイディオットでデータサイエンティストをしています、秋田と申します。
今更ですが、PyTorchでLSTMを自作してみました。

「え、 torch.nn.LSTM ってモジュールがあるよ?」

おっしゃる通り...ですが、改めて自分で作成してみると公式のムズカシイドキュメントを眺めるよりもシンプルに理解ができるのと、改造が可能なことが大きな利点として挙げられるのではないでしょうか!
公式で提供されているLSTMのモジュールは、RNNクラスも継承しているため色々参照しなければいけないので大変なんですよね😅
今回自作したLSTMでは、極力見慣れたメソッドやクラスだけで構成しているので上記の心配はありません。

全体像

まずはプログラム全体をご紹介いたします。

from typing import Optional

import torch
from torch import nn


class LSTMCell(nn.Module):
    def __init__(self, input_size: int, hidden_size: int):
        super(LSTMCell, self).__init__()

        # 入力層サイズと隠れ層サイズを設定
        self.input_size = input_size
        self.hidden_size = hidden_size

        # 各種重みとバイアスの定義
        self.weight_ih = nn.Parameter(
            torch.randn(4 * self.hidden_size, self.input_size)
        )
        self.weight_hh = nn.Parameter(
            torch.randn(4 * self.hidden_size, self.hidden_size)
        )
        self.bias_ih = nn.Parameter(torch.zeros(4 * self.hidden_size))
        self.bias_hh = nn.Parameter(torch.zeros(4 * self.hidden_size))

        # 活性化関数の定義
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()

    def forward(
        self, x_t: torch.Tensor, h_prev: torch.Tensor, c_prev: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        順伝播を行う関数

        Parameters
        ----------
        x_t: torch.Tensor
            入力データ
        h_prev: torch.Tensor
            1ステップ前の隠れ状態
        c_prev: torch.Tensor
            1ステップ前のセル

        Returns
        ----------
        h_t: torch.Tensor
            隠れ状態
        c_t: torch.Tensor
            セル
        """
        # 総合的なゲート計算
        gates = (
            torch.matmul(x_t, self.weight_ih.t()) +
            torch.matmul(h_prev, self.weight_hh.t()) +
            self.bias_ih + self.bias_hh
        )

        # ゲート分割
        i, f, g, o = gates.chunk(4, dim=1)

        # 活性化関数適用
        i = self.sigmoid(i)
        f = self.sigmoid(f)
        g = self.tanh(g)
        o = self.sigmoid(o)

        # セル・隠れ状態の更新
        c_t = f * c_prev + i * g
        h_t = o * self.tanh(c_t)

        return h_t, c_t


class LSTMLayer(nn.Module):
    def __init__(
        self, input_size: int, hidden_size: int,
        bidirectional: bool=False, residual: bool=False
    ):
        super(LSTMLayer, self).__init__()

        # 入力層サイズと隠れ層サイズを設定
        self.input_size = input_size
        self.hidden_size = hidden_size

        # 双方向にするかを設定
        self.bidirectional = bidirectional

        # 残差接続するかを設定
        self.residual = residual

        # (双方向の)LSTMセルのインスタンスの作成
        self.cell_fwd = LSTMCell(self.input_size, self.hidden_size)
        if self.bidirectional:
            self.cell_bwd = LSTMCell(self.input_size, self.hidden_size)

    def forward(
        self, x: torch.Tensor, h_0: torch.Tensor, c_0: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        順伝播を行う関数

        Parameters
        ----------
        x: torch.Tensor
            入力データ
        h_0: torch.Tensor
            初期化された隠れ状態
        c_0: torch.Tensor
            初期化されたセル

        Returns
        ----------
        outputs: torch.Tensor
            出力データ
        h_t: torch.Tensor
            隠れ状態
        c_t: torch.Tensor
            セル
        """
        # 入力長を取得
        seq_len, _, _ = x.size()

        # forward方向の処理
        h_fwd, c_fwd = h_0[0], c_0[0]
        outputs_fwd = []
        for t in range(seq_len):
            h_fwd, c_fwd = self.cell_fwd(x[t], h_fwd, c_fwd)
            outputs_fwd.append(h_fwd.unsqueeze(0))
        outputs_fwd = torch.cat(outputs_fwd, dim=0)

        # 双方向ではない場合用に分岐
        if not self.bidirectional:
            outputs = outputs_fwd
            h_t = h_fwd.unsqueeze(0)
            c_t = c_fwd.unsqueeze(0)
        else:
            # backward方向の処理
            h_bwd, c_bwd = h_0[1], c_0[1]
            outputs_bwd = []
            for t in reversed(range(seq_len)):
                h_bwd, c_bwd = self.cell_bwd(x[t], h_bwd, c_bwd)
                outputs_bwd.insert(0, h_bwd.unsqueeze(0))
            outputs_bwd = torch.cat(outputs_bwd, dim=0)

            # forward/backwardの出力を連結
            outputs = torch.cat([outputs_fwd, outputs_bwd], dim=2)
            h_t = torch.stack([h_fwd, h_bwd], dim=0)
            c_t = torch.stack([c_fwd, c_bwd], dim=0)

        # 残差接続
        if self.residual and outputs.shape == x.shape:
            outputs = outputs + x

        return outputs, h_t, c_t


class LSTM(nn.Module):
    def __init__(
        self, input_size: int, hidden_size: int, num_layers: int=1,
        bidirectional: bool=False, residual: bool=False
    ):
        super(LSTM, self).__init__()

        # 入力層サイズと隠れ層サイズを設定
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # 層数を設定
        self.num_layers = num_layers

        # 双方向にするかを設定
        self.bidirectional = bidirectional

        # 残差接続するかを設定
        self.residual = residual

        # LSTM全体の構造を作成する
        self.layers = nn.ModuleList()
        for layer in range(self.num_layers):
            in_size = self.input_size if layer == 0 else self.hidden_size * (
                2 if self.bidirectional else 1
            )
            self.layers.append(
                LSTMLayer(
                    in_size, self.hidden_size,
                    self.bidirectional, self.residual
                )
            )

    def forward(
        self, x: torch.Tensor,
        h_0: Optional[torch.Tensor] = None,
        c_0: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        """
        順伝播を行う関数

        Parameters
        ----------
        x: torch.Tensor
            入力データ
        h_0: Optional[torch.Tensor]=None
            初期化された隠れ状態
        c_0: Optional[torch.Tensor]=None
            初期化されたセル

        Returns
        ----------
        outputs: torch.Tensor
            出力データ
        h_t: torch.Tensor
            隠れ状態
        c_t: torch.Tensor
            セル
        """
        # バッチサイズを取得
        _, batch_size, _ = x.size()

        # 方向数を指定
        num_directions = 2 if self.bidirectional else 1

        # 初期状態の隠れ状態とセルを定義
        if h_0 is None:
            h_0 = torch.zeros(
                self.num_layers * num_directions, batch_size,
                self.hidden_size, device=x.device
            )
        if c_0 is None:
            c_0 = torch.zeros(
                self.num_layers * num_directions, batch_size,
                self.hidden_size, device=x.device
            )

        # スタックの用意
        h_n = []
        c_n = []
        layer_input = x

        # 層ごとに処理・格納
        for layer_idx, layer in enumerate(self.layers):
            # 初期化
            h_start = layer_idx * num_directions
            h_i = h_0[h_start:h_start + num_directions]
            c_i = c_0[h_start:h_start + num_directions]

            # 次の層に更新
            layer_output, h_i_out, c_i_out = layer(layer_input, h_i, c_i)
            layer_input = layer_output

            # スタックに格納
            h_n.append(h_i_out)
            c_n.append(c_i_out)

        # 最終層の出力
        outputs = layer_output
        h_t = torch.cat(h_n, dim=0)
        c_t = torch.cat(c_n, dim=0)

        return outputs, (h_t, c_t)

3つのモジュールに分割し、 LSTMCell はLSTMの基本アルゴリズムを、 LSTMLayer はレイヤーレベルのまとまりを、 LSTM は層全体を司ります。
ただし、読み出しの線形層は含まれていないことに注意してください。

公式との顕著な差

公式にあるのにここには無いものはいくつもありますが、公式には無くてここにあるものが1つだけあります。
それが「残差接続」です。
残差接続というと、画像認識の分野でのResNetが有名ですが、層を深くしたら表現力が上がることが期待される一方で勾配が消失・爆発するという問題があったのをある程度解消することが出来たことで話題になりましたね。
これは時系列解析の分野でも同様で、学習の安定化のために残差接続を行うことが検討されてきています。
凄く簡単な修正で実装が済むものの、公式の nn.LSTM クラスから作るのは至難の業です。
ということでカスタマイズできるように1から作る必要がありました!

LSTMCell

class LSTMCell(nn.Module):
    def __init__(self, input_size: int, hidden_size: int):
        super(LSTMCell, self).__init__()

        # 入力層サイズと隠れ層サイズを設定
        self.input_size = input_size
        self.hidden_size = hidden_size

        # 各種重みとバイアスの定義
        self.weight_ih = nn.Parameter(
            torch.randn(4 * self.hidden_size, self.input_size)
        )
        self.weight_hh = nn.Parameter(
            torch.randn(4 * self.hidden_size, self.hidden_size)
        )
        self.bias_ih = nn.Parameter(torch.zeros(4 * self.hidden_size))
        self.bias_hh = nn.Parameter(torch.zeros(4 * self.hidden_size))

        # 活性化関数の定義
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()

    def forward(
        self, x_t: torch.Tensor, h_prev: torch.Tensor, c_prev: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        順伝播を行う関数

        Parameters
        ----------
        x_t: torch.Tensor
            入力データ
        h_prev: torch.Tensor
            1ステップ前の隠れ状態
        c_prev: torch.Tensor
            1ステップ前のセル

        Returns
        ----------
        h_t: torch.Tensor
            隠れ状態
        c_t: torch.Tensor
            セル
        """
        # 総合的なゲート計算
        gates = (
            torch.matmul(x_t, self.weight_ih.t()) +
            torch.matmul(h_prev, self.weight_hh.t()) +
            self.bias_ih + self.bias_hh
        )

        # ゲート分割
        i, f, g, o = gates.chunk(4, dim=1)

        # 活性化関数適用
        i = self.sigmoid(i)
        f = self.sigmoid(f)
        g = self.tanh(g)
        o = self.sigmoid(o)

        # セル・隠れ状態の更新
        c_t = f * c_prev + i * g
        h_t = o * self.tanh(c_t)

        return h_t, c_t

まずはLSTMの構成として、入力層と隠れ層、それからそれらのバイアスについてのパラメータを用意する必要がありますね。
それぞれ4つのゲート(入力・忘却・セル入力・出力)に対してパラメータが必要なので 4 * self.hidden_size と隠れサイズを4倍しています。
4つのゲートの計算については次のように定義されます。

\begin{align*} i_t &= \sigma (W_{ii} x_{t} + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\ f_t &= \sigma (W_{if} x_{t} + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\ g_t &= \tanh (W_{ig} x_{t} + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\ o_t &= \sigma (W_{io} x_{t} + b_{io} + W_{ho} h_{t-1} + b_{ho}) \end{align*}

forwardの計算では、まずは上記の計算を一気に行うため、

        gates = (
            torch.matmul(x_t, self.weight_ih.t()) +
            torch.matmul(h_prev, self.weight_hh.t()) +
            self.bias_ih + self.bias_hh
        )

として、ゲート全体を定義し、この後 .chunk() で分割して活性化関数にかけます。
そして、セルと隠れ状態の更新を次の式の通りに行います。

\begin{align*} c_t &= f_t \odot c_{t-1} + i_t \odot g_t \\ h_t &= o_t \odot \tanh (c_t) \end{align*}

こちらは次のように書いています。

        c_t = f * c_prev + i * g
        h_t = o * self.tanh(c_t)

ここまでが、LSTMが基本的に行うアルゴリズムの設計になります。

LSTMLayer

class LSTMLayer(nn.Module):
    def __init__(
        self, input_size: int, hidden_size: int,
        bidirectional: bool=False, residual: bool=False
    ):
        super(LSTMLayer, self).__init__()

        # 入力層サイズと隠れ層サイズを設定
        self.input_size = input_size
        self.hidden_size = hidden_size

        # 双方向にするかを設定
        self.bidirectional = bidirectional

        # 残差接続するかを設定
        self.residual = residual

        # (双方向の)LSTMセルのインスタンスの作成
        self.cell_fwd = LSTMCell(self.input_size, self.hidden_size)
        if self.bidirectional:
            self.cell_bwd = LSTMCell(self.input_size, self.hidden_size)

    def forward(
        self, x: torch.Tensor, h_0: torch.Tensor, c_0: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        順伝播を行う関数

        Parameters
        ----------
        x: torch.Tensor
            入力データ
        h_0: torch.Tensor
            初期化された隠れ状態
        c_0: torch.Tensor
            初期化されたセル

        Returns
        ----------
        outputs: torch.Tensor
            出力データ
        h_t: torch.Tensor
            隠れ状態
        c_t: torch.Tensor
            セル
        """
        # 入力長を取得
        seq_len, _, _ = x.size()

        # forward方向の処理
        h_fwd, c_fwd = h_0[0], c_0[0]
        outputs_fwd = []
        for t in range(seq_len):
            h_fwd, c_fwd = self.cell_fwd(x[t], h_fwd, c_fwd)
            outputs_fwd.append(h_fwd.unsqueeze(0))
        outputs_fwd = torch.cat(outputs_fwd, dim=0)

        # 双方向ではない場合用に分岐
        if not self.bidirectional:
            outputs = outputs_fwd
            h_t = h_fwd.unsqueeze(0)
            c_t = c_fwd.unsqueeze(0)
        else:
            # backward方向の処理
            h_bwd, c_bwd = h_0[1], c_0[1]
            outputs_bwd = []
            for t in reversed(range(seq_len)):
                h_bwd, c_bwd = self.cell_bwd(x[t], h_bwd, c_bwd)
                outputs_bwd.insert(0, h_bwd.unsqueeze(0))
            outputs_bwd = torch.cat(outputs_bwd, dim=0)

            # forward/backwardの出力を連結
            outputs = torch.cat([outputs_fwd, outputs_bwd], dim=2)
            h_t = torch.stack([h_fwd, h_bwd], dim=0)
            c_t = torch.stack([c_fwd, c_bwd], dim=0)

        # 残差接続
        if self.residual and outputs.shape == x.shape:
            outputs = outputs + x

        return outputs, h_t, c_t

ここでは2つ、双方向LSTMの構造にするか残差接続のある構造にするかのオプションをつけられるようにしています。
どちらもBooleanで与えられ、 True の場合にその構造にするよう設計してあります。

双方向LSTM

普通のLSTMは、系列データ(系列長 n)の i 番目を予測するのに、前(0i - 1)から学習を行います。
一方で、双方向LSTM(Bidirectional LSTM)では、系列の予測したいデータの前後(0i - 1, i + 1n - 1)から共に学習を行います。
劣化版Attentionといったところでしょうかね。

        self.cell_fwd = LSTMCell(self.input_size, self.hidden_size)
        if self.bidirectional:
            self.cell_bwd = LSTMCell(self.input_size, self.hidden_size)

__init__() の部分で、引数 bidirectional によって後ろの系列を予測するか判断し、双方向にするのであれば2つの LSTMCell のインスタンスを使ってモデルを構築します。

        h_fwd, c_fwd = h_0[0], c_0[0]
        outputs_fwd = []
        for t in range(seq_len):
            h_fwd, c_fwd = self.cell_fwd(x[t], h_fwd, c_fwd)
            outputs_fwd.append(h_fwd.unsqueeze(0))
        outputs_fwd = torch.cat(outputs_fwd, dim=0)

forward() メソッドの中で、まずは前からの計算処理の結果を持っておきます。
もし、 bidirectional=False にしていれば、この部分のみでLSTMの処理が完結すると考えてください。
ここで、 h_fwd, c_fwd = h_0[0], c_0[0] としていますが、 0 番目の要素を指定しているのは「前からの計算を行う」からです。
後ほど「後ろからの計算を行う」ときは、ここを 1 にします。
PyTorchの公式 nn.LSTM でも 0 が前、 1 が後ろであることが記載されています。

単方向であれば、次のコードで終了します。

        if not self.bidirectional:
            outputs = outputs_fwd
            h_t = h_fwd.unsqueeze(0)
            c_t = c_fwd.unsqueeze(0)

双方向の場合、後ろも同様に計算を行います。

            h_bwd, c_bwd = h_0[1], c_0[1]
            outputs_bwd = []
            for t in reversed(range(seq_len)):
                h_bwd, c_bwd = self.cell_bwd(x[t], h_bwd, c_bwd)
                outputs_bwd.insert(0, h_bwd.unsqueeze(0))
            outputs_bwd = torch.cat(outputs_bwd, dim=0)

前述の通り、 h_bwd, c_bwd = h_0[1], c_0[1] となっていますね。
また、for文が reversed() になっていることと、 outputs_bwdappend() ではなく insert() をしているところに注意です。
2つの結果を統合します。

            outputs = torch.cat([outputs_fwd, outputs_bwd], dim=2)
            h_t = torch.stack([h_fwd, h_bwd], dim=0)
            c_t = torch.stack([c_fwd, c_bwd], dim=0)

これで双方向LSTMの処理の流れは終わりです。

残差接続

そもそも残差接続とは、入力と出力の差を予測することで勾配消失・爆発を減らすというものでした。
時系列データでは、一連の系列に因果があり、データの分布が一定であるため差分の変化が極端ではないことが予想できます。

        if self.residual and outputs.shape == x.shape:
            outputs = outputs + x

差分を足すだけでシンプルですね!

LSTM

class LSTM(nn.Module):
    def __init__(
        self, input_size: int, hidden_size: int, num_layers: int=1,
        bidirectional: bool=False, residual: bool=False
    ):
        super(LSTM, self).__init__()

        # 入力層サイズと隠れ層サイズを設定
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # 層数を設定
        self.num_layers = num_layers

        # 双方向にするかを設定
        self.bidirectional = bidirectional

        # 残差接続するかを設定
        self.residual = residual

        # LSTM全体の構造を作成する
        self.layers = nn.ModuleList()
        for layer in range(self.num_layers):
            in_size = self.input_size if layer == 0 else self.hidden_size * (
                2 if self.bidirectional else 1
            )
            self.layers.append(
                LSTMLayer(
                    in_size, self.hidden_size,
                    self.bidirectional, self.residual
                )
            )

    def forward(
        self, x: torch.Tensor,
        h_0: Optional[torch.Tensor] = None,
        c_0: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        """
        順伝播を行う関数

        Parameters
        ----------
        x: torch.Tensor
            入力データ
        h_0: Optional[torch.Tensor]=None
            初期化された隠れ状態
        c_0: Optional[torch.Tensor]=None
            初期化されたセル

        Returns
        ----------
        outputs: torch.Tensor
            出力データ
        h_t: torch.Tensor
            隠れ状態
        c_t: torch.Tensor
            セル
        """
        # バッチサイズを取得
        _, batch_size, _ = x.size()

        # 方向数を指定
        num_directions = 2 if self.bidirectional else 1

        # 初期状態の隠れ状態とセルを定義
        if h_0 is None:
            h_0 = torch.zeros(
                self.num_layers * num_directions, batch_size,
                self.hidden_size, device=x.device
            )
        if c_0 is None:
            c_0 = torch.zeros(
                self.num_layers * num_directions, batch_size,
                self.hidden_size, device=x.device
            )

        # スタックの用意
        h_n = []
        c_n = []
        layer_input = x

        # 層ごとに処理・格納
        for layer_idx, layer in enumerate(self.layers):
            # 初期化
            h_start = layer_idx * num_directions
            h_i = h_0[h_start:h_start + num_directions]
            c_i = c_0[h_start:h_start + num_directions]

            # 次の層に更新
            layer_output, h_i_out, c_i_out = layer(layer_input, h_i, c_i)
            layer_input = layer_output

            # スタックに格納
            h_n.append(h_i_out)
            c_n.append(c_i_out)

        # 最終層の出力
        outputs = layer_output
        h_t = torch.cat(h_n, dim=0)
        c_t = torch.cat(c_n, dim=0)

        return outputs, (h_t, c_t)

ポイントとしては、層数のパラメータを引数として受け取っているので、モデル全体のアーキテクチャを nn.ModuleList でまとめます。

        self.layers = nn.ModuleList()
        for layer in range(self.num_layers):
            in_size = self.input_size if layer == 0 else self.hidden_size * (
                2 if self.bidirectional else 1
            )
            self.layers.append(
                LSTMLayer(
                    in_size, self.hidden_size,
                    self.bidirectional, self.residual
                )
            )

あとは流すだけですね!

        h_n = []
        c_n = []
        layer_input = x

        for layer_idx, layer in enumerate(self.layers):
            h_start = layer_idx * num_directions
            h_i = h_0[h_start:h_start + num_directions]
            c_i = c_0[h_start:h_start + num_directions]

            layer_output, h_i_out, c_i_out = layer(layer_input, h_i, c_i)
            layer_input = layer_output

            h_n.append(h_i_out)
            c_n.append(c_i_out)

        outputs = layer_output
        h_t = torch.cat(h_n, dim=0)
        c_t = torch.cat(c_n, dim=0)

終わりに

各モジュールにおける処理内容はそこまで難しいものではないので、各自でカスタマイズしやすくて良いですね!

Discussion