📌

JuliaでGPTモデルを作成してみる

2023/09/29に公開

最近流行りのTransformerについて理解を深めたい
→ 内部構造を理解するなら一言語で完結するJuliaでしょ!
(単に研究である強化学習の実装にJuliaを使用しているだけ)

言語タスクを学習させるのは面倒だったので、算術タスクを学習させて性能を評価することにしました。2桁同士の加算、一桁同士の和積+"()"を使用した優先度付きの演算など

環境

  • Windows 11
  • Julia 1.9.3
  • cuda 12.2

実装

https://github.com/char5742/simple-gpt-julia
まずは問題の作成

@enum MathToken begin
    Pad = 1
    # 演算子
    Plus
    Times
    Equal
    # 数字
    Zero
    One
    Two
    Three
    Four
    Five
    Six
    Seven
    Eight
    Nine
    LeftParen
    RightParen
    End
end

const numbers = [Zero, One, Two, Three, Four, Five, Six, Seven, Eight, Nine]
const operators = [Plus, Times]

算術では数字(10個)、括弧(2個)、演算子(3個)、パディングと終端(2個)の合計17個のトークンを使用することにしました。

"""
全ての単純な式を列挙する

下記のような式を全て列挙する  
8+6 = [Eight, Plus, Six]
"""
function enumerate_simple_exprs(k::Int)
    res = Vector{Vector{MathToken}}()
    max_num = 10^k - 1
    for i in 0:max_num
        for j in 0:max_num
            expr = [num_to_tokens(i)..., Plus, num_to_tokens(j)...]
            push!(res, expr)
        end
    end
    res
end

これでK桁どうしの足し算を作ることができます

モデルの実装

using Lux, LuxCUDA, Random, Zygote, MLUtils, Optimisers, ProgressBars, Printf, JLD2, StatsBase, NamedTupleTools

使用するライブラリ

PositionalEncodingLayerの実装

struct PositionalEncodingLayer <: Lux.LuxCore.AbstractExplicitLayer
    features::Int
    seq_len::Int
end

Lux.initialparameters(rng::AbstractRNG, layer::PositionalEncodingLayer) = NamedTuple()

function Lux.initialstates(rng::AbstractRNG, layer::PositionalEncodingLayer)
    features = layer.features
    seq_len = layer.seq_len
    pos_enc = [
        i % 2 == 1 ? sin(j / 10000^(2 * i / features)) : cos(j / 10000^(2 * i / features))
        for i in 1:features, j in 1:seq_len, _ in 1:1
    ]
    (pos_enc=pos_enc,)
end
Lux.parameterlength(l::PositionalEncodingLayer) = 0
Lux.statelength(l::PositionalEncodingLayer) = l.features * l.seq_len

function (::PositionalEncodingLayer)(x, ps, st)
    return x .+ st.pos_enc, st
end

位置情報の扱いについては、オリジナルそのままの正弦波を使用しました。

MultiHeadAttentionの実装

struct MultiHeadAttention <: Lux.LuxCore.AbstractExplicitContainerLayer{(:query, :key, :value, :dropout, :output)}
    query
    key
    value
    dropout
    output
    nheads
end

function (m::MultiHeadAttention)((x, mask), ps, st)
    # x = (feature, seq_len, batch_size)
    q, _ = m.query(x, ps.query, st.query)
    k, _ = m.key(x, ps.key, st.key)
    v, _ = m.value(x, ps.value, st.value)
    st_dropout = nothing
    function fdrop(x)
        x, st_dropout = m.dropout(x, ps.dropout, st.dropout)
        x
    end
    values, _ = dot_product_attention(q, k, v; mask=mask, fdrop=fdrop, nheads=m.nheads)

    output, _ = m.output(values, ps.output, st.output)
    st = merge(st, (dropout=st_dropout,))
    return output, st
end

NNlib.jlにあるdot_product_attentionを使用して実装しました。fropについてはもう少し良い実装があるような気がしますが、とりあえずはローカル関数でst_dropoutを取り出すようにしました。

DecoderBlockの実装

struct DecoderBlock <: Lux.LuxCore.AbstractExplicitContainerLayer{(:mha, :layer_norm1, :ffn, :dropout, :layer_norm2)}
    mha
    layer_norm1
    ffn
    dropout
    layer_norm2
end

function (m::DecoderBlock)((x, mask), ps, st)
    norm_out, st_layer_norm1 = m.layer_norm1(x, ps.layer_norm1, st.layer_norm1)
    attn_output, st_mha = m.mha((norm_out, mask), ps.mha, st.mha)
    x = attn_output + x

    norm_out, st_layer_norm2 = m.layer_norm2(x, ps.layer_norm2, st.layer_norm2)
    ffn_output, _ = m.ffn(norm_out, ps.ffn, st.ffn)
    dropout_out, st_dropout = m.dropout(ffn_output, ps.dropout, st.dropout)
    x = dropout_out + x

    st = merge(st, (mha=st_mha, layer_norm1=st_layer_norm1, layer_norm2=st_layer_norm2, dropout=st_dropout,))
    return (x, mask), st
end

Decoderの実装

struct Decoder <: Lux.LuxCore.AbstractExplicitContainerLayer{(:embedding, :positional_encoding, :blocks, :output)}
    embedding
    positional_encoding
    blocks
    output
end

function (m::Decoder)((x, mask), ps, st)
    # x = (token_size, seq_len, batch_size)
    x, _ = m.embedding(x, ps.embedding, st.embedding)
    x, _ = m.positional_encoding(x, ps.positional_encoding, st.positional_encoding)
    # x = (features, seq_len, batch_size)

    (x, _), st_blocks = m.blocks((x, mask), ps.blocks, st.blocks)

    x = flatten(x)
    output, st_out = m.output(x, ps.output, st.output)
    # output = (token_size, seq_len, batch_size)
    st = merge(st, (blocks=st_blocks, output=st_out))
    return output, st
end

パディングマスクの作成

function create_mask(x)
    # x shape: (seq_len, batch_size)
    mask = (x .== 1)
    # Expand mask dimensions to shape (seq_len, 1, 1, batch_size)
    mask_expanded = reshape(mask, 1, size(mask, 1), 1, size(mask, 2))
    mask = repeat(mask_expanded, size(mask, 1), 1, 1, 1)
    return mask
end

これにより、Padに対して注意しなくなる。

学習

Transformerの効果を確認するために比較モデルとしてただのDNN層であるシンプルモデルも用意しました。
(モデルのパラメータはチューニングしていないので、あまり意味のあるものではないですが...)

シンプルモデル
Decoder(
    embedding = Dense(17 => 48),        # 864 parameters
    positional_encoding = PositionalEncodingLayer(),
    blocks = NoOpLayer(),
    output = Chain(
        layer_1 = Dense(912 => 912, relu),  # 832_656 parameters
        layer_2 = Dropout(0.2),
        layer_3 = Dense(912 => 912, relu),  # 832_656 parameters
        layer_4 = Dropout(0.2),
        layer_5 = Dense(912 => 912, relu),  # 832_656 parameters
        layer_6 = Dropout(0.2),
        layer_7 = Dense(912 => 17),     # 15_521 parameters
    ),
)         # Total: 2_514_353 parameters,
          #        plus 918 states.
GPTモデル
Decoder(
    embedding = Dense(17 => 48),        # 864 parameters
    positional_encoding = PositionalEncodingLayer(),
    blocks = Chain(
        layer_1 = DecoderBlock(
            mha = MultiHeadAttention(
                query = Dense(48 => 48),  # 2_352 parameters
                key = Dense(48 => 48),  # 2_352 parameters
                value = Dense(48 => 48),  # 2_352 parameters
                dropout = Dropout(0.1),
                output = Dense(48 => 48),  # 2_352 parameters
            ),
            layer_norm1 = LayerNorm((48, 1), affine=true, dims=Colon()),  # 96 parameters
            ffn = Chain(
                layer_1 = Dense(48 => 192, gelu),  # 9_408 parameters
                layer_2 = Dense(192 => 48),  # 9_264 parameters
                layer_3 = Dropout(0.1),
            ),
            dropout = Dropout(0.1),
            layer_norm2 = LayerNorm((48, 1), affine=true, dims=Colon()),  # 96 parameters
        ),
        layer_2 = DecoderBlock(
            mha = MultiHeadAttention(
                query = Dense(48 => 48),  # 2_352 parameters
                key = Dense(48 => 48),  # 2_352 parameters
                value = Dense(48 => 48),  # 2_352 parameters
                dropout = Dropout(0.1),
                output = Dense(48 => 48),  # 2_352 parameters
            ),
            layer_norm1 = LayerNorm((48, 1), affine=true, dims=Colon()),  # 96 parameters
            ffn = Chain(
                layer_1 = Dense(48 => 192, gelu),  # 9_408 parameters
                layer_2 = Dense(192 => 48),  # 9_264 parameters
                layer_3 = Dropout(0.1),
            ),
            dropout = Dropout(0.1),
            layer_norm2 = LayerNorm((48, 1), affine=true, dims=Colon()),  # 96 parameters
        ),
        layer_3 = DecoderBlock(
            mha = MultiHeadAttention(
                query = Dense(48 => 48),  # 2_352 parameters
                key = Dense(48 => 48),  # 2_352 parameters
                value = Dense(48 => 48),  # 2_352 parameters
                dropout = Dropout(0.1),
                output = Dense(48 => 48),  # 2_352 parameters
            ),
            layer_norm1 = LayerNorm((48, 1), affine=true, dims=Colon()),  # 96 parameters
            ffn = Chain(
                layer_1 = Dense(48 => 192, gelu),  # 9_408 parameters
                layer_2 = Dense(192 => 48),  # 9_264 parameters
                layer_3 = Dropout(0.1),
            ),
            dropout = Dropout(0.1),
            layer_norm2 = LayerNorm((48, 1), affine=true, dims=Colon()),  # 96 parameters
        ),
    ),
    output = Dense(912 => 17),          # 15_521 parameters
)         # Total: 101_201 parameters,
          #        plus 930 states.

学習パラメータ

学習率: 6e-4
オプティマイザ: AdamW(β = (0.9, 0.95), γ = 0.1, ϵ = 1e-8)
※ bias, scaleのWeightDecayは0
学習データ: 10000個の足し算式を Train:Test=1:3 に分割

学習結果

シンプルモデル GPTモデル
学習時間 27m 学習時間 1h14m
MaxTestAcc 86.3% MaxTestAcc 99.9%

シンプルなDNNモデルはTrainの精度とTestの精度が連動していますが、GPTの方はTrainデータのaccが上限になってもTestデータのaccが上がり続けていますね。

感想

汎化性能の違いを明確に見ることができてよかったです。モデルサイズや学習データのTrain, Test比率などの調整にはちょっと苦労しましたが、満足な結果となりました。

Discussion