Bytedance のリアルタイム推薦システム
はじめに
この記事は、Bytedance から出版され RecSys'22 の Workshop で取り上げられた以下の論文(Monolith: Real Time Recommendation System With Collisionless Embedding Table)の内容のメモです。詳細はあまり踏み込まないので興味ある方は論文を読んでいただければと思います。
この論文は、一言で言うと、「実際の推薦サービスにおいてオンライン学習するためには」ということにフォーカスしてます。多くの企業では基本的にモデルの学習は、オフライン環境のバッチ処理で行うことが普通だと思います。なぜオンラインで学習する必要があるのかも含めて見ていこうと思います。この論文を選んだのは、実サービスで動くオンライン学習システムを扱っていたからです。
また一部[1]実装が公開されてます。
以下では注釈がない画像は上の論文より拝借してます。
Introduction
論文内では、推薦モデルが言語モデルや画像モデルと大きく異なる点として
- (Sparcity and Dynamism) 特徴量の大半がスパースで、カテゴリカルかつダイナミックに変化する
- (Non-stationary Distribution) 学習データの分布が非定常である(Concept Drift)
という2点が挙げられています。
Sparcity and Dynamism
- 言語モデルではトークン数は多くても 100k ~ 200k のオーダーだが、推薦モデルのユーザ数やアイテム数はそれよりオーダーが大きくなることがあり、その場合は一つのホストのメモリには乗り切らない。
- さらに、新規ユーザーや新規アイテムによって embedding テーブルのサイズが時間の経過とともに大きくなることが多い。
(応用上の)既存手法
- 諦める
- 学習データで頻度の高いもの上位 N 件のみ別々に扱いそれ以外は1つの embedding にマッピングする。
- hash(id) mod j
- collision を防ぐ
- linear probing
- double hash
- 一致を避ける
- 組み合わせの利用 Quotient-Remainder Embedding
これらの手法は、全ての ID が同じ頻度で現れるとして平等に扱っているが、実際はロングテールで少数のユーザやアイテムが高頻度で出現。
また、時間と共にテーブルが大きくなり、衝突が起こることでモデルの精度低下につながる。
Non-stationary Distribution
言語モデルや画像モデルでは数ヶ月から長いと百年というタイムオーダ(つまりほぼ時間による影響は無視)だが、推薦システムでは同じユーザでも分単位で興味が変化することがある。つまり、データは非定常であり、学習時と推論時でデータの分布が変わることを Concept Drift という。
解決法
上の2つの課題を
- 衝突しない embedding テーブル
- 高い耐障害性を持つ生産準備が整ったリアルタイムオンライントレーニングアーキテクチャ
で解決しようとした論文。
Design
当該の Monolith は Tensorflow の distributed Worker-ParameterServer という機能を用いている。
Worker はデータを読み込んで gradient を計算し、Parameter Server は gradient を受け取りモデルパラメタを更新するアーキテクチャである。
Hash Table
以下の3点が主な特徴です。
- Cuckoo Hashmap の利用
- 出現頻度が閾値を超えて初めて挿入する
- TTL
下2つは自明だと思うので、Cuckoo Hashmap だけ少し見ます。
Cuckoo Hashmap の利用
衝突なしに key を挿入できるハッシュマップです。
-
でルックアップと削除が可能O(1) - ならし
で挿入が可能O(1)
仮にサイクルになった場合は、再ハッシュされる。
余談ですが、Cuckoo はカッコウ(鳥)で、以下のような恐ろしい習性を持っています。
オオヨシキリの卵をひとつ取り出し、そのあとで自分の卵をひとつ産んだのです。巣の中には、カッコウの卵と、オオヨシキリの卵がひとつずつ残されていました。およそ12日で、カッコウのヒナが誕生します。カッコウは、オオヨシキリより早く孵ります。数時間もするとヒナはオオヨシキリの卵を外に捨てます。最後にはカッコウのヒナしか残りません。
この「追い出す」ところが同じであるところから、Cuckoo Hashmap という名前が付けられたようです。
Online Traning
- 学習はバッチとオンライン両方に対応
-
推論時の特徴量をユーザログに join する(おそらく)
- ユーザログが届いた時点の特徴量ではない
- ユーザログの遅れ(数日レベル)も join できるように disk レベルのキャッシュも利用
-
negative sampling
- log adds correction を用いて推論時に分布を補正
-
インクリメンタルなパラメタ同期
- モデルサイズは数テラバイト。推論を止めず、ネットワークなどにも影響を出さない同期が必要
- sparse なパラメタがモデルの大半を占めるので、前回の同期から更新されたパラメタのみを分レベルの間隔で同期
- dense なパラメタは同期頻度を小さくする(モーメンタムを利用したアルゴリズムでは急激な変化は起こりにくいため)
-
パラメタの snapshot は daily
- これはパラメタサーバが落ちるなどシステムの異常時のため
実験
「衝突なし」ハッシュテーブルは有効か
- どちらのケースでも「衝突なし」ハッシュテーブルの方が AUC が大きく性能が高い
- 実データの結果(右図)で AUC が上下しているのは Concept Drift による。(この時のパラメタ同期頻度は不明)
リアルタイムオンライン学習は重要か
- 頻度にかかわらずオンライン学習なしよりもありの方が性能が高い
- 頻度が短くなるほど性能が高くなる
- 実システムの Ads モデルでは大きな性能向上が見られた
パラメタ同期の仕組みはロバストか
略
感想
- モデルサイズがこれ以上大きくなると(今話題の)分散学習も必要になってきそうだが、 model parallelism と組み合わせたりするとかなり難易度が上がりそう。
- 他にもオンライン学習を行なっている実サービスの話があれば教えてください。
参考
Discussion