🙇‍♂️

sktimeのドキュメントを読みながら遊ぶ(1)

2024/04/08に公開

sktimeはsckit-learn-likeなインターフェースを持つ、時系列予測に特化したライブラリです。sktimeを使った情報が思ったより少なかったので、今回はこのsktimeを使って、ドキュメントを読みながら予測までの一通りのフローを流してみようと思います。

ボリューム的に複数回になる匂いがしたので、タイトルのsuffixは(1)としています。

sktimeで時系列予測を行う際の手順は以下のようになります。

  1. データの準備
    • 今回はあらかじめ用意されているairlineデータを使います。
  2. 予測したいインデックス(時点)の指定
    • ForecastingHorizonの定義
  3. scikit-learnのような構文で予測器の作成
    • あのインターフェースってBaseEstimatorって呼ばれるらしいです。知らんかった。
  4. 学習
    • fit
  5. 予測
    • predict

この手順に沿って一連をやってみます。

最初に、必要なライブラリをimportしておきます。

import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# hide warnings
warnings.filterwarnings("ignore")

データ準備

今回はビルトインされているairlineデータを使用します。これは1949年から1960年までの国際航空会社の月間合計乗客数を表しており、値の単位は1000人です。

読み込みと、データの可視化をおこないます。可視化にはsktimeのplot_series関数を使用しました。

from sktime.datasets import load_airline
from sktime.utils.plotting import plot_series

y = load_airline()
plot_series(y)

本来はこの後に、データから外れ値を取り除いたり、特徴量の作成をおこなうのですが、今回は取り敢えずこのまま素のデータで予測までやります。

ForecastHorizonの設定

予測実行する時点の設定をおこなうのですが、その前に、データを学習用と評価用に分け、学習は予測(評価)時点以前のデータだけでおこなうようにしたいです。

from sktime.split import temporal_train_test_split
y_train, y_test = temporal_train_test_split(y, test_size=36)

sktimeのtrain_test_splitは、時系列データ予測の特性上、データを前後で区切ってくれます。時系列は過去のデータの動きが重要なので、飛び飛びのデータを抜き出しても意味ないですよね。

今回は全データのうち、後ろから36ヶ月分を評価用データとして抜き出しました。

y_train

time value
1949-01 112.0
1949-02 118.0
1949-03 132.0
... ...
1957-10 347.0
1957-11 305.0
1957-12 336.0

y_test

time value
1958-01 340.0
1958-02 318.0
1958-03 362.0
... ...
1960-10 461.0
1960-11 390.0
1960-12 432.0

プロットしてみます。

from sktime.utils import plotting

fig, ax = plotting.plot_series(
    y_train, y_test, labels=["y_train", "y_test"])

青色で示された系列が学習に使用するデータ、オレンジ色で示された系列は評価に使うデータになります。
評価用の系列の日付は1958-01から1960-12までの36ヶ月となりました。これを予測器に渡すために、ForecastHorizonを作成します。

ForecastHorizonでは、is_relative=Falseを設定することで、特定の時点を予測することができます。反対に、Trueにすると、相対的(例えば、学習データの最終時点を基点として、2ヶ月毎など)な時点を指定することができます。今回は36ヶ月分を絶対的に指定しました。

from sktime.forecasting.base import ForecastingHorizon

fh = ForecastingHorizon(
    pd.PeriodIndex(pd.date_range("1958-01", periods=36, freq="M")), is_relative=False
)
ForecastingHorizon(['1958-01', '1958-02', '1958-03', '1958-04', '1958-05', '1958-06',
             '1958-07', '1958-08', '1958-09', '1958-10', '1958-11', '1958-12',
             '1959-01', '1959-02', '1959-03', '1959-04', '1959-05', '1959-06',
             '1959-07', '1959-08', '1959-09', '1959-10', '1959-11', '1959-12',
             '1960-01', '1960-02', '1960-03', '1960-04', '1960-05', '1960-06',
             '1960-07', '1960-08', '1960-09', '1960-10', '1960-11', '1960-12'],
            dtype='period[M]', is_relative=False)

予測器の定義

今回はSARIMAを使って予測をおこなおうと思います。
ARIMAにsp(seasonal period)を12で設定してあげることで、季節性を考慮したSARIMA予測器を作成しました。

from sktime.forecasting.arima import AutoARIMA

# suppress_warning = True 
# 最尤推定が最適化しない場合にwarningが出る
# 本当は確認するべき
# 今回は後回しで別記事で触れる予定
sarima = AutoARIMA(sp=12, suppress_warnings=True)

学習

学習用データで学習を実行します。これはsklearn-likeにfitの引数にy_trainを渡すだけでおこなってくれます。

sarima.fit(y_train)

予測

予測も簡単にpredictにForecastHorizonを渡すだけで、指定されている時点の分だけ予測してくれます。

y_pred = sarima.predict(fh)

可視化

予測まで実行してみました。どのような結果になったでしょうか。可視化してみましょう。

plot_series(y_train, y_test, y_pred, labels = ["y_train","y_test","predicted"])

  • グラフの推移を見ると、最初の12ヶ月ぐらいはそこそこ精度も良いのではないかなと思います。
  • 気になるのは、2個目と3個目のピークで、伸びきれていないことです。
    これは、1957年のピークに対して、1958年はトレンド通り伸びているようにみえますが、1959年,1960年のピークは今までの増加トレンドと比較して、更に伸びが大きくなっているのが追随しきれなかった原因かと思っています。

評価

定量的な評価指標として、今回はMAPE(Mean Absolute Percentage Error)を用いることとします。
これもビルトインで実装されているので、簡単に使えます。

mean_absolute_percentage_error関数に評価期間のデータと、実際に予測されたデータを渡します。

from sktime.performance_metrics.forecasting import mean_absolute_percentage_error
mean_absolute_percentage_error(y_test, y_pred)

ピークは少し過小評価してしまいましたが、全体の平均でみれば、MAPEが4%程度と、思ったより精度がいいなという感想です。ピークを重要視して、RMSEでの評価も追加でおこなった方が良いかもしれません。

0.041489714363285066

Coverageを予測、可視化

predict_intervalを用いて、90%で実際の結果がこの中に入るぜ!という範囲を計算します。これが簡単にできるのは結構便利。

coverage = 0.90
y_pred_ints = sarima.predict_interval(coverage=coverage)
y_pred_ints
from sktime.utils import plotting

fig, ax = plotting.plot_series(
    y_train, y_test, y_pred, labels=["y_train","y_test" ,"y_pred"], pred_interval=y_pred_ints
)

最初の12ヶ月分は結構確信度が高く、時を経るほど範囲は大きくなっていきます。季節性が12ヶ月周期で、更に増加傾向も少し変わっているのでこんなものかなと思います。特徴量を増やしたりするとどうなるかは気になるところです。

まとめ

今回は使う機会のなかったsktimeを使って、SARIMAによる予測を一通りおこないました。sktimeの全体像はまだわからないですが、都度ドキュメントを読みながら備忘録も兼ねて記事にしていこうと思います。

というか、次回は追加学習周りを読んでみようと思います。

Goals Tech Blog

Discussion