👻

時系列基盤モデルへ②:UniTS【論文】

2024/06/05に公開

UniTS: Building a Unified Time Series Model

今回も, 前回に引き続き時系列基盤モデルを確認します. 今回はUniTSというモデルです. とても簡単にまとめると以下の図のようになります.

関連リンク

問題設定

まずは, この論文でどのようなタスクを解きたいのかを設定します. 以後登場する文字は断りのない限りここで定義したものです.

与えられるのは, 複数ドメインのデータセットの集合 D=\{D_i|i=1,\ldots,n\} です. 各データセットは様々な長さの時系列データで構成されて, D_i=(\mathcal{X}_i, \mathcal{Y}_i) で表されます. ここで, \mathcal{X}_o は時系列データ, \mathcal{Y}_i はタスクを表します. この論文では予測, 分類, 異常検知, 補間の4タスクを扱います. F(\mathcal{X}, \theta) を, D の全てのデータセットで学習したモデルとします. また, \hat{\mathcal{X}}\mathcal{X} に含まれないout-of-domainなデータとし, \hat{\mathcal{Y}}\mathcal{Y} に含まれない新しいタスクとします.

時系列データは \boldsymbol{x}\in\mathbb{R}^{l_i\times v} で表されます. ここで, v は変量の数, l_i は系列長です. それ以外にも, 様々なトークンを設定します.

  • sequence token: \boldsymbol{z}_s\in\mathbb{R}^{l_s\times v\times d}
  • prompt token: \boldsymbol{z}_p\in\mathbb{R}^{l_p\times v\times d}
  • mask token: \boldsymbol{z}_m\in\mathbb{R}^{1\times v\times d}
  • CLS token: \boldsymbol{z}_c\in\mathbb{R}^{1\times v\times d}
  • category token: \boldsymbol{z}_e\in\mathbb{R}^{k\times v\times d}

l_s, l_p はそれぞれ系列, プロンプトの数です. d は特徴量の次元, k は分類問題におけるカテゴリの数を表します. モデルにはトークンを結合した \boldsymbol{z}_{\mathrm{in}}\in\mathbb{R}^{l\times v\times d} が入力されます.

時系列タスクでは, 様々な性質を持つデータと様々なタスクが考えられるので, それらに合わせてモデル設計をしていました. しかし, ここではそれを統一的に行うことを考えます. すなわち, 任意のタスクについて F(\mathcal{X}, \theta) は全て同じ重み \theta を用います. それは以下の3つの要件を満たすようにします.

  1. Multi-domain time series: モデル F は様々なソースからの時系列 x における系列長 l_{\mathrm{in}} を変量数 v の多様性があるので, 入力サンプル \mathcal{X} に依存してはいけない.
  2. Universal task specification: モデル F は全てのタスク \mathcal{Y} に適応可能な仕様 F(\mathcal{X}, \theta)\rightarrow\mathcal{Y} を満たす必要がある.
  3. No task-specific modules: タスク間で重み \theta を共有することで, モデル F はタスクに対するfine-tuningなしに複数のタスクを処理できる必要がある.

仰々しく書かれていますが, 様々な時系列データと時系列タスクに対応し, それをfine-tuningなしでこなせるモデルを作ることが目標になります.

UniTS Model

まず, UniTSの全体像を確認します. 下図の(b)や(c)をみてわかるように, promptベースのモデルになります. モデルへの入力はprompt token, sequence token, task tokenの3つのトークンを結合したものになります.

これらプロンプトを駆使して異なるタスクを統一的に扱います. それぞれ確認します.

sequence token

PatchTSTにしたがって, 時系列データ \boldsymbol{x}\in\mathcal{X}_il_i 次元に分割します. パッチサイズは p です. すると, 長さが l_s=l_p/p\boldsymbol{z}_{\hat{s}} が得られます. 各 \boldsymbol{z}_{\hat{s}} を線形層に通して次元数が固定されたsequence token \boldsymbol{z}_s にします. ここには学習可能なpositional embeddingが加えられます. v はドメインによって異なるので, tokenに変量の次元を保持します.

このアプローチはLLMsなどの他の分野からの統一モデルの直接的な適用を阻害します. これに対処するために, 任意の数の変量を処理できる柔軟なネットワーク構造を提案しています.

prompt token

prompt token \boldsymbol{z}_p は学習可能な埋め込みとして定義されます. 複数のタスクを解く状況下では, 各タスクに固有のprompt tokenを設定します. ここでのpromptはLLMで用いるpromptとは異なり, 直感的な理解が難しいので各タスクに必要なtoken取得のためにprompt tuningを用います.

task token

図の(b, c)に示されるように, task tokenは以下の2つの主要なタイプに分類されます.

  1. Mask token: 予測, 補完, 異常検知などの生成モデリングで用います.
  2. CLS token and category embeddings: 分類などの認識タスクで用います.

task tokenはタスクを表現する一般的なフォーマットを定義することで, 新しいタスクに柔軟に対応することが可能です.

  • 予測: mask token \boldsymbol{z}_m は, 任意の長さの予測のためにモデル入力で繰り返され, UniTS 出力の繰り返しmask tokenがシーケンスに戻されます
  • 分類: UNITS 出力の CLS token \boldsymbol{z}_c はcategory embeddings \boldsymbol{z}_e と一致させます
  • 補完: 欠損部分はmask tokenを用いて補完されます.
  • 異常検知: モデルが返すノイズ除去されたsequence tokenを用いて, 異常なデータポイントを特定します


どのようにtask tokenを作るかですが, Appendix C.1に書かれています. ここでは予測タスクの場合を確認します.

mask token \boldsymbol{z}_m を予測したい系列長 l_f にしたがって複製します. その後, sequence tokenとprompt tokenに結合してネットワークに入力します. 数式で書くと

\boldsymbol{z}_{\mathrm{Fore}}=\mathrm{CA}(\boldsymbol{z}_p, \boldsymbol{z}_s, \mathrm{RE}(\boldsymbol{z}_m, l_f))\in\mathbb{R}^{(l_p+l_s+l_f)\times v\times d}

です. CAはconcatenation, REはrepeatを表します.

続いて, アーキテクチャの部分を見ます. 先ほどの図の(d)を再掲します.

これを見ればわかるように, UniTS Blockを N 個並べた構造になっています. blockの出力をtowerという軽めのネットワークに通して出力とします. 見ればわかるように, UniTS Blockはgate, Sequence Multi-Head Self-Attention, Variable Multi-Head Self-Attention, Dynamic MLPの4つから構成されています.

以降では, 主要な3つのモジュールとDynamic MLPの中にあるDyLinearについて深掘りします.

Sequence and Variable MHSA

時系列データは系列長や変量の数が異なりますが, それらを統一的に扱うのがこの2つのモジュールになっています. 統一的に扱うとは言っても片方ずつ処理して最終的に幅広いデータに対応します. Sequence MHSAはPatchTSTでも用いられているMHSAをそのまま利用します. Variable MHSAは長い系列に対する計算量を抑えつつ, 系列全体にわたる変量の関係性を細くするために Q, K を平均化して \hat{Q}, \hat{K} を得ます.

\hat{Q}, \hat{K}=\mathrm{mean}_l(Q, K)

Variable MHSAの出力は当然

\mathrm{Softmax}\left(\dfrac{\hat{Q}\hat{K}^T}{\sqrt{d}}\right)V

です. 簡単のためSingle Headで記述されていますが, 実際にはMulti-Headで行います.

DyLinear

Sequence MHSAは普通のattentionなので類似度ベースです. それとは対照的に, token間の密な関係をモデリングする動的線形演算子DyLinearをここでは導入します. 様々な系列長に対応するための重み補完スキームと見ることができて, 長さ l_s のsequence token \boldsymbol{z}_s と, 事前位定義された重み \boldsymbol{w}\in\mathbb{R}^{w_i\times w_o} に対して以下のように定義します.

\begin{align*} &\mathrm{DyLinear}(\boldsymbol{z}_s;\boldsymbol{w})=\boldsymbol{W}_{\mathrm{Interp}}\boldsymbol{z}_s \\ &\boldsymbol{W}_{\mathrm{Interp}}=\mathrm{Interp}(\boldsymbol{w}) \end{align*}

Interpはバイリニア補完で \boldsymbol{w}\in\mathbb{R}^{w_i\times w_o} を入力系列長と出力系列長 l_s\times l_{out} にリサイズします.

Dynamic MLP

DyLinearを盛り込んで, 局所的な詳細と大域的な関係性を捉えるモジュールです. 局所的な詳細を捉えるためにkernel sizeが3の畳み込み層を適用します. その後, d 次元の特徴量を (\boldsymbol{z}_{mid}^1, \boldsymbol{z}_{mid}^2)\in\mathbb{R}^{l\times v\times d/2} の2つに分割します. そして, 次のように処理します.

\boldsymbol{z}_{out}=\mathrm{Linear}(\mathrm{Concat}(\mathrm{DyLinear}_M(\boldsymbol{z}_{mid}^1), \boldsymbol{z}_{mid}^2))

Training

生成タスクと認識タスクに対応するための学習を行います. 基本的にはmasked modelingですが, prompt tokenとCLS tokenの両方の意味内容を効果的に再構成します. lossは以下のように表されます.

L_u=|H_m(\boldsymbol{z}_p, \boldsymbol{z}_s)-x|^2+|H_m(\hat{\boldsymbol{z}_c}, \boldsymbol{z}_s)-x|^2

x はmaskされていない状態の系列全体です. \hat{\boldsymbol{z}_c}=H_c(\boldsymbol{z}_{CLS}) は, CLS tower H_c で処理されたCLS tokenの特徴量で, H_m はmask towerです.

masked modelingとは別の学習方法として, 教師あり学習を行います (公式実装を見る限り追加訓練ではないと思います). 1つのデータセットからランダムに取り出して行います. lossは

L_{\mathrm{total}}=\sum_{i=1}^I\lambda_i\cdot L_i(D_i)

で, I はサンプリングした数で, \lambda は重みです. 予測タスクではMSE, 分類タスクではcross-entropyを L として用います.

実験

データセットとベースラインを確認してから実験結果を見ます.

データセット

38のデータを集めて用います. human activity, healthcare, financeなどの多様なドメインを持ち, 20は系列長が60から720までの予測タスク, 18が2クラスから52クラスまでの分類タスクです. 以下に概要を示します.

ベースライン

7つのベースラインを用意します. iTransformer, TimesNet, PatchTST, Pyraformer, Autoformer, GPT4TS, LLMTimeです.

LLMを使うようなものでは予測にしか対応していない場合がありますが, その場合は分類モジュールを追加して対応します.

結果: Multi-Task Learning

早速結果を示します.

まず, UniTS Promptと書かれたmasked modelingを行った結果を見ます. 多くの場合でbestあるいはsecond bestの性能であることがわかります.

また, ベースラインは単一タスクのみで高性能で複数タスクには対応できていないことがわかります. 例えばTimesNetでは分類タスクの結果はいいですが, 予測タスクの結果は悪いです. 逆に, iTransformerは予測タスクはいいですが, 分類タスクは悪いです. それとは対照的に, UniTSは両方のタスクで高性能な結果です.

LLMを適用する例としてGPT4TSの結果が示されています. 訓練データの規模とモデルの規模の両方で大幅な差がありますが, 提案手法の方がcompetitiveあるいはoutperformであることがわかります.

一方で, 教師あり学習を行ったUniTS Supは教師なしで学習したUniTS Promptと同じくらいの性能にとどまっています. それどころか予測タスクのMAEではUniTS Promptの方がいい結果となっています. 著者らはそれを根拠にPrompt Learningがいいと主張しています. 個人的にはこれだけで主張するのは難しいと思います. 他にも論文には主張が書かれていますがそれらも説得力に欠けると思います.

結果: Zero-Shot New-Length Forecasting

これまでの手法での様々な系列長の予測は複数の予測器を訓練することによって達成されますが, これは未知の系列長の予測には対応できません. UniTSはmask tokenを繰り返すだけで訓練時にはない系列長の予測に対応できます. これを既存手法と比較したいのですが, 既存手法ではそのようなことはできないので新たに予測スキームを開発しています. モデルは固定されたwindowの長さで予測をし, そのwindowをずらすことによって新しい長さで予測します. 14のデータセットを用いて比較します.

この図から, ベースラインの三手法より性能が向上していることがわかります. One-step推論が可能なので最大の384系列を追加で予測したときはiTransformerより3倍ほどの高速化が達成されています.

結果: Zero-shot Forecasting on New Datasets

訓練データにはないデータでの実験をします. 以下の表に示す新しいデータを用いて予測を行います.

LLMTimeとの比較を行います.

ほとんどのデータにおいてLLMTimeを上回る性能が得られています. また, 推論速度は100倍程度になり, 高速に高精度な推論がzero-shotで可能であることが示されています.

結果: Few-shot Classification and Forecasting

few-shot learningでの結果を見ます. 6つの分類データ, 9つの予測データから構成され, fine-tuningでは5%, 15%, 20%をそれぞれ訓練で用います. 比較手法はiTransformerです.

結果を確認します. 全ての状況においてUniTSはiTransformerを凌駕しています. 特に, データを増やすと性能はより向上しています. さらに, prompt learningのUniTSは完全な教師あり学習であるiTransformerを上回っており, さらにデータが5%の場合はMSEとMAEにおいてfine-tuningをも上回っています. few-shotではよりzero-shotに近い方が汎用的な能力を確認できるので, prompt learningがデータが少ない場合にも効果を発揮することが示唆されます.

結果: Few-shot Imputation

imputation taskでの結果を確認します. TimesNetで用いられた6つのデータを使用します. データのうち10%のみを、用いてfine-tuningし, 25%および50%の欠損値を補間します.

この場合でもUniTSが高性能を発揮していることがわかります. 特筆すべき点として, prompt learningで行った場合に, ベースラインを上回るだけでなくfine-tuningと同等の性能になっています. これは, 適切なprompt tokenを選択するだけでimputationではUniTSを効果的に適応できることを示しています.

結果: Few-shot Anomaly Detection

Imputationと同じデータを用いて5%を訓練に用います.

全ての指標でベースラインを上回ることが確認できます. 他の実験と比較してこの実験は記述量が大幅に減少しており, もう書くことがありません.

まとめ

  • unifiedな時系列モデルのUniTSを提案
  • prompt learningでmasked modelingをして学習
  • 既存の教師あり学習モデルより高性能かつzero-shotやfew-shotの設定でも未知データに適応可能

思ったこと

  • MOMENTと比較してモデル構造などが練られた印象を受けました
  • 学習済みモデルもGitHubのReleasesで公開されているのはいいと思いますが, huggingfaceの方が最近は主流なのでは?と思います
  • 実験結果に対する分析が「xx%良くなりました」や「この指標でxxxとyyyなので〜」みたいなものが多く, もう少しそこから得られる考察を書いて欲しいと思いました
  • few-shotのデータの比率は恣意的なものを感じます

参考文献

  • Gao, S., Koker, T., Owen, Q., Thomas, H. Theodoros, T., and Marinka, Z. UniTS: Building a Unified Time Series Model. arXiv preprint arXiv:2403.00131, 2024.

Discussion