PyTorch to JAX 移行ガイド(GANの学習|TrainStateのカスタマイズ編)

2021/10/13に公開

背景

JAXベースのNNライブラリであるFlaxを用いて、PyTorchのコードをJAXに移行する方法を紹介しています。特に今回はGenerative Adversarial Networks (GAN) の学習を取り上げ、flaxにおける便利機能TrainStateのカスタマイズやJAXにおける乱数の扱いについて学びます。

https://gist.github.com/yonetaniryo/bfb566429410f3d77a2a838bab5c936a
例によってコードはgistにアップロードしてあります。

サンプル: Gaussian MixtureをGANで学習する

学習データの用意

今回は簡単な問題として、1次元、2つ山のあるGaaussian Mixtureを学習する問題を考えましょう

import numpy as np
from tensorflow.data import Dataset

data = np.concatenate([np.random.normal(0, 0.1, 10000), np.random.normal(0.5, 0.1, 10000)])

train_set = Dataset.from_tensor_slices(data).shuffle(len(data), reshuffle_each_iteration=True).batch(10).prefetch(1)

# 可視化
import matplotlib.pyplot as plt
plt.hist(data, 100);

GANの学習

今回はGenerator, DiscriminatorともにBatchNorm付きのMLPを用意し、MSELossを最小化するLeast Square GANを学習します。また、JAXへの移行を簡単にするために、それぞれのモデルの学習1ステップをg_step, d_stepという関数に切り出しておきます。

import time

import torch
import torch.nn as nn
import torch.optim as optim

class Generator(nn.Module):
    def __init__(self, z_dim=512):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(z_dim, 64), nn.BatchNorm1d(64, 0.8), nn.ReLU(),
                                 nn.Linear(64, 32), nn.BatchNorm1d(32, 0.8), nn.ReLU(),
                                 nn.Linear(32, 1))
        
    def forward(self, x):
        return self.net(x)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(1, 32), nn.BatchNorm1d(32, 0.8), nn.ReLU(),
                                 nn.Linear(32, 64), nn.BatchNorm1d(64, 0.8), nn.ReLU(),
                                 nn.Linear(64, 1))
        
        
    def forward(self, x):
        return self.net(x)

z_dim = 512
generator = Generator(z_dim)
g_opt = optim.Adam(generator.parameters(), lr=0.00005)
discriminator = Discriminator()
d_opt = optim.Adam(discriminator.parameters(), lr=0.0001)

# Least Square GANなのでMSE Loss
criterion = nn.MSELoss()

# Generatorの学習1ステップ
def g_step(z):
    valid = torch.ones((len(z), 1))
    fake_data = generator(z)
    g_loss = criterion(discriminator(fake_data), valid)
    g_opt.zero_grad()
    g_loss.backward()
    g_opt.step()

    return g_loss, fake_data

# Discriminatorの学習1ステップ
def d_step(real_data, fake_data):
    valid = torch.ones((len(real_data), 1))
    fake = torch.zeros((len(fake_data), 1))
    d_loss_real = criterion(discriminator(real_data), valid)
    d_opt.zero_grad()
    d_loss_real.backward()
    d_opt.step()
    d_loss_fake = criterion(discriminator(fake_data.detach()), fake)
    d_opt.zero_grad()
    d_loss_fake.backward()
    d_opt.step()
    d_loss = (d_loss_real + d_loss_fake) / 2.

    return d_loss

for e in range(30):
    g_loss_avg, d_loss_avg = 0, 0
    tic = time.time()
    for x in train_set.as_numpy_iterator():
        n = len(x)
        real_data = torch.from_numpy(x).float().reshape(n, -1)
        z = torch.randn((n, z_dim))
        g_loss, fake_data = g_step(z)
        d_loss = d_step(real_data, fake_data)
        g_loss_avg += g_loss.item()
        d_loss_avg += d_loss.item()
    g_loss_avg /= len(train_set)
    d_loss_avg /= len(train_set)
    elapsed = time.time() - tic
    print(f"epoch: {e}, g_loss: {g_loss_avg:0.2f}, d_loss: {d_loss_avg:0.2f}, elapased: {elapsed:0.2f}")

z = torch.randn((10000, z_dim))
fake_data = generator(z).detach().flatten().numpy()
plt.figure()
plt.hist(fake_data, 100, alpha=.5);
plt.hist(data, 100, alpha=.5);
plt.show()

Colab (CPU) で学習を回します(まあまあ待ちます)。

epoch: 0, g_loss: 0.37, d_loss: 0.26, elapased: 20.54
epoch: 1, g_loss: 0.33, d_loss: 0.24, elapased: 20.47
epoch: 2, g_loss: 0.27, d_loss: 0.23, elapased: 20.47
epoch: 3, g_loss: 0.29, d_loss: 0.24, elapased: 20.47
epoch: 4, g_loss: 0.28, d_loss: 0.24, elapased: 20.47
epoch: 5, g_loss: 0.28, d_loss: 0.24, elapased: 20.47
epoch: 6, g_loss: 0.27, d_loss: 0.24, elapased: 20.47
epoch: 7, g_loss: 0.27, d_loss: 0.24, elapased: 12.33
epoch: 8, g_loss: 0.26, d_loss: 0.25, elapased: 12.38
epoch: 9, g_loss: 0.26, d_loss: 0.25, elapased: 12.54
epoch: 10, g_loss: 0.26, d_loss: 0.25, elapased: 20.47
epoch: 11, g_loss: 0.26, d_loss: 0.25, elapased: 20.47
epoch: 12, g_loss: 0.26, d_loss: 0.25, elapased: 20.47
epoch: 13, g_loss: 0.26, d_loss: 0.25, elapased: 20.47
epoch: 14, g_loss: 0.26, d_loss: 0.25, elapased: 12.36
epoch: 15, g_loss: 0.26, d_loss: 0.25, elapased: 20.47
epoch: 16, g_loss: 0.26, d_loss: 0.25, elapased: 20.47
epoch: 17, g_loss: 0.26, d_loss: 0.25, elapased: 20.47
epoch: 18, g_loss: 0.26, d_loss: 0.25, elapased: 20.47
epoch: 19, g_loss: 0.26, d_loss: 0.25, elapased: 12.32
epoch: 20, g_loss: 0.26, d_loss: 0.25, elapased: 12.42
epoch: 21, g_loss: 0.26, d_loss: 0.25, elapased: 12.19
epoch: 22, g_loss: 0.25, d_loss: 0.25, elapased: 20.47
epoch: 23, g_loss: 0.25, d_loss: 0.25, elapased: 20.47
epoch: 24, g_loss: 0.25, d_loss: 0.25, elapased: 12.27
epoch: 25, g_loss: 0.26, d_loss: 0.25, elapased: 12.19
epoch: 26, g_loss: 0.26, d_loss: 0.25, elapased: 20.47
epoch: 27, g_loss: 0.26, d_loss: 0.25, elapased: 20.47
epoch: 28, g_loss: 0.26, d_loss: 0.25, elapased: 20.47
epoch: 29, g_loss: 0.26, d_loss: 0.25, elapased: 12.25


まあなんか学習できたようなできていないような・・・Gaussian MixtureをGANで学習するのって結構難しいんですよね。

TrainStateのカスタマイズ: BatchNormを使いやすくする

ここからJAX (Flax)への移行をしていきます。モデルはこんな感じになるはずです。

class Generator(fnn.Module):
    @fnn.compact
    def __call__(self, z, use_running_average):
        h = fnn.Dense(64)(z)
        h = fnn.BatchNorm(use_running_average=use_running_average, momentum=0.8)(h)
        h = fnn.relu(h)
        h = fnn.Dense(32)(h)
        h = fnn.BatchNorm(use_running_average=use_running_average, momentum=0.8)(h)
        h = fnn.relu(h)
        y = fnn.Dense(1)(h)
        return y

class Discriminator(fnn.Module):
    @fnn.compact
    def __call__(self, x, use_running_average):
        h = fnn.Dense(32)(x)
        h = fnn.BatchNorm(use_running_average=use_running_average, momentum=0.8)(h)
        h = fnn.relu(h)
        h = fnn.Dense(64)(h)
        h = fnn.BatchNorm(use_running_average=use_running_average, momentum=0.8)(h)
        h = fnn.relu(h)
        y = fnn.Dense(1)(h)
        return y

前回の解説の通り、BatchNormの挙動を制御する引数としてuse_running_average:bool_call__に渡せるようにします。

モデルを初期化します。

key = jax.random.PRNGKey(0)
g = Generator()
subkey, key = jax.random.split(key)
g_variables = g.init(subkey, jnp.ones((1, 512)), True)
d = Discriminator()
subkey, key = jax.random.split(key)
d_variables = d.init(subkey, jnp.ones((1, 1)), True)

このg_variables, d_variablesにそれぞれのモデルのパラメタとバッチ統計量が入っているのでした。

JAXにおける乱数

ところでJAXの主要な特徴の一つに、乱数の絡む関数には毎回必ず疑似乱数生成器のkeyを渡す必要があります。そして、これらの関数は与えられたkeyに対応する決定的な(毎回同じ)値を返します。

key = jax.random.PRNGKey(0)
print(jax.random.normal(key))
print(jax.random.normal(key))

-0.20584235
-0.20584235

新しい乱数を生成したい場合、jax.random.splitを使って、このkeyから新しいサブkeyを作ります。

subkey, key = jax.random.split(key)
print(jax.random.normal(subkey))

TrainStateのカスタマイズ

Flaxではモデルの定義(g = Generator)とパラメタ・内部状態(g.initで出てくるg_variables)は分離して扱われます。これはこれですっきりしていて良いのですが、学習の管理を行うflax.training.train_state.TrainStateはモデルのパラメタg_variables["params"]のみを受け取るため、モデルの内部状態であるバッチ統計量g_variables["batch_stats"]を別に扱う必要がありました。前回の学習ステップ実装を思い出しましょう。

@partial(jax.jit, static_argnums=(4,))
# stateとは別にbatch_statsを別途与える
def step(x, y, state, batch_stats, is_training=True):
    def loss_fn(params, batch_stats):
        y_pred, mutated_vars = state.apply_fn({"params": params, "batch_stats": batch_stats}, x, is_training, mutable=["batch_stats"]) 
        new_batch_stats = mutated_vars["batch_stats"]
        loss = optax.softmax_cross_entropy(logits=y_pred, labels=y).mean()
        return loss, (y_pred, new_batch_stats)
    y = jnp.eye(10)[y]
    if is_training:
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss, (y_pred, new_batch_stats)), grads = grad_fn(state.params, batch_stats)
        state = state.apply_gradients(grads=grads)
    else:
        loss, (y_pred, new_batch_stats) = loss_fn(state.params, batch_stats)
    return loss, y_pred, state, new_batch_stats
    
...
# 学習ステップ
for x, y in train_set.as_numpy_iterator(): 
        loss, y_pred, state, batch_stats = step(x, y, state, batch_stats, is_training=True)

もう少しすっきりさせるために、ここではTrainStateを拡張して、batch_statsも内部に持たせることを考えます。また、そのbatch_statsを使ったモデルのforwardや、forwardの結果変化したbatch_statsの更新を、それぞれメソッドとして定義しておきます。

class CustomTrainState(TrainState):
    batch_stats: dict
    
    # batch_statsのmutable対応をしたapply_fnへのショートカット
    def apply_fn_with_bn(self, *args, is_training, **nargs):
        output, mutated_vars = self.apply_fn(*args, **nargs, use_running_average=not is_training, mutable=["batch_stats"])
        new_batch_stats = mutated_vars["batch_stats"]
        return output, new_batch_stats
    
    def update_batch_stats(self, new_batch_stats):
        return self.replace(batch_stats=new_batch_stats)

こうすると、上のstepは以下のように書き換えることができます

def step(x, y, state, is_training=True):
    def loss_fn(params, batch_stats):
        y_pred, mutated_vars = state.apply_fn_with_bn({"params": params, "batch_stats": batch_stats}, x, is_training) 
        new_batch_stats = mutated_vars["batch_stats"]
        loss = optax.softmax_cross_entropy(logits=y_pred, labels=y).mean()
        return loss, (y_pred, new_batch_stats)
    y = jnp.eye(10)[y]
    if is_training:
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss, (y_pred, new_batch_stats)), grads = grad_fn(state.params, state.batch_stats)
        state = state.apply_gradients(grads=grads)
        state = state.update_batch_stats(new_batch_stats)
    else:
        loss, (y_pred, new_batch_stats) = loss_fn(state.params, state.batch_stats)
    return loss, y_pred, state

...
# 学習ステップ
for x, y in train_set.as_numpy_iterator(): 
        loss, y_pred, state = step(x, y, state, is_training=True)

これにより、batch_statsstep関数の中だけで呼び出して使えるようになりました。

Generatorの学習

ここでは、乱数zを受け取り、generatorのパラメタとバッチ統計量、discriminatorのバッチ統計量[1]を更新します。

    
@partial(jax.jit, static_argnums=(3,))
def g_step(z, g_state, d_state, is_training=True):
    def loss_fn(g_params, g_bs, d_params, d_bs, is_training):
        gen_data, g_bs = g_state.apply_fn_with_bn({"params": g_params, "batch_stats": g_bs}, z, is_training=is_training)
        validity, d_bs = d_state.apply_fn_with_bn({"params": d_params, "batch_stats": d_bs}, gen_data, is_training=is_training)
        g_loss = optax.l2_loss(validity, jnp.ones((len(validity), 1))).mean() * 2.
        # multiplication by 2 to make the loss consistent with pytorch. see https://github.com/deepmind/optax/blob/master/optax/_src/loss.py#L32#L54
        return g_loss, (gen_data, g_bs, d_bs)
	
    if is_training:
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (g_loss, (gen_data, g_bs, d_bs)), grads = grad_fn(g_state.params, g_state.batch_stats, d_state.params, d_state.batch_stats, True)
        g_state = g_state.apply_gradients(grads=grads)
        g_state = g_state.update_batch_stats(g_bs)
        d_state = g_state.update_batch_stats(d_bs)
    else:
        g_loss, (gen_data, g_bs, d_bs) = loss_fn(g_state.params, g_state.batch_stats, d_state.params, d_state.batch_stats, False)
    
    return g_loss, g_state, gen_data

Discriminatorの学習

ここでは、先のgenerator学習の際に生成したgen_dataと、学習データから引っ張ってきたreal_dataを用いて、discriminatorのパラメタとバッチ統計量を更新します。

@partial(jax.jit, static_argnums=(3,))
def d_step(real_data, fake_data, d_state, is_training=True):
    valid = jnp.ones((len(real_data), 1))
    fake = jnp.zeros((len(fake_data), 1))
    def loss_fn(d_params, d_bs, is_training):
        validity, d_bs = d_state.apply_fn_with_bn({"params": d_params, "batch_stats": d_bs}, x, is_training=is_training)
        d_loss = optax.l2_loss(validity, y).mean()
        return d_loss, d_bs
    
    loss = 0.
    if is_training:
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        for x, y in zip([real_data, fake_data], [valid, fake]):
            (loss_, d_bs), grads = grad_fn(d_state.params, d_state.batch_stats, True)
            d_state = d_state.apply_gradients(grads=grads)
            d_state = d_state.update_batch_stats(d_bs)
            loss += loss_
    else:
        for x, y in zip([real_data, fake_data], [valid, fake]):
            loss_, d_bs = loss_fn(d_state.params, d_state.batch_stats, False)
            loss += loss_
        
    return loss, d_state

完成: JAX移行後

これで完成です。全体としては以下のようになります。イテレーションのたびに新たなsubkeyを生成し、それに基づいて乱数zを生成しています。

from functools import partial

import jax
import jax.numpy as jnp
import flax.linen as fnn
from flax.core import FrozenDict
from flax.training.train_state import TrainState
import optax


class Generator(fnn.Module):
    @fnn.compact
    def __call__(self, z, use_running_average):
        h = fnn.Dense(64)(z)
        h = fnn.BatchNorm(use_running_average=use_running_average, momentum=0.8)(h)
        h = fnn.relu(h)
        h = fnn.Dense(32)(h)
        h = fnn.BatchNorm(use_running_average=use_running_average, momentum=0.8)(h)
        h = fnn.relu(h)
        y = fnn.Dense(1)(h)
        return y

class Discriminator(fnn.Module):
    @fnn.compact
    def __call__(self, x, use_running_average):
        h = fnn.Dense(32)(x)
        h = fnn.BatchNorm(use_running_average=use_running_average, momentum=0.8)(h)
        h = fnn.relu(h)
        h = fnn.Dense(64)(h)
        h = fnn.BatchNorm(use_running_average=use_running_average, momentum=0.8)(h)
        h = fnn.relu(h)
        y = fnn.Dense(1)(h)
        return y

class CustomTrainState(TrainState):
    batch_stats: dict
    
    def apply_fn_with_bn(self, *args, is_training, **nargs):
        output, mutated_vars = self.apply_fn(*args, **nargs, use_running_average=not is_training, mutable=["batch_stats"])
        new_batch_stats = mutated_vars["batch_stats"]
        return output, new_batch_stats
    
    def update_batch_stats(self, new_batch_stats):
        return self.replace(batch_stats=new_batch_stats)
    
@partial(jax.jit, static_argnums=(3,))
def g_step(z, g_state, d_state, is_training=True):
    def loss_fn(g_params, g_bs, d_params, d_bs, is_training):
        gen_data, g_bs = g_state.apply_fn_with_bn({"params": g_params, "batch_stats": g_bs}, z, is_training=is_training)
        validity, d_bs = d_state.apply_fn_with_bn({"params": d_params, "batch_stats": d_bs}, gen_data, is_training=is_training)
        g_loss = optax.l2_loss(validity, jnp.ones((len(validity), 1))).mean() * 2.
        # multiplication by 2 to make the loss consistent with pytorch. see https://github.com/deepmind/optax/blob/master/optax/_src/loss.py#L32#L54
        return g_loss, (gen_data, g_bs, d_bs)
        
    if is_training:
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (g_loss, (gen_data, g_bs, d_bs)), grads = grad_fn(g_state.params, g_state.batch_stats, d_state.params, d_state.batch_stats, True)
        g_state = g_state.apply_gradients(grads=grads)
        g_state = g_state.update_batch_stats(g_bs)
        d_state = g_state.update_batch_stats(d_bs)
    else:
        g_loss, (gen_data, g_bs, d_bs) = loss_fn(g_state.params, g_state.batch_stats, d_state.params, d_state.batch_stats, False)
    
    return g_loss, g_state, gen_data

@partial(jax.jit, static_argnums=(3,))
def d_step(real_data, fake_data, d_state, is_training=True):
    valid = jnp.ones((len(real_data), 1))
    fake = jnp.zeros((len(fake_data), 1))
    def loss_fn(d_params, d_bs, is_training):
        validity, d_bs = d_state.apply_fn_with_bn({"params": d_params, "batch_stats": d_bs}, x, is_training=is_training)
        d_loss = optax.l2_loss(validity, y).mean()
        return d_loss, d_bs
    
    loss = 0.
    if is_training:
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        for x, y in zip([real_data, fake_data], [valid, fake]):
            (loss_, d_bs), grads = grad_fn(d_state.params, d_state.batch_stats, True)
            d_state = d_state.apply_gradients(grads=grads)
            d_state = d_state.update_batch_stats(d_bs)
            loss += loss_
    else:
        for x, y in zip([real_data, fake_data], [valid, fake]):
            loss_, d_bs = loss_fn(d_state.params, d_state.batch_stats, False)
            loss += loss_
        
    return loss, d_state
    
key = jax.random.PRNGKey(0)

z_dim = 512
g = Generator()
subkey, key = jax.random.split(key)
g_variables = g.init(subkey, jnp.ones((1, z_dim)), True)
g_tx = optax.adam(0.00005)
g_state = CustomTrainState.create(apply_fn=g.apply, params=g_variables["params"], tx=g_tx, batch_stats=g_variables["batch_stats"])

d = Discriminator()
subkey, key = jax.random.split(key)
d_variables = d.init(subkey, jnp.ones((1, 1)), True)
d_tx = optax.adam(0.0001)
d_state = CustomTrainState.create(apply_fn=d.apply, params=d_variables["params"], tx=d_tx, batch_stats=d_variables["batch_stats"])

for e in range(30):
    g_loss_avg, d_loss_avg = 0, 0
    tic = time.time()
    for x in train_set.as_numpy_iterator():
        subkey, key = jax.random.split(key)
        z = jax.random.normal(subkey, (32, z_dim))
        real_data = jnp.array(x).reshape(-1, 1)
        g_loss, g_state, gen_data = g_step(z, g_state, d_state, True)
        d_loss, d_state = d_step(real_data, gen_data, d_state, True)
        g_loss_avg += g_loss
        d_loss_avg += d_loss
    g_loss_avg /= len(train_set)
    d_loss_avg /= len(train_set)
    elapsed = time.time() - tic
    print(f"epoch: {e}, g_loss: {g_loss_avg:0.2f}, d_loss: {d_loss_avg:0.2f}, elapased: {elapsed:0.2f}")
z = jax.random.normal(key, (10000, z_dim))
gen_data, g_bs = g_state.apply_fn_with_bn({"params": g_state.params, "batch_stats": g_state.batch_stats}, z, is_training=False)
plt.hist(gen_data.flatten(), 100);
plt.hist(data, 100, alpha=.5);
plt.legend(["Generated", "Data"])
plt.show()

回します。

epoch: 0, g_loss: 0.31, d_loss: 0.21, elapased: 9.47
epoch: 1, g_loss: 0.35, d_loss: 0.18, elapased: 4.16
epoch: 2, g_loss: 0.39, d_loss: 0.16, elapased: 5.11
epoch: 3, g_loss: 0.38, d_loss: 0.17, elapased: 4.14
epoch: 4, g_loss: 0.40, d_loss: 0.16, elapased: 4.09
epoch: 5, g_loss: 0.42, d_loss: 0.14, elapased: 5.11
epoch: 6, g_loss: 0.45, d_loss: 0.13, elapased: 3.95
epoch: 7, g_loss: 0.49, d_loss: 0.12, elapased: 4.16
epoch: 8, g_loss: 0.52, d_loss: 0.11, elapased: 4.12
epoch: 9, g_loss: 0.55, d_loss: 0.10, elapased: 4.01
epoch: 10, g_loss: 0.59, d_loss: 0.09, elapased: 5.11
epoch: 11, g_loss: 0.62, d_loss: 0.09, elapased: 4.02
epoch: 12, g_loss: 0.65, d_loss: 0.08, elapased: 4.12
epoch: 13, g_loss: 0.67, d_loss: 0.08, elapased: 5.11
epoch: 14, g_loss: 0.68, d_loss: 0.08, elapased: 4.16
epoch: 15, g_loss: 0.68, d_loss: 0.08, elapased: 5.11
epoch: 16, g_loss: 0.68, d_loss: 0.08, elapased: 4.09
epoch: 17, g_loss: 0.67, d_loss: 0.09, elapased: 5.11
epoch: 18, g_loss: 0.65, d_loss: 0.09, elapased: 4.23
epoch: 19, g_loss: 0.62, d_loss: 0.10, elapased: 5.11
epoch: 20, g_loss: 0.58, d_loss: 0.12, elapased: 5.11
epoch: 21, g_loss: 0.52, d_loss: 0.14, elapased: 4.11
epoch: 22, g_loss: 0.46, d_loss: 0.16, elapased: 5.11
epoch: 23, g_loss: 0.40, d_loss: 0.19, elapased: 4.19
epoch: 24, g_loss: 0.37, d_loss: 0.20, elapased: 4.11
epoch: 25, g_loss: 0.37, d_loss: 0.20, elapased: 5.11
epoch: 26, g_loss: 0.38, d_loss: 0.19, elapased: 5.11
epoch: 27, g_loss: 0.40, d_loss: 0.18, elapased: 4.16
epoch: 28, g_loss: 0.41, d_loss: 0.18, elapased: 5.11
epoch: 29, g_loss: 0.43, d_loss: 0.17, elapased: 4.34

今回もかなり高速化されました。30エポック合計でみると3.78倍速くなっています。

min max total (30 epochs)
PyTorch 12.19 20.54 532.68
JAX 3.95 9.47 (初回コンパイル時) 140.90


PyTorchのときより学習の経過が良い・・・?がすみません、これはちょっと理由がわかりませんでした。同じ学習をしているはずだが・・・

まとめ

JAXでGANを学習する方法を紹介しました。BatchNormが入るとどうしてもパラメタ更新が複雑になってしまいますが、TrainStateをカスタマイズすることで若干見通しを良くしました。更新ステップの部分(g_step, d_step)はだいたい毎回同じになるので、スニペット等の形式で使いまわせるようにしておくと良いかもしれません。

脚注
  1. よく考えるとgeneratorの更新時にdiscriminatorのバッチ統計量を変更すべきかどうかはあまり自明でない気がします。PyTorch-GANの例だと更新しているようなので、今回は更新しました ↩︎

Discussion