🐱

Tensorflow Recommendersの精度を上げるためのテクニック

2022/09/19に公開

はじめに

Tensorflow Recommenders (TFRS) は Tensorflow による推薦システム構築のためのライブラリです。特に大規模サービスへの応用が念頭に置かれており、two-tower アーキテクチャの推薦モデルの構築と、近似近傍探索による高速な推論を可能にしてくれます。

公式のチュートリアルで基本的な使い方を学ぶことができます。また以下のようなブログ記事も参考になるでしょう。

TensorFlow Recommenderで映画のレコメンダーシステムを構築
TensorFlow Recommendersの紹介
大規模サービスで効率よくレコメンドを提供するためにTensorflow Recommendersを活用する

今回この記事を書こうと思ったのは、公式のチュートリアルなどではTFRSを使う上で重要な精度面に関する情報が不足していると思ったからです。この後実演するように、公式のチュートリアルに沿って作成した推薦モデルでは、単なる most popular 推薦にも勝つことができません。実はTFRSのAPIドキュメントや github issue・ソースコードなどをよく読むと精度を改善するためのテクニックが存在することがわかります。今回はそれについて紹介します。

TFRS とは

TFRSはニューラルネットワークによる推薦モデルを作成するためのライブラリです。モデルは以下の二つの tower から構成され、two-tower モデルとも呼ばれます。

  • Query tower: 推薦の入力クエリ情報をベクトルに変換する。ユーザーのIDや特徴量で構成する。
  • Candidate tower: 推薦対象となるアイテムをベクトルに変換する。アイテムのIDや特徴量で構成する。

データセットは過去アクションのあった T 個のユーザー・アイテムペア (x_i, y_i) (i=1,\dots, T) からなるとします。負例データが explicit には存在しない、implicit feedback を仮定しています。二つの towers はそれぞれこの (x_i, y_i)\bm u(x_i, \theta)\bm v(y_i, \theta) という埋め込みベクトルに変換します。\theta はニューラルネットワークのパラメータです。以下では記法の簡略化のために \bm u_i, \bm v_i とも書きます。

学習ではある x_i に対してペアを組んでいた y_i を正例アイテム、ミニバッチ内のその他のアイテム y_j (j\neq i) を負例アイテムとした二値分類を解きます。TFRSのデフォルトでは以下のようなソフトマックスをロス関数として用いています。

L = \sum_{i\in B}\frac{e^{\bm u_i \cdot \bm v_i}}{e^{\bm u_i \cdot \bm v_i} + \sum_{j\in B, j\neq i}e^{\bm u_i \cdot \bm v_j}}

B はあるミニバッチを指しており、分母の normalization はミニバッチ内についての和になっています。バッチサイズが大きいほど負例として多くのアイテムを使うことができ性能向上が期待できることから、通常は1024や4096など大きめのバッチで学習します。

学習がうまくいけば、正例の \bm u_i, \bm v_i が埋め込み空間上で近くに位置するようになります。推論時には入力情報を query tower で \bm u に変換した上で(近似)近傍探索によって、相性の良いアイテムを取得することになります。

モデルの作成

まず上で説明した two-tower モデルをBuilding deep retrieval modelsチュートリアルに沿って作成してみます。

使用するデータは映画レビューのデータセットである movielens 100k です。このデータはユーザーが、見た映画に5段階評価をつけているものですが、ここでは点数の情報は無視して、implicit feedback データとして使います。入力は (ユーザーID, timestamp) であり、ターゲットは映画のタイトルです。

コード詳細はチュートリアルページを見てください。二つの tower から得た埋め込みベクトルを使ってロスを計算する部分は以下のようになります

class MovielensModel(tfrs.models.Model):

  def __init__(self, layer_sizes):
    super().__init__()
    self.query_model = QueryModel(layer_sizes)
    self.candidate_model = CandidateModel(layer_sizes)
    self.task = tfrs.tasks.Retrieval(
        metrics=tfrs.metrics.FactorizedTopK(
            candidates=movies.batch(128).map(self.candidate_model),
        ),
    )

  def compute_loss(self, features, training=False):
    query_embeddings = self.query_model({
        "user_id": features["user_id"],
        "timestamp": features["timestamp"],
    })
    movie_embeddings = self.candidate_model(features["movie_title"])

    return self.task(
        query_embeddings, movie_embeddings, compute_metrics=not training)

query_embeddings, movie_embeddingsがそれぞれ \bm u(x, \theta), \bm v(y, \theta)tfrs.tasks.Retrieval がロスになります。

チュートリアルではmovielens100k データセットでこのモデルを学習させた結果、test set に対して、top 100 categorical accuracy というメトリックで 0.29 という数値を報告しています。Top 100 categorical accuracy というのは、入力クエリ (ユーザーID, timestamp) に対して上位100件の映画を推薦し、ターゲットの映画(一つ)が入っていれば 1、入っていなければ 0 として、全クエリで平均を取ったものです。今回の問題設定ではRecall@100と言っても同じです。

後でルールベースロジックと比較するため、tensorflow とは独立にRecall@Kを評価する関数を作っておきます。

def recall(truth_list, pred_list, k=100):
    if not pred_list:
        return 0

    pred_list = pred_list[:k]
    tp = set(truth_list).intersection(pred_list)
    return len(tp) / len(truth_list)

def mean_recall(df, k=100):
    res = 0.0
    for _, row in df.iterrows():
        truth_list = [row["movie_title"]]
        pred_list = row["prediction"]
        res += recall(truth_list, pred_list, k) / len(df)

    return res

これを使って学習・評価をしてみます。コードの概略は以下のようになります(実験に使ったコードはこの記事の最下部にあります)

import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_recommenders as tfrs

ratings = tfds.load("movielens/100k-ratings", split="train")
movies = tfds.load("movielens/100k-movies", split="train")

ratings = ratings.map(lambda x: {
    "movie_title": x["movie_title"],
    "user_id": x["user_id"],
    "timestamp": x["timestamp"],
})
movies = movies.map(lambda x: x["movie_title"])

tf.random.set_seed(42)
shuffled = ratings.shuffle(100_000, seed=42, reshuffle_each_iteration=False)

train = shuffled.take(80_000)
test = shuffled.skip(80_000).take(20_000)
...
# train で 300 epoch 学習して test に対して推論
num_epochs = 300

model = MovielensModel([64, 32])
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))
model.fit(...)
...
# ground truth と推論結果を df_tfrs に入れておく
...
mean_recall(df_tfrs, 100)
# 0.2926999999999759

となり、Recall@100 = 0.293 でした。

Most popular 推薦との比較

次に、上の結果を most popular 推薦と比較してみましょう。Most popular というのは train データで人気なアイテム上位K件を推薦するルールベースの手法です。

df_train = tfds.as_dataframe(train)
most_popular_100 = df_train.groupby("movie_title").size().sort_values(ascending=False).iloc[:100].index.tolist()
df_most_popular["prediction"] = [most_popular_100] * len(df_most_popular)

ちなみに上位10件の映画は以下のようになっています

>>> most_popular_100[:10]
[b'Star Wars (1977)',
 b'Return of the Jedi (1983)',
 b'Fargo (1996)',
 b'English Patient, The (1996)',
 b'Scream (1996)',
 b'Contact (1997)',
 b'Liar Liar (1997)',
 b'Toy Story (1995)',
 b'Raiders of the Lost Ark (1981)',
 b'Air Force One (1997)']

先ほどと同様Recall@100を評価してみると

>>> mean_recall(df_most_popular, 100)
0.29889999999997524

となりました。TFRSモデルではおよそ 0.293 だったので、most popular の方が僅かに良い精度を出しています。これではわざわざ時間とコストを使って機械学習モデルを導入するメリットはないでしょう。

精度を改善する

なぜこれほどTFRSモデルの精度が悪いのでしょうか?これはTFRSモデルのロス関数と関係があります。最初に説明したように、二値分類のロスを計算するため、TFRSではバッチ内で正例ペアを組んでいないアイテムを負例として使っています。当然人気アイテムほど学習データに出現する回数は多くなるので、この方法は実質的にアイテムの登場頻度(=人気度)分布に応じた負例サンプリングをしていることになります。人気アイテムは正例になる可能性が大きいので、このサンプリングでは人気アイテムに対するスコアを underestimate し、不人気なアイテムのスコアを overestimate する方向に学習が進んでしまいます。

このようなサンプリングバイアスを除去する方法がRecSys2019でGoogleから提案されています。
Sampling-Bias-Corrected Neural Modeling for Large Corpus Item Recommendations

Two-tower モデルではクエリ x_i とアイテム y_j のスコアはその埋め込みベクトルの内積 \bm u_i \cdot \bm v_j で計算されました。論文ではこれを

\bm u_i \cdot \bm v_j \to \bm u_i \cdot \bm v_j - \log(p_j)

で置き換えることが提案されています。ここで p_j はアイテム j があるバッチに入る確率です。人気がないアイテム(=p_jが小さいアイテム)ほど大きな正の補正を受け、その結果それらの埋め込みベクトルが過剰に大きくなるのを防いでくれます。

サンプリングバイアスの補正をして再実験

TFRSには上記補正のための機能がすでに実装されており、 tfrs.tasks.Retrieval を呼び出すときに、p_j を渡すことで利用できます。具体的には、先ほどの MovielensModelcompute_loss で、以下のように candidate_sampling_probability を渡します。

class MovielensModel(tfrs.models.Model):
  ...

  def compute_loss(self, features, training=False):
    query_embeddings = self.query_model({
        "user_id": features["user_id"],
        "timestamp": features["timestamp"],
    })
    movie_embeddings = self.candidate_model(features["movie_title"])

    # p_j
    candidate_sampling_probability = tf.cast(features["candidate_sampling_probability"], tf.float32)

    return self.task(
        query_embeddings,
        movie_embeddings,
        candidate_sampling_probability=candidate_sampling_probability,
        compute_metrics=not training)

p_j は今回のような固定のデータセットでは、(アイテム j が正例となるレコード数) / (全レコード数) で見積もれます。例えば以下のような関数であらかじめ計算しておけば良いでしょう

def add_candidate_sampling_probability(df):
    tmp = df.groupby(by="movie_title").size() / len(df)
    tmp = tmp.to_frame("candidate_sampling_probability")

    return df.join(tmp, on="movie_title")

df_train = tfds.as_dataframe(train)
df_train = add_candidate_sampling_probability(df_train)
processed_train = tf.data.Dataset.from_tensor_slices(dict(df_train))

これで再度上でやったのと同じ実験をすると

>>> mean_recall(df_tfrs, 100)
0.4228499999999616

となりました。結果をまとめてみましょう

Recall@100
Most popular 0.299
TFRS 0.293
TFRS w/ candidate_sampling_probability 0.423

candidate_sampling_probability の導入によって、精度が大幅に向上していることがわかります。

まとめ

Movielens データに限らず、two-tower モデルは、経験上 candidate_sampling_probability の有無で精度が大きく変わります。実務で使用する際には、真っ先に気にするべき点だと思います。他にも(現時点で)チュートリアルに書かれていない機能がいくつか存在するので、APIドキュメントを読んでおくと良いです。

今回の実験に使ったコードは以下の gist に上げています。
https://gist.github.com/yng87/ac60ce26f61d18eb8cf7374ddf290a29

また以下の Issue を参考にしました。
https://github.com/tensorflow/recommenders/issues/257

Discussion