🕰️

TimesNet解説 Part3 モデルの実装と比較

2024/03/24に公開

TimesNet解説のPart3です。Part1,Part2はこちらからどうぞ。

この記事は、TimesNet: The Latest Advance in Time Series Forecastingを参考に記述しています。英語が読める方はこちらを確認して下さい。

3. TimesNetによる予測

では、実際にTimesNetによる予測をmarcopeixさんの実装を参考に解説していきます。

3.1 ライブラリのインポート

必要なライブラリのインポートを行います。ここではNIXTLAのNeuralForecastで利用可能な実装を使用します。

# ライブラリのインポート
!pip install neuralforecast

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from neuralforecast.core import NeuralForecast
from neuralforecast.models import NHITS, NBEATS, TimesNet

from neuralforecast.losses.numpy import mae, mse

plt.rcParams["figure.figsize"] = (9,6)

3.2 データのロード

データをロードします。今回は時系列予測のベンチマークとしてよく利用されている、変圧器の油温予測を行うEttデータセットを使用します。

# Ettデータのロード
df = pd.read_csv('/kaggle/input/etth1-data/etth1.csv')

df['ds'] = pd.to_datetime(df['ds'])

df.head()

# 可視化
fig, ax = plt.subplots()

ax.plot(df['y'])
ax.set_xlabel('Time')
ax.set_ylabel('Oil temperature')

fig.autofmt_xdate()
plt.tight_layout()

・Ettデータセット

3.3 モデル定義

モデルを定義します。
予測期間は長期予測の一般的な期間である96時間に設定します。

N-BEATS、N-HiTS、TimesNetのモデルのリストを定義し、NeuralForecastsオブジェクトをインスタンス化します。
すべてのモデルのデフォルトパラメータをそのまま使用し、エポックの最大数を50に制限します。デフォルトでは、TimesNetはデータ内の最も重要な期間の上位5つを選択します。

# モデル定義
horizon = 96

models = [NHITS(h=horizon,
               input_size=2*horizon,
               max_steps=50),
         NBEATS(h=horizon,
               input_size=2*horizon,
               max_steps=50),
         TimesNet(h=horizon,
                 input_size=2*horizon,
                 max_steps=50)]

# インスタンス化
nf = NeuralForecast(models=models, freq='H')

3.4 モデル評価

3.4.1 交差検証

交差検証を使用して、モデルの性能を評価します。(Testデータセットに対して予測を行います)

# 交差検証
preds_df = nf.cross_validation(df=df, step_size=horizo​​n, n_windows= 2 )
# 各モデル予測結果を表示
preds_df.head()

# 各モデル予測を可視化
fig, ax = plt.subplots()

ax.plot(preds_df['y'], label='actual')
ax.plot(preds_df['NHITS'], label='N-HITS', ls='--')
ax.plot(preds_df['NBEATS'], label='N-BEATS', ls=':')
ax.plot(preds_df['TimesNet'], label='TimesNet', ls='-.')

ax.legend(loc='best')
ax.set_xlabel('Time steps')
ax.set_ylabel('Oil temperature')

fig.autofmt_xdate()
plt.tight_layout()
plt.savefig('forecast.png')
plt.show()

・各モデルの予測結果

・各モデルの予測結果の可視化

上記の結果では、どのモデルもデータの変動を予測できていないように見えます。
またN-BEATSとN-HiTSは、TimesNetの予測とは異なり、元データの周期的なパターンを少し捉えていますが、予測と呼ぶには精度が不十分です。

3.4.2 MAE,MSE

ここで、予測精度を定量的に評価するために、MAE(平均絶対誤差)とMSE(平均二乗誤差)を計算します。
MAEとMSEは、予測結果が元のデータからどの程度離れていたか、という指標で精度を評価します。

# MAE,MSEの計算
data = {'N-HiTS': [mae(preds_df['NHITS'], preds_df['y']), mse(preds_df['NHITS'], preds_df['y'])],
       'N-BEATS': [mae(preds_df['NBEATS'], preds_df['y']), mse(preds_df['NBEATS'], preds_df['y'])],
       'TimesNet': [mae(preds_df['TimesNet'], preds_df['y']), mse(preds_df['TimesNet'], preds_df['y'])]}

metrics_df = pd.DataFrame(data=data)
metrics_df.index = ['mae', 'mse']

metrics_df.style.highlight_min(color='lightgreen', axis=1)

・MAE,MSE

結果を見ると、N-HiTSがどちらの指標でも優れているようです。
※ N-BEATSもほぼ同様の精度を達成しています

従って、TimesNetは最高精度を達成できなかったことになります。しかし、注意すべきは単純な96×2stepの予測であり、ハイパーパラメータのチューニングなども行われていないということです。

3.5 結論

TimesNetとN-HiTS、N-BEATSについて、簡単な予測モデルを比較したところ、TimesNetは両モデルに及びませんでした。
しかし、TimesNetの構想である、「周期性を捉えるために、フーリエ変換を用いてデータを一次元から二次元に変換、重み付けしたものを特徴として扱う」という考え方は、幅広いタスクへの応用が考えられます。
また各予測における最適な手法は異なるため、様々なモデルを試すことが推奨されます。

まとめ

今回はTimesNetの実装から他モデルとの比較までを行いました。
TimesNet解説は今回で終了です。読んでいただきありがとうございました。

参考

(1)marcopeix/time-series-analysis/TimesNet.ipynb
(2)TimesNet: The Latest Advance in Time Series Forecasting
(3)TimesNet: Temporal 2D-Variation Modeling for General Time Series Analysis
(4)Going Deeper with Convolutions(GoogLeNet)

Discussion