Spiking Nural Networkで教師なしMNIST
はじめに
この記事はSNNでMNISTを解くことを目的とした記事です。おそらくZennでSNNの解説はこの記事が初めてだと思います。
また、この記事は基本的にAn Unsupervised Spiking Neural Network Inspired By Biologically Plausible Learning Rules and Connections
に従います。
SNNとは?
SNN(Spiking Nural Network)とは現在主流のANN(Artificial Neural Network)が実数値を利用しているのに対し、ニューロンの離散的発火(1か0か)を利用して推論を行おう!というモデルです。今回の記事ではメジャーなLIF(積分漏れ発火モデル)を利用してニューロンの発火を再現していこうと思います。
今回使用する学習則について
SNNには基本STDP学習則、代理勾配を利用した誤差逆伝播法の二つがあります。今回はSTDP学習則を拡張したSTB-STDPを利用します。
STDPとは?
STDPとは2つのニューロンの発火時刻差に依存したシナプス結合強化が行われる学習方式になります。つまり、発火した入力値Aに対して出力値Bが発火していた場合AーB間の重みが増加し、発火した入力値 Aに対し出力値Bが発火していない場合はA-B間の重みが減少します。ここら辺はA,Bの発火時刻との差で増加量、減少量が変わります。結果的にはモデルにただデータを流すだけで勝手にニューロン間の結合が変化し、学習されていきます。これがSTDPを使ったSNNが教師なし学習と言われる所以です。
WTA
WTA(Win Takes All)は複数ニューロンが発火した際に、その中からランダムに一つ選び、それ以外を抑制します(発火しなかったことにする)。
ASF
ASF(Adaptive Synaptic Filter)はATBと同時に使われます。ざっくりいうとInput→Conv2d→ASF→NuralEncoding→Linear→LIF→ATBの順で使われます。学習可能なパラメータを持つLinearをASFとATBで挟み込んでいるイメージです。
ASFはシナプスの短期可塑性 (STP)を模倣しており、シナプスの信号伝達効率を増加または減少させることにより、情報処理のためのフィルタ機能を提供を目的としています。ASFは非線形関数であり、入力がが閾値または静止電位の近くに集中するように振る舞います。
実際の数式は論文を読んだ方が早いと思うので、ここではコードを載せておきます。
class ASFModule(nn.Module):
def __init__(self) -> None:
super(ASFModule, self).__init__()
def forward(self, x: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor:
x = x.clamp(min=0) # 負の値を0にする
x = threshold / (
1 + torch.exp(-(x - 4 * threshold / 10) * (8 / threshold))
) # ASF
# tにおけるthresholdを1+exp δtで割る
return x
ATB
ATB(Adaptive Threshold Balance)はニューロンがスパイクを発火すると閾値が増加し、最大閾値に達するとすべてのニューロンの閾値が減少します。 この動的バランスにより、単一のニューロンが支配されるのを防ぎ、すべてのニューロンに発火する機会が確保されます。
実際にどのように閾値が変化してるかは以下のコードの通りです。
以下のコードは徐々に
def updatethresh(self): # ATB
if self.spike is None:
return
#発火したニューロンの入力を閾値に足す
self.node.threshold.data += (self.alpha*self.x * self.spike.detach()).sum(0)
theta_bias = self.node.threshold.max() - self.gamma
if theta_bias > 0:#閾値の最大がgmmaを超えたら
self.node.threshold.data -= theta_bias#超えた分引く
STB-STDP
STB-STDPは、複数のサンプルと時間ステップからの情報を集めた後に重みを更新します。
これにより、入力はバッチで行われ、forwardではT回、発火モデルの計算を行います。
例として全結合レイヤーのforwardを以下に示します。
def forward(self, x: torch.Tensor, T: int = 1) -> torch.Tensor:
x = x.detach()
xsum: float | torch.Tensor = 0.0
if self.init is False:
self.weight(x)
self.normweight()
x = self.weight(x) # 重みをかける
self.x = x.detach()
if not self.init:
self.node.threshold.data = (x.max(0)[0].detach()*3).to(device)
self.init=True
for t in range(T):
xori=x
y = self.node(x) # 発火
if y.max()>0:
y = self.WTA(y)
y = self.lateralinh(y, xori)#侧抑制
self.spike = y
xsum += y
if isinstance(xsum, float):
raise TypeError
return xsum
また、重みが分岐したりシフトしたりするのを防ぐために正規化を使用します。正規化は、重みをすべて同じ大きさにスケーリングします。これにより、重みが大きくなりすぎたり小さくなりすぎたりするのを防ぎ、問題を防ぐことができます。全結合層では、1ステップごとに実行され、エンコーディング層では初めの一回だけ呼び出されます。
例として全結合レイヤーの正規化メソッドを示します。
def normweight(self):
self.weight.weight.data = torch.clamp(
self.weight.weight.data, min=0, max=1.0
) # 0~1にclamp
self.weight.weight.data /= self.weight.weight.data.max(1, True)[0] / 0.1
# 重みの最大値を0.1にする
Voting
Votingでは、最終層の発火したニューロンと正解ラベルの紐付け、また、紐付け済みニューロンによる多数決を行います。STDPだけでは、ある特徴に発火しやすいニューロンを作ることができるだけなので、このVotingが推論器の役割を果たします。
class Voting(nn.Module):
def __init__(self, shape):
super().__init__()
self.label = torch.zeros(shape) - 1
self.assignments:torch.Tensor
self.alpha = 0.7
def assign_labels(self, spikes:torch.Tensor, labels, rates=None, n_labels=10):
# labels 正解ラベル
# spikes => batch * time * in_size
# print(spikes.size())
n_neurons = spikes.size(2)
if rates is None:
rates = torch.zeros(n_neurons, n_labels, device=device)#node数*label数
self.n_labels = n_labels
#時間方向に合計
spikes = spikes.cpu().sum(1).to(device)#batch*time*in_size->batch*in_size
for i in range(n_labels):
n_labeled = torch.sum(labels == i).float() # label iの数
if n_labeled > 0:
indices = torch.nonzero(labels == i).view(-1) # label iのインデックス
tmp = torch.sum(spikes[indices], 0) / n_labeled # label iのときのバッチのSpikeの合計をlabel iの回数で割る
rates[:, i] = self.alpha * rates[:, i] + tmp #過去からの引き継ぎ*Alpha+今回の値
self.assignments = torch.max(rates, 1)[1]#もっとも対応したラベルをニューロンに割り当てる
return self.assignments, rates
def get_label(self, spikes:torch.Tensor):
n_samples = spikes.size(0)#batch
#時間方向に合計
spikes = spikes.cpu().sum(1).to(device)#batch*time*in_size->batch*in_size
rates = torch.zeros(n_samples, self.n_labels, device=device)#batch*label数
for i in range(self.n_labels):
n_assigns = torch.sum(self.assignments == i).float() #そのクラスに対応したニューロンの数
if n_assigns > 0:
indices = torch.nonzero(self.assignments == i).view(-1) # そのクラスに対応したニューロンのインデックス
rates[:, i] = torch.sum(spikes[:, indices], 1) / n_assigns # そのクラスに対応したニューロンのスパイクの合計をそのクラスに対応したニューロンの数で割る
return torch.sort(rates, dim=1, descending=True)[1][:, 0]#もっともスパイクが多いクラスを返す
モデル概要
T=1にしたときのモデルは以下の通りですSTB-STDPを使用するため、Tを増やすとLateralInhibition、WTAの層は増加していきます。
==========================================================================================
SNN_Conv_Model [1, 256] --
├─ModuleList: 1-1 -- 11
│ └─SNNConv2DEncoding: 2-1 [1, 12, 28, 28] --
│ │ └─Conv2d: 3-1 [1, 12, 28, 28] (108)
│ │ └─ASFModule: 3-2 [1, 12, 28, 28] --
│ │ └─LateralInhibition: 3-3 -- (13)
│ └─MaxPool2d: 2-2 [1, 12, 14, 14] --
│ └─Normliaze: 2-3 [1, 12, 14, 14] --
│ └─Flatten: 2-4 [1, 2352] --
│ └─STDP_LIF_Layer: 2-5 [1, 256] --
│ │ └─Linear: 3-4 [1, 256] 602,112
│ │ └─LateralInhibition: 3-5 -- (257)
==========================================================================================
Total params: 602,501
Trainable params: 602,112
Non-trainable params: 389
Total mult-adds (M): 0.69
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.15
Params size (MB): 2.41
Estimated Total Size (MB): 2.57
==========================================================================================
ノートブック
それらを踏まえて作成したノートブックがこちらになります。
終わりに
Votingのためにエポックごとの推論データとラベルデータを保存しなければいけないため、結構メモリを食います。SNNをSTDPで学習させた後にLGBMなどで分類するのも面白いかもしれないですね。
Discussion