【LLM for NewsRec】大規模言語モデル(BERT)を活用したニュース推薦のPyTorchによる実装と評価
1. はじめに
世は大インターネット時代。「ニュースは紙ではなく、スマホで。」が当たり前。日々生み出される膨大なニュースの中から個人の嗜好に基づいた記事を抽出するニュース推薦システムの需要は高まり、Microsoft NewsやYahoo News、Smart Newsなど数多くのオンラインニュースメディアが、その分野に多大なる労力を割いています。そして、近年用いられる手法の多くは機械学習技術が用いられています。
ニュース推薦における推薦アイテムは、いうまでもなく「ニュース記事」。そしてその大部分はテキスト情報から構成されます。機械学習 x テキスト処理となると、今最もホットなトピックといえば、やはり大規模言語モデルの応用です。
大規模言語モデルは、膨大なコーパスによる事前学習を通して深い言語理解を獲得した大規模なニューラルネットです。文書分類や翻訳、対話応答など、様々な自然言語処理タスクにおいて、高い性能を出すことで知られています。そして、多くの先行研究で、ニュース推薦における「ニュース」や「ユーザ」のモデリング(ベクトル化)に大規模言語モデルを活用した手法が高い性能を挙げていることが報告されています [1-3, 9]。
さて、そこで今回は、PyTorchとhuggingface/transformersを用いて、言語モデルを活用したニュース推薦手法であるPLM-NR(NRMS-BERT)[1]の実装および評価を行いました。言語モデルにはBERT-base[4]とDistilBERT-base[5]を用いています。モデルの学習および評価のためのデータセットとしてはMIND[3]というMicrosoft発のデータセットを使用しています。
学習したモデルを検証データで評価した結果、PLM-NR[1]の提案論文に記載された結果に迫る高い性能を得ることができました。学習済みモデルに関しても一般公開しています。
以下、実装したリポジトリです。
2. 深層学習を用いたニュース推薦
推薦アルゴリズムには、行列分解やメモリベース協調フィルタリング、バンディットアルゴリズムを用いた推薦など、様々な手法が存在します。解きたいタスクやドメインに応じて、どのアルゴリズムが適切であるか、はケースバイケースです。
そんな中、ニュース推薦では、深層学習を用いたコンテンツベースの推薦モデルが特に高い性能を出すことが、近年の学術研究では知られています [1-3, 9]。
2.1 ニュース推薦モデルの構造
まずは、深層学習を用いたコンテンツベースの推薦モデルの一般構造を紹介します。以下は先行研究の論文から引用した図です。
図の
ニュース推薦モデルは、大きくわけて、以下の3つのコンポーネントから構成されます。
-
News Encoder: ニュースコンテンツ(
,D_c ~D_1 )からニュース記事を意味合いを反映したD_T 次元のベクトル(d ,h_c ~h_1 )を出力するh_T -
User Encoder: 過去にユーザが読んだニュース記事のリスト(
~D_1 )から、ユーザの嗜好を反映したD_T 次元のベクトル(d )を出力するu -
Click Predictor: User EncoderとNews Encoderの
次元の出力ベクトル(d ,u )の内積h_c を計算し、ユーザがu^T\cdot h_c をクリックする確率D_c を出力する\^y
すなわち、ニュース推薦モデルは、「過去にユーザが読んだ記事データから得られるユーザの嗜好」と「推薦対象となるニュース記事データ」をそれぞれ同じ次元数のベクトルに変換し、その類似度を、内積により計算することで、クリック率を予測しているモデルであると言えます。
「ユーザの嗜好やニュース記事データをどのようにしてベクトル化するのか(News EncoderとUser Encoderをどう構築するか)?」というのがニュース推薦における最も重要なトピックと言っても過言ではありません。
近年では、それらのベクトル化にBERTをはじめとした大規模言語モデルを活用した手法が、高い性能を出すことで知られています。
2.2 PLM-NR [1]
2019年のBERTの登場を皮切りに、RoBERTaやGPT、LLaMAなど、様々なTransformerをベースとした大規模言語モデルが提案・公開され、あらゆる自然言語処理タスクで高い性能を発揮しています。
今回実装したPLM-NR(Pre-trained Language Model empowered News Recommendation)は、事前学習済みの言語モデルを用いたニュース推薦手法で、2021年にWuらにより、その結果が報告されました[1]。
この研究では、2.1節でも説明があった通り、News EncoderとUser EncoderにBERTを用いています。彼らはNews EncoderとUser EncoderにBERTを用いた複数のニュース推薦モデルに対して、後述するMIND[3]を用いてオフライン検証を行い、それらのモデルが高い性能が出すことを確認しました。
彼らの報告はオフライン検証による性能評価に留まりません。オンライン実験として、PLM-NRによるニュース推薦システムを実際にMicrosoft Newsプラットフォーム上で稼働したところ、既存モデルと比較して、8.53%のクリック率の向上が見られたとのことです。
彼らのオフライン検証の中で最も高い性能が得られたのがNRMS[9]というモデルにBERTを適用したNRMS-BERTという、BERTからの出力系長にMultihead Attentionを適用した手法です。次章では、いよいよ今回実装したNRMS-BERTの詳細な理論および実装の説明をしていきます。
3. NRMS-BERTの理論と実装
3.1 使用技術・フォルダ構成
まず、プロジェクトを概観するために、フォルダ構成を見てみましょう。
$ tree -L 2
├── README.md
├── dataset/ # データセットのダウンロード用コード, ダウンロードされたデータセットが格納される
│ └── download_mind.py
├── pyproject.toml
├── requirements-dev.lock # Ryeにより自動生成されるファイル
├── requirements.lock # Ryeにより自動生成されるファイル
├── src/
│ ├── config/ # ハイパーパラメータ等が定義されたconfig.pyが存在
│ ├── const/ # プロジェクト共通の定数
│ ├── evaluation/ # MRRやnDCG等の評価指標計算
│ ├── experiment/ # 実験用コード(train.py)
│ ├── mind/ # データセット関連のコード (PyTorchのDatasetクラス等)
│ ├── recommendation/ # 推薦モデルの実装
│ │ └── nrms/ # NRMS-BERTの実装
│ │ ├── AdditiveAttention.py
│ │ ├── NRMS.py
│ │ ├── PLMBasedNewsEncoder.py
│ │ ├── UserEncoder.py
│ │ ├── __init__.py
│ └── utils/ # 汎用関数
└── test/ # テストコード
├── evaluation/
├── mind/
└── recommendation/
使用した技術は以下になります。
- 言語: Python
- 深層学習ライブラリ: PyTorch, huggingface/transformers
- パッケージ管理・仮想環境: Rye
- LinterやFormatter等の開発基盤周り: Ruff, mypy, black
今回、パッケージ管理・仮想環境の作成だけ、少し思い切ってRyeを採用しました。
RyeはFlaskの作者であるArmin Ronacherにより開発されたパッケージ管理ツールです。リポジトリのDisscussionのShould Rye Exist?に記載がある通り、Rustにおけるcargoやrustupのように、Pythonにおけるプロジェクト管理のデファクトスタンダードになることを目指しています。
Ryeはone-stop-shopと表現されるように、Pythonインタプリタの管理やパッケージのインストールといったPython開発に必要な管理を一通り担ってくれます。poetry + pyenv,asdfのように、パッケージ管理とインタプリタ管理でそれぞれツールをインストールする必要はありません。Ryeさえ入れておけば、rye sync
とコマンドを一つ打つだけで、pyproject.yaml
に則ったPython実行環境が作成されます。
現時点ではあくまで、experimentalでnot yet production readyと記載があるので、本番利用等を強く勧めることは、躊躇われます。しかし、個人の感想としては、悪名高きPython開発環境構築をDXしてくれる素晴らしいツールであると感じています。
3.2 PyTorchによるPLM-NR(NRMS-BERT)実装
プロジェクトの概観がつかめたところで、NRMS-BERTとそれを構成するレイヤーのPyTorch/transformersによる実装を見ていきます。なお、Multihead Attentionの実装・数式に関する説明はここでは割愛します。Transformerの原論文やインターネット上の技術記事等を参照してください。
Additive Attention
NewsEncoder・UserEncoderについてみて行く前に、まずはAdditive Attentionについてです。Additive Attentionは「ニュースの単語ベクトルの配列」や「ユーザが過去に読んだニュースの埋め込みベクトルの配列」などのベクトル系列を、重要度に基づいて重み付けし、一つのベクトルに集約する役割を持ちます。
以下が数式です:
Additive Attentionでは、
そして、重み
それでは、Additive Attentionの計算がわかったところで、PyTorchによる実装を見てみましょう。
import torch
from torch import nn
def init_weights(m: nn.Module) -> None:
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
nn.init.zeros_(m.bias)
class AdditiveAttention(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int) -> None:
super().__init__()
self.attention = nn.Sequential(
nn.Linear(
input_dim, hidden_dim
), # in: (batch_size, M, d), out: (batch_size, M, hidden_dim)
nn.Tanh(), # in: (batch_size, M, hidden_dim), out: (batch_size, seq_len, hidden_dim)
nn.Linear(
hidden_dim, 1, bias=False
), # in: (batch_size, M, hidden_dim), out: (batch_size, M, 1)
nn.Softmax(dim=-2),
)
self.attention.apply(init_weights)
def forward(self, input: torch.Tensor) -> torch.Tensor:
attention_weight = self.attention(input) # = α
return input * attention_weight
以上がAdditive Attentionの説明になります。NRMS-BERTでは、
- NewsEncoderにおける単語ベクトルの系列の集約
- UserEncoderにおけるユーザが過去に読んだニュースベクトル系列の集約
の2つにAdditive Attentionが用いられています。
News Encoder
ここで、ニュース記事テキストの単語列を
NRMS-BERTのNews Encoderは、BERT Encoder・Multihead Attention・Additive Attentionの3つから構成されます。
News Encoderでは、まず、ニュース記事の記事タイトルテキスト
今回、BERT EncoderとMultihead Attentionは、transformersとPyTorchにより提供されている実装を活用しました。Additive Attentionも既に実装済みであるため、News Encoder自体の実装は非常にシンプルです。
それでは、以上を踏まえ、News Encoder(PLMBasedNewsEncoder
)の実装を見てみましょう。
import torch
from torch import nn
from transformers import AutoConfig, AutoModel
from .AdditiveAttention import AdditiveAttention
class PLMBasedNewsEncoder(nn.Module):
def __init__(
self,
pretrained: str = "bert-base-uncased",
multihead_attn_num_heads: int = 16,
additive_attn_hidden_dim: int = 200,
):
super().__init__()
self.plm = AutoModel.from_pretrained(pretrained)
plm_hidden_size = AutoConfig.from_pretrained(pretrained).hidden_size
self.multihead_attention = nn.MultiheadAttention(
embed_dim=plm_hidden_size, num_heads=multihead_attn_num_heads, batch_first=True
)
self.additive_attention = AdditiveAttention(plm_hidden_size, additive_attn_hidden_dim)
def forward(self, input_val: torch.Tensor) -> torch.Tensor:
V = self.plm(input_val).last_hidden_state # [batch_size, M] -> [batch_size, seq_len, d]
multihead_attn_output, _ = self.multihead_attention(
V, V, V
) # [batch_size, M, d] -> [batch_size, M, d]
additive_attn_output = self.additive_attention(
multihead_attn_output
) # [batch_size, M, d] -> [batch_size, M, d]
output = torch.sum(
additive_attn_output, dim=1
) # [batch_size, M, d] -> [batch_size, d]
return output
User Encoder
ここでは、ユーザが過去に読んだニュース記事リストから、News Encoderの出力次元と同じ
User Encoderでは、まず、既に実装したNews Encoderを用いて、ユーザが過去に読んだ
以上がUser Encoderの構造に関する説明です。こちらもNews Encoder, Additive Attention, Multihead Attentionといった構成要素は既に実装済みであるため、実装は非常にシンプルです。User Encoder(UserEncoder
)の実装を見てみましょう。
import torch
from torch import nn
from .AdditiveAttention import AdditiveAttention
class UserEncoder(nn.Module):
def __init__(
self,
hidden_size: int,
multihead_attn_num_heads: int = 16,
additive_attn_hidden_dim: int = 200,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.multihead_attention = nn.MultiheadAttention(
embed_dim=hidden_size, num_heads=multihead_attn_num_heads, batch_first=True
)
self.additive_attention = AdditiveAttention(hidden_size, additive_attn_hidden_dim)
def forward(self, news_histories: torch.Tensor, news_encoder: nn.Module) -> torch.Tensor:
batch_size, hist_size, seq_len = news_histories.size()
news_histories = news_histories.view(
batch_size * hist_size, seq_len
) # [batch_size, N, M] -> [batch_size*N, M]
news_histories_encoded = news_encoder(
news_histories
) # [batch_size*N, M] -> [batch_size*N, d]
news_histories_encoded = news_histories_encoded.view(
batch_size, hist_size, self.hidden_size
) # [batch_size*N, d] -> [batch_size, N, d]
multihead_attn_output, _ = self.multihead_attention(
news_histories_encoded, news_histories_encoded, news_histories_encoded
) # [batch_size, N, d] -> [batch_size, N, d]
additive_attn_output = self.additive_attention(
multihead_attn_output
) # [batch_size, N, d] -> [batch_size, d]
output = torch.sum(additive_attn_output, dim=1)
return output
NRMS-BERT(Click Predictor)
最後にNRMS-BERT(Click Predictor)の実装をみていきます。今までで実装したUser EncoderとNews Encoderを使い、推薦候補となるニュースやユーザをベクトル化(
import torch
from torch import nn
from transformers.modeling_outputs import ModelOutput
class NRMS(nn.Module):
def __init__(
self,
news_encoder: nn.Module,
user_encoder: nn.Module,
hidden_size: int,
loss_fn: nn.Module = nn.CrossEntropyLoss(),
) -> None:
super().__init__()
self.news_encoder: nn.Module = news_encoder
self.user_encoder: nn.Module = user_encoder
self.hidden_size: int = hidden_size
self.loss_fn = loss_fn
def forward(
self, candidate_news: torch.Tensor, news_histories: torch.Tensor, target: torch.Tensor
) -> torch.Tensor:
"""
Parameters
----------
candidate_news : torch.Tensor (shape = (batch_size, candidate_num, seq_len))
news_histories : torch.Tensor (shape = (batch_size, candidate_num, seq_len))
===========================================================================
Returns
----------
output: torch.Tensor (shape = (batch_size, candidate_num))
"""
batch_size, candidate_num, seq_len = candidate_news.size()
candidate_news = candidate_news.view(batch_size * candidate_num, seq_len)
news_candidate_encoded = self.news_encoder(
candidate_news
) # [batch_size * (candidate_num), seq_len] -> [batch_size * (candidate_num), emb_dim]
news_candidate_encoded = news_candidate_encoded.view(
batch_size, candidate_num, self.hidden_size
) # [batch_size * (candidate_num), emb_dim] -> [batch_size, (candidate_num), emb_dim]
news_histories_encoded = self.user_encoder(
news_histories, self.news_encoder
) # [batch_size, histories, seq_len] -> [batch_size, emb_dim]
news_histories_encoded = news_histories_encoded.unsqueeze(
-1
) # [batch_size, emb_dim] -> [batch_size, emb_dim, 1]
output = torch.bmm(
news_candidate_encoded, news_histories_encoded
) # [batch_size, (candidate_num), emb_dim] x [batch_size, emb_dim, 1] -> [batch_size, (1+npratio), 1, 1]
output = output.squeeze(-1).squeeze(-1) # [batch_size, (1+npratio), 1, 1] -> [batch_size, (1+npratio)]
# NOTE:
# when "val" mode(self.training == False) → not calculate loss score
# Multiple hot labels may exist on target.
# e.g.
# candidate_news = ["N24510","N39237","N9721"]
# target = [0,2](=[1, 0, 1] in one-hot format)
if not self.training:
return ModelOutput(logits=output, loss=torch.Tensor([-1]), labels=target)
loss = self.loss_fn(output, target)
return ModelOutput(logits=output, loss=loss, labels=target)
以上がモデルの実装に関する説明になります。次章からは、実装したモデルの訓練・検証を行います。今回、モデルの検証にはMINDというデータセットを用いました。
4. MIND: Microsoft News Dataset
今回、モデルの学習・評価にはMicrosoft Newsの実際の行動ログ・ニュースデータを収集することにより作成された Microsoft News Dataset(通称:MIND)[3]を用いました。MINDには「約16万件の英文ニュース記事データ」と「約100万のユーザから収集された1500万件以上の行動ログ」が保存されています。2020年にMicrosoftの研究者らによって公開されて以来、MINDは多くのニュース推薦に関する研究で用いられています
今回は、MIND内のニュース情報を格納したnews.tsv
とユーザのImpression情報を格納したbehaviors.tsv
の二つのtsvファイルを用いてモデルの学習・検証を行いました。Microsoftの公式サイトによると、それぞれのファイルには次の表に示すようなデータが格納されています。
news.tsv
カラム名 | 説明 | 具体例 |
---|---|---|
News ID | ニュースのID | N37378 |
Category | カテゴリ | sports |
Subcategory | サブカテゴリ | golf |
Title | タイトル | PGA Tour winners |
Abstract | 要約 | A gallery of recent winners on the PGA Tour. |
URL | URL | https://www.msn.com/en-us/sports/golf/pga-tour-winners/ss-AAjnQjj?ocid=chopendata |
Title Entities | このニュースのタイトルに含まれるエンティティ | - |
Abstract Entities | このニュースの要約に含まれるエンティティ | - |
behavior.tsv:
カラム名 | 説明 | 具体例 |
---|---|---|
Impression ID | インプレッションのID | 123 |
User ID | ユーザーの匿名ID | U131 |
Time | インプレッションの時間。形式は“MM/DD/YYYY HH:MM:SS AM/PM” | 11/13/2019 8:36:57 AM |
History | このインプレッションの前のユーザーのニュースクリック履歴(クリックされたニュースのIDリスト) | N11 N21 N103 |
Impressions | このインプレッションで表示されたニュースのリストと、それらのニュースに対するユーザーのクリック行動(1はクリック、0は非クリック) | N4-1 N34-1 N156-0 N207-0 N198-0 |
Historyカラムに示されたニュースリストをユーザが過去にクリックしたニュース(
MINDに関するより詳細な説明は、論文や公式サイトをご覧ください。なお、MINDにはユーザ数を50,000人分のみに限定したsmallサイズのデータセット( MIND-small )も用意されており、今回の実験ではそちらを利用しました。
5. モデルの訓練と評価
5.1 ネガティブサンプリング
NRMS-BERTは、モデルの訓練方法が少し特殊です。MINDにあるようなクリックログから、単純にクリックする(=1), クリックしない(=0)の二値分類として学習するのではなく、ネガティブサンプリングという手法を活用します。
ユーザがクリックしたニュース1本(正例,label=1)に対して、ユーザがクリックしなかったニュースK本(負例,label=0)をランダムにサンプリングし、(1 + K)クラスの多値分類として学習を行います。
ここでは、ネガティブサンプリングが実装されているPyTorchのDatasetクラスの実装を参考に見ていきます
...
class MINDTrainDataset(Dataset):
def __init__(
...
) -> None:
self.behavior_df: pl.DataFrame = behavior_df # behavior.tsvが格納されたデータフレーム
self.news_df: pl.DataFrame = news_df # news.tsvが格納されたデータフレーム
self.npratio: int = npratio # ネガティブサンプリング時の負例の数K
...
def __getitem__(self, behavior_idx: int) -> dict: # TODO: 一行あたりにpositiveが複数存在することも考慮した
...
# Extract Values
behavior_item = self.behavior_df[behavior_idx]
...
# Sampling Positive(clicked) & Negative(non-clicked) Sample
poss_idxes, neg_idxes = (
behavior_item["clicked_idxes"].to_list()[0],
behavior_item["non_clicked_idxes"].to_list()[0],
)
sample_poss_idxes, sample_neg_idxes = random.sample(poss_idxes, 1), self.__sampling_negative(
neg_idxes, self.npratio
)
sample_impression_idxes = sample_poss_idxes + sample_neg_idxes
random.shuffle(sample_impression_idxes)
sample_impressions = impressions[sample_impression_idxes]
...
def __sampling_negative(self, neg_idxes: list[int], npratio: int) -> list[int]:
if len(neg_idxes) < npratio:
return neg_idxes + [EMPTY_IMPRESSION_IDX] * (npratio - len(neg_idxes))
return random.sample(neg_idxes, self.npratio)
...
なお、今回は先行研究[1,9]と同様、負例の数をK = 4(npratio = 4
)、すなわち「5本のニュースの中でユーザがクリックしたニュースはどれか?」という5値分類として推薦モデルの学習を行いました。
5.2 評価指標
モデルの評価指標は元論文[1,9]に倣い、AUC, MRR, nDCG@Kの3つを採用しました。AUCは2値分類でよく用いられる評価指標の一つで、クリックした・しなかったを適切に予測できている時ほどスコアが高くなります。
MRRはレコメンドの評価指標としてよく用いられ、以下の式で表せます。
MRRは、予測した順位の中の最初の適合アイテムの順位、すなわち、予測したニュースのクリック率ランキングの中で最も最初に現れる実際にクリックしたニュースの順位の逆数の全ユーザに対する平均値です。
nDCGもレコメンドの評価指標としてポピュラーで以下の式で表せます。
なお、元論文に合わせて、Burgesらにより定義されたnDCG[8]を採用しています。
これらの評価指標はRecEvaluatorクラスに実装があります。長くなるのでここでは割愛しますが、興味があれば、ぜひご覧ください。
5.3 評価結果
さて、以上を踏まえて、いよいよ評価結果を見てみましょう。これまでで紹介したモデルをNVIDIA V100 GPU x 1上で、MIND-smallを用いて学習しました。また性能のベースラインとして、ランダム推薦も実装・評価を行いました。
実験結果を以下に示します。
Model | AUC | MRR | nDCG@5 | nDCG@10 |
---|---|---|---|---|
Random Recommendation | 0.500 | 0.201 | 0.203 | 0.267 |
NRMS + DistilBERT-base | 0.674 | 0.297 | 0.322 | 0.387 |
NRMS + BERT-base | 0.689 | 0.306 | 0.336 | 0.400 |
NRMS-BERT[1] (参考) | 0.695 | 0.347 | 0.380 | 0.437 |
ランダム推薦と比較すると、DistilBERT, BERT、いずれも明らかに高い性能が得られていることがわかります。
また、結果の表の最下部に、PLM-NRの論文に記載されていたNRMS-BERTの結果を参考として掲載しました。今回は、小規模なMIND-smallで学習しましたが、それでも、MIND全体で学習した論文記載の結果にかなり近い性能を出せていることがわかります。
なお、学習にかかった時間やハイパーパラメータ、学習済みモデルへのリンクは、Appendixに掲載しましたので、そちらをご覧ください。
6. まとめ
長くなりましたが、以上になります!
今回は、大規模言語モデル(BERT)を用いた推薦手法であるNRMS-BERTの実装・評価を行いました。MIND(MIND-small)というデータセットを用いて訓練・評価を行った結果、NRMS-BERTの現論文にかなり迫る性能を出すことができました。
今回は、BERT-base/DistilBERT-baseを使ったモデルをファインチューニングしたモデルを評価しましたが、2023年に入って発表された直近の研究では、GPT系モデル + LoRAチューニングを用いたニュース推薦の拡張[7]やニュース推薦のためのPrompt Learning[6]等、新しい手法が次々と紹介されています。
大規模言語モデルの応用分野の一つとして、ニュース推薦分野の研究は、ますます活発化していくことが予想(期待)されます。
本記事を面白いと思ってくれた方がいらっしゃいましたら、news-recommendation-llmにスターをしてくださると、励みになります。
7. 参考文献
[1] "Empowering News Recommendation with Pre-Trained Language Models." Wu, C., Wu, F., Qi, T., & Huang, Y. https://doi.org/10.1145/3404835.3463069
[2] "Personalized News Recommendation: A Survey." Wu, C., Wu, F., & Huang, Y. https://arxiv.org/abs/2106.08934.
[3] "MIND: A Large-scale Dataset for News Recommendation" Wu, F., Qiao, Y., Chen, J.-H., Wu, C., Qi, T., Lian, J., Liu, D., Xie, X., Gao, J., Wu, W., & Zhou, M. https://aclanthology.org/2020.acl-main.331
[4] "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding" Devlin, J., Chang, M.-W., Lee, K., & Toutanova, K. https://aclanthology.org/N19-1423
[5] "DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter" Sanh, V., Debut, L., Chaumond, J., & Wolf, T. https://arxiv.org/abs/1910.01108
[6] "Prompt Learning for News Recommendation" Zhang, Z., & Wang, B. arXiv preprint arXiv:2304.05263
[7] "ONCE: Boosting Content-based Recommendation with Both Open- and Closed-source Large Language Models." Liu, Q., Chen, N., Sakai, T., & Wu, X.-M. arXiv:2305.06566.
[8] "Learning to Rank Using Gradient Descent" Burges, C., Shaked, T., Renshaw, E., Lazier, A., Deeds, M., Hamilton, N., & Hullender, G. https://doi.org/10.1145/1102351.1102363
[9] "Neural News Recommendation with Multi-Head Self-Attention" Wu, C., Wu, F., Ge, S., Qi, T., Huang, Y., & Xie, X. https://doi.org/10.18653/v1/D19-1671
8. Appendix
Hyper parameters
Model | epoch | Learning Rate | batch size | K(npratio) | history size |
---|---|---|---|---|---|
NRMS + DistilBERT-base | 3 | 1e-4 | 128 | 4 | 50 |
NRMS + BERT-base | 3 | 1e-4 | 128 | 4 | 50 |
Time to train & Trained Model
Model | Trained Model | Time to Train |
---|---|---|
NRMS + DistilBERT-base | Google Drive | 15.0h |
NRMS + BERT-base | Google Drive | 28.5h |
Discussion