今更だけどLSTMを自作してみた
こんにちは!
株式会社アイディオットでデータサイエンティストをしています、秋田と申します。
今更ですが、PyTorchでLSTMを自作してみました。
「え、 torch.nn.LSTM
ってモジュールがあるよ?」
おっしゃる通り...ですが、改めて自分で作成してみると公式のムズカシイドキュメントを眺めるよりもシンプルに理解ができるのと、改造が可能なことが大きな利点として挙げられるのではないでしょうか!
公式で提供されているLSTMのモジュールは、RNNクラスも継承しているため色々参照しなければいけないので大変なんですよね😅
今回自作したLSTMでは、極力見慣れたメソッドやクラスだけで構成しているので上記の心配はありません。
全体像
まずはプログラム全体をご紹介いたします。
from typing import Optional
import torch
from torch import nn
class LSTMCell(nn.Module):
def __init__(self, input_size: int, hidden_size: int):
super(LSTMCell, self).__init__()
# 入力層サイズと隠れ層サイズを設定
self.input_size = input_size
self.hidden_size = hidden_size
# 各種重みとバイアスの定義
self.weight_ih = nn.Parameter(
torch.randn(4 * self.hidden_size, self.input_size)
)
self.weight_hh = nn.Parameter(
torch.randn(4 * self.hidden_size, self.hidden_size)
)
self.bias_ih = nn.Parameter(torch.zeros(4 * self.hidden_size))
self.bias_hh = nn.Parameter(torch.zeros(4 * self.hidden_size))
# 活性化関数の定義
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
def forward(
self, x_t: torch.Tensor, h_prev: torch.Tensor, c_prev: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
順伝播を行う関数
Parameters
----------
x_t: torch.Tensor
入力データ
h_prev: torch.Tensor
1ステップ前の隠れ状態
c_prev: torch.Tensor
1ステップ前のセル
Returns
----------
h_t: torch.Tensor
隠れ状態
c_t: torch.Tensor
セル
"""
# 総合的なゲート計算
gates = (
torch.matmul(x_t, self.weight_ih.t()) +
torch.matmul(h_prev, self.weight_hh.t()) +
self.bias_ih + self.bias_hh
)
# ゲート分割
i, f, g, o = gates.chunk(4, dim=1)
# 活性化関数適用
i = self.sigmoid(i)
f = self.sigmoid(f)
g = self.tanh(g)
o = self.sigmoid(o)
# セル・隠れ状態の更新
c_t = f * c_prev + i * g
h_t = o * self.tanh(c_t)
return h_t, c_t
class LSTMLayer(nn.Module):
def __init__(
self, input_size: int, hidden_size: int,
bidirectional: bool=False, residual: bool=False
):
super(LSTMLayer, self).__init__()
# 入力層サイズと隠れ層サイズを設定
self.input_size = input_size
self.hidden_size = hidden_size
# 双方向にするかを設定
self.bidirectional = bidirectional
# 残差接続するかを設定
self.residual = residual
# (双方向の)LSTMセルのインスタンスの作成
self.cell_fwd = LSTMCell(self.input_size, self.hidden_size)
if self.bidirectional:
self.cell_bwd = LSTMCell(self.input_size, self.hidden_size)
def forward(
self, x: torch.Tensor, h_0: torch.Tensor, c_0: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
順伝播を行う関数
Parameters
----------
x: torch.Tensor
入力データ
h_0: torch.Tensor
初期化された隠れ状態
c_0: torch.Tensor
初期化されたセル
Returns
----------
outputs: torch.Tensor
出力データ
h_t: torch.Tensor
隠れ状態
c_t: torch.Tensor
セル
"""
# 入力長を取得
seq_len, _, _ = x.size()
# forward方向の処理
h_fwd, c_fwd = h_0[0], c_0[0]
outputs_fwd = []
for t in range(seq_len):
h_fwd, c_fwd = self.cell_fwd(x[t], h_fwd, c_fwd)
outputs_fwd.append(h_fwd.unsqueeze(0))
outputs_fwd = torch.cat(outputs_fwd, dim=0)
# 双方向ではない場合用に分岐
if not self.bidirectional:
outputs = outputs_fwd
h_t = h_fwd.unsqueeze(0)
c_t = c_fwd.unsqueeze(0)
else:
# backward方向の処理
h_bwd, c_bwd = h_0[1], c_0[1]
outputs_bwd = []
for t in reversed(range(seq_len)):
h_bwd, c_bwd = self.cell_bwd(x[t], h_bwd, c_bwd)
outputs_bwd.insert(0, h_bwd.unsqueeze(0))
outputs_bwd = torch.cat(outputs_bwd, dim=0)
# forward/backwardの出力を連結
outputs = torch.cat([outputs_fwd, outputs_bwd], dim=2)
h_t = torch.stack([h_fwd, h_bwd], dim=0)
c_t = torch.stack([c_fwd, c_bwd], dim=0)
# 残差接続
if self.residual and outputs.shape == x.shape:
outputs = outputs + x
return outputs, h_t, c_t
class LSTM(nn.Module):
def __init__(
self, input_size: int, hidden_size: int, num_layers: int=1,
bidirectional: bool=False, residual: bool=False
):
super(LSTM, self).__init__()
# 入力層サイズと隠れ層サイズを設定
self.input_size = input_size
self.hidden_size = hidden_size
# 層数を設定
self.num_layers = num_layers
# 双方向にするかを設定
self.bidirectional = bidirectional
# 残差接続するかを設定
self.residual = residual
# LSTM全体の構造を作成する
self.layers = nn.ModuleList()
for layer in range(self.num_layers):
in_size = self.input_size if layer == 0 else self.hidden_size * (
2 if self.bidirectional else 1
)
self.layers.append(
LSTMLayer(
in_size, self.hidden_size,
self.bidirectional, self.residual
)
)
def forward(
self, x: torch.Tensor,
h_0: Optional[torch.Tensor] = None,
c_0: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""
順伝播を行う関数
Parameters
----------
x: torch.Tensor
入力データ
h_0: Optional[torch.Tensor]=None
初期化された隠れ状態
c_0: Optional[torch.Tensor]=None
初期化されたセル
Returns
----------
outputs: torch.Tensor
出力データ
h_t: torch.Tensor
隠れ状態
c_t: torch.Tensor
セル
"""
# バッチサイズを取得
_, batch_size, _ = x.size()
# 方向数を指定
num_directions = 2 if self.bidirectional else 1
# 初期状態の隠れ状態とセルを定義
if h_0 is None:
h_0 = torch.zeros(
self.num_layers * num_directions, batch_size,
self.hidden_size, device=x.device
)
if c_0 is None:
c_0 = torch.zeros(
self.num_layers * num_directions, batch_size,
self.hidden_size, device=x.device
)
# スタックの用意
h_n = []
c_n = []
layer_input = x
# 層ごとに処理・格納
for layer_idx, layer in enumerate(self.layers):
# 初期化
h_start = layer_idx * num_directions
h_i = h_0[h_start:h_start + num_directions]
c_i = c_0[h_start:h_start + num_directions]
# 次の層に更新
layer_output, h_i_out, c_i_out = layer(layer_input, h_i, c_i)
layer_input = layer_output
# スタックに格納
h_n.append(h_i_out)
c_n.append(c_i_out)
# 最終層の出力
outputs = layer_output
h_t = torch.cat(h_n, dim=0)
c_t = torch.cat(c_n, dim=0)
return outputs, (h_t, c_t)
3つのモジュールに分割し、 LSTMCell
はLSTMの基本アルゴリズムを、 LSTMLayer
はレイヤーレベルのまとまりを、 LSTM
は層全体を司ります。
ただし、読み出しの線形層は含まれていないことに注意してください。
公式との顕著な差
公式にあるのにここには無いものはいくつもありますが、公式には無くてここにあるものが1つだけあります。
それが「残差接続」です。
残差接続というと、画像認識の分野でのResNetが有名ですが、層を深くしたら表現力が上がることが期待される一方で勾配が消失・爆発するという問題があったのをある程度解消することが出来たことで話題になりましたね。
これは時系列解析の分野でも同様で、学習の安定化のために残差接続を行うことが検討されてきています。
凄く簡単な修正で実装が済むものの、公式の nn.LSTM
クラスから作るのは至難の業です。
ということでカスタマイズできるように1から作る必要がありました!
LSTMCell
class LSTMCell(nn.Module):
def __init__(self, input_size: int, hidden_size: int):
super(LSTMCell, self).__init__()
# 入力層サイズと隠れ層サイズを設定
self.input_size = input_size
self.hidden_size = hidden_size
# 各種重みとバイアスの定義
self.weight_ih = nn.Parameter(
torch.randn(4 * self.hidden_size, self.input_size)
)
self.weight_hh = nn.Parameter(
torch.randn(4 * self.hidden_size, self.hidden_size)
)
self.bias_ih = nn.Parameter(torch.zeros(4 * self.hidden_size))
self.bias_hh = nn.Parameter(torch.zeros(4 * self.hidden_size))
# 活性化関数の定義
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
def forward(
self, x_t: torch.Tensor, h_prev: torch.Tensor, c_prev: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
順伝播を行う関数
Parameters
----------
x_t: torch.Tensor
入力データ
h_prev: torch.Tensor
1ステップ前の隠れ状態
c_prev: torch.Tensor
1ステップ前のセル
Returns
----------
h_t: torch.Tensor
隠れ状態
c_t: torch.Tensor
セル
"""
# 総合的なゲート計算
gates = (
torch.matmul(x_t, self.weight_ih.t()) +
torch.matmul(h_prev, self.weight_hh.t()) +
self.bias_ih + self.bias_hh
)
# ゲート分割
i, f, g, o = gates.chunk(4, dim=1)
# 活性化関数適用
i = self.sigmoid(i)
f = self.sigmoid(f)
g = self.tanh(g)
o = self.sigmoid(o)
# セル・隠れ状態の更新
c_t = f * c_prev + i * g
h_t = o * self.tanh(c_t)
return h_t, c_t
まずはLSTMの構成として、入力層と隠れ層、それからそれらのバイアスについてのパラメータを用意する必要がありますね。
それぞれ4つのゲート(入力・忘却・セル入力・出力)に対してパラメータが必要なので 4 * self.hidden_size
と隠れサイズを4倍しています。
4つのゲートの計算については次のように定義されます。
forwardの計算では、まずは上記の計算を一気に行うため、
gates = (
torch.matmul(x_t, self.weight_ih.t()) +
torch.matmul(h_prev, self.weight_hh.t()) +
self.bias_ih + self.bias_hh
)
として、ゲート全体を定義し、この後 .chunk()
で分割して活性化関数にかけます。
そして、セルと隠れ状態の更新を次の式の通りに行います。
こちらは次のように書いています。
c_t = f * c_prev + i * g
h_t = o * self.tanh(c_t)
ここまでが、LSTMが基本的に行うアルゴリズムの設計になります。
LSTMLayer
class LSTMLayer(nn.Module):
def __init__(
self, input_size: int, hidden_size: int,
bidirectional: bool=False, residual: bool=False
):
super(LSTMLayer, self).__init__()
# 入力層サイズと隠れ層サイズを設定
self.input_size = input_size
self.hidden_size = hidden_size
# 双方向にするかを設定
self.bidirectional = bidirectional
# 残差接続するかを設定
self.residual = residual
# (双方向の)LSTMセルのインスタンスの作成
self.cell_fwd = LSTMCell(self.input_size, self.hidden_size)
if self.bidirectional:
self.cell_bwd = LSTMCell(self.input_size, self.hidden_size)
def forward(
self, x: torch.Tensor, h_0: torch.Tensor, c_0: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
順伝播を行う関数
Parameters
----------
x: torch.Tensor
入力データ
h_0: torch.Tensor
初期化された隠れ状態
c_0: torch.Tensor
初期化されたセル
Returns
----------
outputs: torch.Tensor
出力データ
h_t: torch.Tensor
隠れ状態
c_t: torch.Tensor
セル
"""
# 入力長を取得
seq_len, _, _ = x.size()
# forward方向の処理
h_fwd, c_fwd = h_0[0], c_0[0]
outputs_fwd = []
for t in range(seq_len):
h_fwd, c_fwd = self.cell_fwd(x[t], h_fwd, c_fwd)
outputs_fwd.append(h_fwd.unsqueeze(0))
outputs_fwd = torch.cat(outputs_fwd, dim=0)
# 双方向ではない場合用に分岐
if not self.bidirectional:
outputs = outputs_fwd
h_t = h_fwd.unsqueeze(0)
c_t = c_fwd.unsqueeze(0)
else:
# backward方向の処理
h_bwd, c_bwd = h_0[1], c_0[1]
outputs_bwd = []
for t in reversed(range(seq_len)):
h_bwd, c_bwd = self.cell_bwd(x[t], h_bwd, c_bwd)
outputs_bwd.insert(0, h_bwd.unsqueeze(0))
outputs_bwd = torch.cat(outputs_bwd, dim=0)
# forward/backwardの出力を連結
outputs = torch.cat([outputs_fwd, outputs_bwd], dim=2)
h_t = torch.stack([h_fwd, h_bwd], dim=0)
c_t = torch.stack([c_fwd, c_bwd], dim=0)
# 残差接続
if self.residual and outputs.shape == x.shape:
outputs = outputs + x
return outputs, h_t, c_t
ここでは2つ、双方向LSTMの構造にするかと残差接続のある構造にするかのオプションをつけられるようにしています。
どちらもBooleanで与えられ、 True
の場合にその構造にするよう設計してあります。
双方向LSTM
普通のLSTMは、系列データ(系列長
一方で、双方向LSTM(Bidirectional LSTM)では、系列の予測したいデータの前後(
劣化版Attentionといったところでしょうかね。
self.cell_fwd = LSTMCell(self.input_size, self.hidden_size)
if self.bidirectional:
self.cell_bwd = LSTMCell(self.input_size, self.hidden_size)
__init__()
の部分で、引数 bidirectional
によって後ろの系列を予測するか判断し、双方向にするのであれば2つの LSTMCell
のインスタンスを使ってモデルを構築します。
h_fwd, c_fwd = h_0[0], c_0[0]
outputs_fwd = []
for t in range(seq_len):
h_fwd, c_fwd = self.cell_fwd(x[t], h_fwd, c_fwd)
outputs_fwd.append(h_fwd.unsqueeze(0))
outputs_fwd = torch.cat(outputs_fwd, dim=0)
forward()
メソッドの中で、まずは前からの計算処理の結果を持っておきます。
もし、 bidirectional=False
にしていれば、この部分のみでLSTMの処理が完結すると考えてください。
ここで、 h_fwd, c_fwd = h_0[0], c_0[0]
としていますが、 0
番目の要素を指定しているのは「前からの計算を行う」からです。
後ほど「後ろからの計算を行う」ときは、ここを 1
にします。
PyTorchの公式 nn.LSTM
でも 0
が前、 1
が後ろであることが記載されています。
単方向であれば、次のコードで終了します。
if not self.bidirectional:
outputs = outputs_fwd
h_t = h_fwd.unsqueeze(0)
c_t = c_fwd.unsqueeze(0)
双方向の場合、後ろも同様に計算を行います。
h_bwd, c_bwd = h_0[1], c_0[1]
outputs_bwd = []
for t in reversed(range(seq_len)):
h_bwd, c_bwd = self.cell_bwd(x[t], h_bwd, c_bwd)
outputs_bwd.insert(0, h_bwd.unsqueeze(0))
outputs_bwd = torch.cat(outputs_bwd, dim=0)
前述の通り、 h_bwd, c_bwd = h_0[1], c_0[1]
となっていますね。
また、for文が reversed()
になっていることと、 outputs_bwd
に append()
ではなく insert()
をしているところに注意です。
2つの結果を統合します。
outputs = torch.cat([outputs_fwd, outputs_bwd], dim=2)
h_t = torch.stack([h_fwd, h_bwd], dim=0)
c_t = torch.stack([c_fwd, c_bwd], dim=0)
これで双方向LSTMの処理の流れは終わりです。
残差接続
そもそも残差接続とは、入力と出力の差を予測することで勾配消失・爆発を減らすというものでした。
時系列データでは、一連の系列に因果があり、データの分布が一定であるため差分の変化が極端ではないことが予想できます。
if self.residual and outputs.shape == x.shape:
outputs = outputs + x
差分を足すだけでシンプルですね!
LSTM
class LSTM(nn.Module):
def __init__(
self, input_size: int, hidden_size: int, num_layers: int=1,
bidirectional: bool=False, residual: bool=False
):
super(LSTM, self).__init__()
# 入力層サイズと隠れ層サイズを設定
self.input_size = input_size
self.hidden_size = hidden_size
# 層数を設定
self.num_layers = num_layers
# 双方向にするかを設定
self.bidirectional = bidirectional
# 残差接続するかを設定
self.residual = residual
# LSTM全体の構造を作成する
self.layers = nn.ModuleList()
for layer in range(self.num_layers):
in_size = self.input_size if layer == 0 else self.hidden_size * (
2 if self.bidirectional else 1
)
self.layers.append(
LSTMLayer(
in_size, self.hidden_size,
self.bidirectional, self.residual
)
)
def forward(
self, x: torch.Tensor,
h_0: Optional[torch.Tensor] = None,
c_0: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""
順伝播を行う関数
Parameters
----------
x: torch.Tensor
入力データ
h_0: Optional[torch.Tensor]=None
初期化された隠れ状態
c_0: Optional[torch.Tensor]=None
初期化されたセル
Returns
----------
outputs: torch.Tensor
出力データ
h_t: torch.Tensor
隠れ状態
c_t: torch.Tensor
セル
"""
# バッチサイズを取得
_, batch_size, _ = x.size()
# 方向数を指定
num_directions = 2 if self.bidirectional else 1
# 初期状態の隠れ状態とセルを定義
if h_0 is None:
h_0 = torch.zeros(
self.num_layers * num_directions, batch_size,
self.hidden_size, device=x.device
)
if c_0 is None:
c_0 = torch.zeros(
self.num_layers * num_directions, batch_size,
self.hidden_size, device=x.device
)
# スタックの用意
h_n = []
c_n = []
layer_input = x
# 層ごとに処理・格納
for layer_idx, layer in enumerate(self.layers):
# 初期化
h_start = layer_idx * num_directions
h_i = h_0[h_start:h_start + num_directions]
c_i = c_0[h_start:h_start + num_directions]
# 次の層に更新
layer_output, h_i_out, c_i_out = layer(layer_input, h_i, c_i)
layer_input = layer_output
# スタックに格納
h_n.append(h_i_out)
c_n.append(c_i_out)
# 最終層の出力
outputs = layer_output
h_t = torch.cat(h_n, dim=0)
c_t = torch.cat(c_n, dim=0)
return outputs, (h_t, c_t)
ポイントとしては、層数のパラメータを引数として受け取っているので、モデル全体のアーキテクチャを nn.ModuleList
でまとめます。
self.layers = nn.ModuleList()
for layer in range(self.num_layers):
in_size = self.input_size if layer == 0 else self.hidden_size * (
2 if self.bidirectional else 1
)
self.layers.append(
LSTMLayer(
in_size, self.hidden_size,
self.bidirectional, self.residual
)
)
あとは流すだけですね!
h_n = []
c_n = []
layer_input = x
for layer_idx, layer in enumerate(self.layers):
h_start = layer_idx * num_directions
h_i = h_0[h_start:h_start + num_directions]
c_i = c_0[h_start:h_start + num_directions]
layer_output, h_i_out, c_i_out = layer(layer_input, h_i, c_i)
layer_input = layer_output
h_n.append(h_i_out)
c_n.append(c_i_out)
outputs = layer_output
h_t = torch.cat(h_n, dim=0)
c_t = torch.cat(c_n, dim=0)
終わりに
各モジュールにおける処理内容はそこまで難しいものではないので、各自でカスタマイズしやすくて良いですね!
Discussion