🕌

GPTを理解したい

2022/11/16に公開約4,600字

GPTを理解したい記事です。

GPTはtransformerを基礎としたモデルです。transformerの解説記事はこちら。

https://zenn.dev/sunbluesome/articles/078ac9a9afca6a

元論文「Improving Language Understanding by Generative Pre-Training」をベースにまとめていきます。

凄いポイント

  • 様々なタスクをモデルの構造をほとんど変えずに行える。しかも12タスク中9つでSOTA。

GPTの構造

Transformer Decoder (T-D)

GPTは[2]で提案されたTransformer Decoder (T-D)を用いたモデルで、以下のような構造をしています(下図左側)。


[1] Figure 1 より引用

初期のTransformerもそうですが、当時の言語モデルはencoder-decoderネットワークが主流でした。例えば英語をドイツ語に変換する翻訳タスクでは、encoderに英語の単語列(入力)を、decoderに過去のドイツ語の単語列(過去の出力)を入力し、次単語予測を行います。

Transformer Decorderはパラメータを半減することができる半面、decoder側の入力しか受け付けません。その為、入力列と過去の出力列を結合した系列を作り学習を行います。つまり、(m^1, \ldots, m^n) \mapsto (y^1, \ldots, y^n) のような系列変換写像があったとき、
これらを結合して (u^1, \ldots, u^{n+\eta+1}) = (m^1, \ldots, m^n, \delta, y^1, \ldots, y^\eta) のような1つの系列にして学習を行います。\delta は区切り文字トークンです。

教師なし事前学習

トークンの集合が \mathcal{U} = \left\{u_1, \ldots, u_n\right\} で与えられるとき、パラメータ \Theta で定義されるニューラルネットワークの対数尤度 L_1 を最大化することを考えます。

\begin{align} L_1 (\mathcal{U}) = \sum_i \log P(u_i | u_{i-k}, \ldots, u_{i-1};\Theta) \end{align}

ここで、kはコンテキストのウィンドウサイズです。

GPTは先述したTransformer Decoderの構造を利用しており*、以下のように定義されます。

\begin{align} h_0 &= UW_e + W_p \\ h_i &= \text{transformer\_block}(h_{i-1}) \quad \forall i \in [1, n]\\ P(u) &= \text{softmax}(h_nW_e^T) \end{align}

ここで、\mathcal{U} = (u_{i-k}, \ldots, u_{i-1})はトークンのコンテキストベクトル、nは層の数、W_eはトークン埋め込み行列(隠れ状態へ埋め込むベクトル)、W_pはposition encodingです(position encodingについてはこちらを参照)。

*transformer blockはmulti-headed self-attentionが使われていることに注意してください。

教師ありファインチューニング

式 (1) を最大化するように学習した後、目標タスクの教師あり学習を行います。

目標タスクの学習データセット(X, y) \in \mathcal{C}を準備します。\left\{x^1, \ldots, x^m\right\} \in Xは入力トークン列、yは入力トークン列に対応するラベルです。このとき、次の尤度関数を最大化するようにファインチューニングを行います。

\begin{align} L_2 (\mathcal{C}) = \sum_{(x, y)} \log P (y|x^1, \ldots, x^m) \end{align}

GPTでいうところのファインチューニングでは、事前学習したモデルの最終層のみパラメータ更新します。つまり、入力トークン列を事前学習したモデルへ通し、transformer blockの最終出力h_l^mを得ます。次に、h_l^mを線形レイヤーへ入力し、y を予測する線形レイヤーのパラメータ W_y を学習しています。

\begin{align} P(y|x^1, \ldots, x^m) = \text{softmax}(h_l^m W_y) \end{align}

auxiliary objective

auxiliary objective(補助目的)を含めてファインチューニングを行うと、汎化性能と収束性が向上する事が分かったと筆者らは言っています。具体的には、新たな目的関数

\begin{align} L_3(\mathcal{C}) = L_2(\mathcal{C}) + \lambda * L_1(\mathcal{C}) \end{align}

を導入してファインチューニングを行います。ここで、\lambda は任意の定数です。

タスクに合わせた入力設計

分類タスクであれば、上記のようなファインチューニングを素直に行えば良いのですが、文章とそれに対する質問に対して回答を得るタスクや、文章間の類似度評価を行うためには入力を工夫する必要があります。GPTモデルは1系列での学習にしか対応しておらず、2系列の入力が必要な文の類似度評価や、3系列の入力が必要になる文章とそれに対する質問に対して回答を得るタスクはそのままでは学習できないためです。

[1]ではTextual entailement、Similarity、Question Answering and Commonsense Reasoningでの例が紹介されています。


再掲:[1] Figure 1 より引用

上図の右側でそれぞれのタスクにおける学習方法の模式図があるので、適宜確認すると以下の説明が分かりやすくなると思います。

Texutual entailement

日本語では、テキスト含意タスクと言います。前提トークン p と仮説トークン h が与えられたときに、ph を含意するかどうかを判定するタスクです。要は、p から h を推論することは可能かどうかを判定するタスクということですね。単純に ph をデリミタトークン($)で繋げはOKです。

Similarity

2つの文の類似度を評価します。デリミタトークンを文間に挟みこむのはTextual entailementと同様です。ただ、2つの文間に順序関係は無いので、順序の異なる2つの入力系列を用意し、それぞれの独立にモデルへ渡して隠れ状態を得ます。得られた2つの隠れ状態の要素ごとの和をとり、最後にlinear layerへ渡します。

Question Answering and Commonsense Reasoning

z、質問 q から、回答 \{a_k\} を予測するタスクです。これまでと同様に、デリミタトークン($)で全てを結合し、[z; q; $; a_k] とします。回答ごとに独立にモデルへ渡し、最後にsoftmaxを掛けることで回答に対する出力分布を得ます。

GPTの性能

最後に様々なタスクに対するGPTの性能を見て終わりましょう。

自然言語推論タスクの性能


[1] table 2 より引用

6つの自然言語推論タスクで、当時のSOTAモデルとの性能比較を行っています。ファインチューニング後のGPTモデルは6つ中5つでSOTAを達成しています。

Question Answering and Commonsense Reasoningの性能

これは質問に対する回答を得るタスクでした。全てのタスクでSOTAを達成しています。


[1] table 3 より引用

Semantic similarity and classificationの性能

6つ中4つでSOTAです。


[1] table 4 より引用

まとめ

  • GPTはTransformerのDecoder側を変形したモデル。
  • ほとんどモデル構造や目的関数を変えることなく様々なタスクでSOTAを達成

にわかには信じがたい威力のモデルですが、GPT-nと総称される後継モデルが人間とほぼ同じクオリティの文章を生成できるなるなど、このモデルの威力は十分実証されていますね。

参考文献

  1. Radford, A., Narasimhan, K., Salimans, T. & Sutskever, I. (2018). Improving Language Understanding by Generative Pre-Training.
  2. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L. & Polosukhin, I. (2017). Attention Is All You Need. arXiv. https://doi.org/10.48550/arxiv.1706.03762

Discussion

ログインするとコメントできます