AutoEncoderを使った異常検知を試してみた
テーブルデータの異常検知を学ぶため、下記のサイトのコードを試してみました。
データセットは kaggle の心電図データを使っています。
PTB診断ECGデータベースの正常データをAutoEncoderで学習し、正常データと異常データの異常スコアを算定して比較しました。
データセットの詳細はこちらをご覧ください。
Colab上で実行したコードはこちら。
データセットは Google Drive にアップロードして実行しています。
なお、ニューラルネット詳しくないマンなので、変なところがありましたら優しくご指摘いただけると喜びます。
参考サイトからのコードの変更点
コードは参考サイトから流用しているので詳細は参考サイト(再掲)か、Colab のデータをご覧ください。
なお、使用したデータセットのカラム数 188 に合わせてネットワークの層の数を変更し、1つ層を増やしています。
class AutoEncoder(nn.Module):
def __init__(self):
super(AutoEncoder, self).__init__()
self.dense_enc1 = nn.Linear(188*20, 1024)
self.bn1 = nn.BatchNorm1d(1024)
self.dense_enc2 = nn.Linear(1024, 512)
self.bn2 = nn.BatchNorm1d(512)
self.dense_enc3 = nn.Linear(512, 256)
self.bn3 = nn.BatchNorm1d(256)
self.dense_enc4 = nn.Linear(256,128)
self.dense_dec1 = nn.Linear(128,256)
self.bn4 = nn.BatchNorm1d(256)
self.dense_dec2 = nn.Linear(256, 512)
self.bn5 = nn.BatchNorm1d(512)
self.dense_dec3 = nn.Linear(512, 1024)
self.bn6 = nn.BatchNorm1d(1024)
self.drop1 = nn.Dropout(p=0.2)
self.dense_dec4 = nn.Linear(1024, 188*20)
# 以下略
使用データ
正常データは ptbdb_normal.csv を使っています。
なお、下記グラフは心電図5つを1つのグラフに表示しています(異常データも同様)。
また、6割を学習用データ、4割を評価用データとして分割しています。
異常データは ptbdb_abnormal.csv を使っています。
異常スコア
数が少ないのが気になりますが、正常データの異常スコアは約 0.01 となっています。
異常データの異常スコアは 0.03~0.05となっており、正常データに比べスコアが高くなっております。
0.02~0.03を閾値にすれば異常検知ができそうです。
その他
Azure の異常検知サービス、Anomaly Detectorや、Amazon Lookout for Metricsも気になるのでそのうち試してみたいと思います。
また、異常検知の学習にあたり、下記の本を読みました。
機械学習や統計の考え方や時系列解析の考え方が丁寧に解説されていました。
以上になります、最後までお読みいただきありがとうございました。
Discussion