PyTorch to JAX 移行ガイド(GANの学習|TrainStateのカスタマイズ編)
背景
JAXベースのNNライブラリであるFlaxを用いて、PyTorchのコードをJAXに移行する方法を紹介しています。特に今回はGenerative Adversarial Networks (GAN) の学習を取り上げ、flaxにおける便利機能TrainState
のカスタマイズやJAXにおける乱数の扱いについて学びます。
例によってコードはgistにアップロードしてあります。
サンプル: Gaussian MixtureをGANで学習する
学習データの用意
今回は簡単な問題として、1次元、2つ山のあるGaaussian Mixtureを学習する問題を考えましょう
import numpy as np
from tensorflow.data import Dataset
data = np.concatenate([np.random.normal(0, 0.1, 10000), np.random.normal(0.5, 0.1, 10000)])
train_set = Dataset.from_tensor_slices(data).shuffle(len(data), reshuffle_each_iteration=True).batch(10).prefetch(1)
# 可視化
import matplotlib.pyplot as plt
plt.hist(data, 100);
GANの学習
今回はGenerator, DiscriminatorともにBatchNorm付きのMLPを用意し、MSELossを最小化するLeast Square GANを学習します。また、JAXへの移行を簡単にするために、それぞれのモデルの学習1ステップをg_step
, d_step
という関数に切り出しておきます。
import time
import torch
import torch.nn as nn
import torch.optim as optim
class Generator(nn.Module):
def __init__(self, z_dim=512):
super().__init__()
self.net = nn.Sequential(nn.Linear(z_dim, 64), nn.BatchNorm1d(64, 0.8), nn.ReLU(),
nn.Linear(64, 32), nn.BatchNorm1d(32, 0.8), nn.ReLU(),
nn.Linear(32, 1))
def forward(self, x):
return self.net(x)
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(nn.Linear(1, 32), nn.BatchNorm1d(32, 0.8), nn.ReLU(),
nn.Linear(32, 64), nn.BatchNorm1d(64, 0.8), nn.ReLU(),
nn.Linear(64, 1))
def forward(self, x):
return self.net(x)
z_dim = 512
generator = Generator(z_dim)
g_opt = optim.Adam(generator.parameters(), lr=0.00005)
discriminator = Discriminator()
d_opt = optim.Adam(discriminator.parameters(), lr=0.0001)
# Least Square GANなのでMSE Loss
criterion = nn.MSELoss()
# Generatorの学習1ステップ
def g_step(z):
valid = torch.ones((len(z), 1))
fake_data = generator(z)
g_loss = criterion(discriminator(fake_data), valid)
g_opt.zero_grad()
g_loss.backward()
g_opt.step()
return g_loss, fake_data
# Discriminatorの学習1ステップ
def d_step(real_data, fake_data):
valid = torch.ones((len(real_data), 1))
fake = torch.zeros((len(fake_data), 1))
d_loss_real = criterion(discriminator(real_data), valid)
d_opt.zero_grad()
d_loss_real.backward()
d_opt.step()
d_loss_fake = criterion(discriminator(fake_data.detach()), fake)
d_opt.zero_grad()
d_loss_fake.backward()
d_opt.step()
d_loss = (d_loss_real + d_loss_fake) / 2.
return d_loss
for e in range(30):
g_loss_avg, d_loss_avg = 0, 0
tic = time.time()
for x in train_set.as_numpy_iterator():
n = len(x)
real_data = torch.from_numpy(x).float().reshape(n, -1)
z = torch.randn((n, z_dim))
g_loss, fake_data = g_step(z)
d_loss = d_step(real_data, fake_data)
g_loss_avg += g_loss.item()
d_loss_avg += d_loss.item()
g_loss_avg /= len(train_set)
d_loss_avg /= len(train_set)
elapsed = time.time() - tic
print(f"epoch: {e}, g_loss: {g_loss_avg:0.2f}, d_loss: {d_loss_avg:0.2f}, elapased: {elapsed:0.2f}")
z = torch.randn((10000, z_dim))
fake_data = generator(z).detach().flatten().numpy()
plt.figure()
plt.hist(fake_data, 100, alpha=.5);
plt.hist(data, 100, alpha=.5);
plt.show()
Colab (CPU) で学習を回します(まあまあ待ちます)。
epoch: 0, g_loss: 0.37, d_loss: 0.26, elapased: 20.54
epoch: 1, g_loss: 0.33, d_loss: 0.24, elapased: 20.47
epoch: 2, g_loss: 0.27, d_loss: 0.23, elapased: 20.47
epoch: 3, g_loss: 0.29, d_loss: 0.24, elapased: 20.47
epoch: 4, g_loss: 0.28, d_loss: 0.24, elapased: 20.47
epoch: 5, g_loss: 0.28, d_loss: 0.24, elapased: 20.47
epoch: 6, g_loss: 0.27, d_loss: 0.24, elapased: 20.47
epoch: 7, g_loss: 0.27, d_loss: 0.24, elapased: 12.33
epoch: 8, g_loss: 0.26, d_loss: 0.25, elapased: 12.38
epoch: 9, g_loss: 0.26, d_loss: 0.25, elapased: 12.54
epoch: 10, g_loss: 0.26, d_loss: 0.25, elapased: 20.47
epoch: 11, g_loss: 0.26, d_loss: 0.25, elapased: 20.47
epoch: 12, g_loss: 0.26, d_loss: 0.25, elapased: 20.47
epoch: 13, g_loss: 0.26, d_loss: 0.25, elapased: 20.47
epoch: 14, g_loss: 0.26, d_loss: 0.25, elapased: 12.36
epoch: 15, g_loss: 0.26, d_loss: 0.25, elapased: 20.47
epoch: 16, g_loss: 0.26, d_loss: 0.25, elapased: 20.47
epoch: 17, g_loss: 0.26, d_loss: 0.25, elapased: 20.47
epoch: 18, g_loss: 0.26, d_loss: 0.25, elapased: 20.47
epoch: 19, g_loss: 0.26, d_loss: 0.25, elapased: 12.32
epoch: 20, g_loss: 0.26, d_loss: 0.25, elapased: 12.42
epoch: 21, g_loss: 0.26, d_loss: 0.25, elapased: 12.19
epoch: 22, g_loss: 0.25, d_loss: 0.25, elapased: 20.47
epoch: 23, g_loss: 0.25, d_loss: 0.25, elapased: 20.47
epoch: 24, g_loss: 0.25, d_loss: 0.25, elapased: 12.27
epoch: 25, g_loss: 0.26, d_loss: 0.25, elapased: 12.19
epoch: 26, g_loss: 0.26, d_loss: 0.25, elapased: 20.47
epoch: 27, g_loss: 0.26, d_loss: 0.25, elapased: 20.47
epoch: 28, g_loss: 0.26, d_loss: 0.25, elapased: 20.47
epoch: 29, g_loss: 0.26, d_loss: 0.25, elapased: 12.25
まあなんか学習できたようなできていないような・・・Gaussian MixtureをGANで学習するのって結構難しいんですよね。
TrainStateのカスタマイズ: BatchNormを使いやすくする
ここからJAX (Flax)への移行をしていきます。モデルはこんな感じになるはずです。
class Generator(fnn.Module):
@fnn.compact
def __call__(self, z, use_running_average):
h = fnn.Dense(64)(z)
h = fnn.BatchNorm(use_running_average=use_running_average, momentum=0.8)(h)
h = fnn.relu(h)
h = fnn.Dense(32)(h)
h = fnn.BatchNorm(use_running_average=use_running_average, momentum=0.8)(h)
h = fnn.relu(h)
y = fnn.Dense(1)(h)
return y
class Discriminator(fnn.Module):
@fnn.compact
def __call__(self, x, use_running_average):
h = fnn.Dense(32)(x)
h = fnn.BatchNorm(use_running_average=use_running_average, momentum=0.8)(h)
h = fnn.relu(h)
h = fnn.Dense(64)(h)
h = fnn.BatchNorm(use_running_average=use_running_average, momentum=0.8)(h)
h = fnn.relu(h)
y = fnn.Dense(1)(h)
return y
前回の解説の通り、BatchNormの挙動を制御する引数としてuse_running_average:bool
を_call__
に渡せるようにします。
モデルを初期化します。
key = jax.random.PRNGKey(0)
g = Generator()
subkey, key = jax.random.split(key)
g_variables = g.init(subkey, jnp.ones((1, 512)), True)
d = Discriminator()
subkey, key = jax.random.split(key)
d_variables = d.init(subkey, jnp.ones((1, 1)), True)
このg_variables
, d_variables
にそれぞれのモデルのパラメタとバッチ統計量が入っているのでした。
JAXにおける乱数
ところでJAXの主要な特徴の一つに、乱数の絡む関数には毎回必ず疑似乱数生成器のkeyを渡す必要があります。そして、これらの関数は与えられたkeyに対応する決定的な(毎回同じ)値を返します。
key = jax.random.PRNGKey(0)
print(jax.random.normal(key))
print(jax.random.normal(key))
↓
-0.20584235
-0.20584235
新しい乱数を生成したい場合、jax.random.split
を使って、このkeyから新しいサブkeyを作ります。
subkey, key = jax.random.split(key)
print(jax.random.normal(subkey))
TrainStateのカスタマイズ
Flaxではモデルの定義(g = Generator
)とパラメタ・内部状態(g.init
で出てくるg_variables
)は分離して扱われます。これはこれですっきりしていて良いのですが、学習の管理を行うflax.training.train_state.TrainState
はモデルのパラメタg_variables["params"]
のみを受け取るため、モデルの内部状態であるバッチ統計量g_variables["batch_stats"]
を別に扱う必要がありました。前回の学習ステップ実装を思い出しましょう。
@partial(jax.jit, static_argnums=(4,))
# stateとは別にbatch_statsを別途与える
def step(x, y, state, batch_stats, is_training=True):
def loss_fn(params, batch_stats):
y_pred, mutated_vars = state.apply_fn({"params": params, "batch_stats": batch_stats}, x, is_training, mutable=["batch_stats"])
new_batch_stats = mutated_vars["batch_stats"]
loss = optax.softmax_cross_entropy(logits=y_pred, labels=y).mean()
return loss, (y_pred, new_batch_stats)
y = jnp.eye(10)[y]
if is_training:
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, (y_pred, new_batch_stats)), grads = grad_fn(state.params, batch_stats)
state = state.apply_gradients(grads=grads)
else:
loss, (y_pred, new_batch_stats) = loss_fn(state.params, batch_stats)
return loss, y_pred, state, new_batch_stats
...
# 学習ステップ
for x, y in train_set.as_numpy_iterator():
loss, y_pred, state, batch_stats = step(x, y, state, batch_stats, is_training=True)
もう少しすっきりさせるために、ここではTrainStateを拡張して、batch_stats
も内部に持たせることを考えます。また、そのbatch_stats
を使ったモデルのforwardや、forwardの結果変化したbatch_stats
の更新を、それぞれメソッドとして定義しておきます。
class CustomTrainState(TrainState):
batch_stats: dict
# batch_statsのmutable対応をしたapply_fnへのショートカット
def apply_fn_with_bn(self, *args, is_training, **nargs):
output, mutated_vars = self.apply_fn(*args, **nargs, use_running_average=not is_training, mutable=["batch_stats"])
new_batch_stats = mutated_vars["batch_stats"]
return output, new_batch_stats
def update_batch_stats(self, new_batch_stats):
return self.replace(batch_stats=new_batch_stats)
こうすると、上のstep
は以下のように書き換えることができます
def step(x, y, state, is_training=True):
def loss_fn(params, batch_stats):
y_pred, mutated_vars = state.apply_fn_with_bn({"params": params, "batch_stats": batch_stats}, x, is_training)
new_batch_stats = mutated_vars["batch_stats"]
loss = optax.softmax_cross_entropy(logits=y_pred, labels=y).mean()
return loss, (y_pred, new_batch_stats)
y = jnp.eye(10)[y]
if is_training:
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, (y_pred, new_batch_stats)), grads = grad_fn(state.params, state.batch_stats)
state = state.apply_gradients(grads=grads)
state = state.update_batch_stats(new_batch_stats)
else:
loss, (y_pred, new_batch_stats) = loss_fn(state.params, state.batch_stats)
return loss, y_pred, state
...
# 学習ステップ
for x, y in train_set.as_numpy_iterator():
loss, y_pred, state = step(x, y, state, is_training=True)
これにより、batch_stats
をstep
関数の中だけで呼び出して使えるようになりました。
Generatorの学習
ここでは、乱数z
を受け取り、generatorのパラメタとバッチ統計量、discriminatorのバッチ統計量[1]を更新します。
@partial(jax.jit, static_argnums=(3,))
def g_step(z, g_state, d_state, is_training=True):
def loss_fn(g_params, g_bs, d_params, d_bs, is_training):
gen_data, g_bs = g_state.apply_fn_with_bn({"params": g_params, "batch_stats": g_bs}, z, is_training=is_training)
validity, d_bs = d_state.apply_fn_with_bn({"params": d_params, "batch_stats": d_bs}, gen_data, is_training=is_training)
g_loss = optax.l2_loss(validity, jnp.ones((len(validity), 1))).mean() * 2.
# multiplication by 2 to make the loss consistent with pytorch. see https://github.com/deepmind/optax/blob/master/optax/_src/loss.py#L32#L54
return g_loss, (gen_data, g_bs, d_bs)
if is_training:
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(g_loss, (gen_data, g_bs, d_bs)), grads = grad_fn(g_state.params, g_state.batch_stats, d_state.params, d_state.batch_stats, True)
g_state = g_state.apply_gradients(grads=grads)
g_state = g_state.update_batch_stats(g_bs)
d_state = g_state.update_batch_stats(d_bs)
else:
g_loss, (gen_data, g_bs, d_bs) = loss_fn(g_state.params, g_state.batch_stats, d_state.params, d_state.batch_stats, False)
return g_loss, g_state, gen_data
Discriminatorの学習
ここでは、先のgenerator学習の際に生成したgen_data
と、学習データから引っ張ってきたreal_data
を用いて、discriminatorのパラメタとバッチ統計量を更新します。
@partial(jax.jit, static_argnums=(3,))
def d_step(real_data, fake_data, d_state, is_training=True):
valid = jnp.ones((len(real_data), 1))
fake = jnp.zeros((len(fake_data), 1))
def loss_fn(d_params, d_bs, is_training):
validity, d_bs = d_state.apply_fn_with_bn({"params": d_params, "batch_stats": d_bs}, x, is_training=is_training)
d_loss = optax.l2_loss(validity, y).mean()
return d_loss, d_bs
loss = 0.
if is_training:
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
for x, y in zip([real_data, fake_data], [valid, fake]):
(loss_, d_bs), grads = grad_fn(d_state.params, d_state.batch_stats, True)
d_state = d_state.apply_gradients(grads=grads)
d_state = d_state.update_batch_stats(d_bs)
loss += loss_
else:
for x, y in zip([real_data, fake_data], [valid, fake]):
loss_, d_bs = loss_fn(d_state.params, d_state.batch_stats, False)
loss += loss_
return loss, d_state
完成: JAX移行後
これで完成です。全体としては以下のようになります。イテレーションのたびに新たなsubkey
を生成し、それに基づいて乱数z
を生成しています。
from functools import partial
import jax
import jax.numpy as jnp
import flax.linen as fnn
from flax.core import FrozenDict
from flax.training.train_state import TrainState
import optax
class Generator(fnn.Module):
@fnn.compact
def __call__(self, z, use_running_average):
h = fnn.Dense(64)(z)
h = fnn.BatchNorm(use_running_average=use_running_average, momentum=0.8)(h)
h = fnn.relu(h)
h = fnn.Dense(32)(h)
h = fnn.BatchNorm(use_running_average=use_running_average, momentum=0.8)(h)
h = fnn.relu(h)
y = fnn.Dense(1)(h)
return y
class Discriminator(fnn.Module):
@fnn.compact
def __call__(self, x, use_running_average):
h = fnn.Dense(32)(x)
h = fnn.BatchNorm(use_running_average=use_running_average, momentum=0.8)(h)
h = fnn.relu(h)
h = fnn.Dense(64)(h)
h = fnn.BatchNorm(use_running_average=use_running_average, momentum=0.8)(h)
h = fnn.relu(h)
y = fnn.Dense(1)(h)
return y
class CustomTrainState(TrainState):
batch_stats: dict
def apply_fn_with_bn(self, *args, is_training, **nargs):
output, mutated_vars = self.apply_fn(*args, **nargs, use_running_average=not is_training, mutable=["batch_stats"])
new_batch_stats = mutated_vars["batch_stats"]
return output, new_batch_stats
def update_batch_stats(self, new_batch_stats):
return self.replace(batch_stats=new_batch_stats)
@partial(jax.jit, static_argnums=(3,))
def g_step(z, g_state, d_state, is_training=True):
def loss_fn(g_params, g_bs, d_params, d_bs, is_training):
gen_data, g_bs = g_state.apply_fn_with_bn({"params": g_params, "batch_stats": g_bs}, z, is_training=is_training)
validity, d_bs = d_state.apply_fn_with_bn({"params": d_params, "batch_stats": d_bs}, gen_data, is_training=is_training)
g_loss = optax.l2_loss(validity, jnp.ones((len(validity), 1))).mean() * 2.
# multiplication by 2 to make the loss consistent with pytorch. see https://github.com/deepmind/optax/blob/master/optax/_src/loss.py#L32#L54
return g_loss, (gen_data, g_bs, d_bs)
if is_training:
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(g_loss, (gen_data, g_bs, d_bs)), grads = grad_fn(g_state.params, g_state.batch_stats, d_state.params, d_state.batch_stats, True)
g_state = g_state.apply_gradients(grads=grads)
g_state = g_state.update_batch_stats(g_bs)
d_state = g_state.update_batch_stats(d_bs)
else:
g_loss, (gen_data, g_bs, d_bs) = loss_fn(g_state.params, g_state.batch_stats, d_state.params, d_state.batch_stats, False)
return g_loss, g_state, gen_data
@partial(jax.jit, static_argnums=(3,))
def d_step(real_data, fake_data, d_state, is_training=True):
valid = jnp.ones((len(real_data), 1))
fake = jnp.zeros((len(fake_data), 1))
def loss_fn(d_params, d_bs, is_training):
validity, d_bs = d_state.apply_fn_with_bn({"params": d_params, "batch_stats": d_bs}, x, is_training=is_training)
d_loss = optax.l2_loss(validity, y).mean()
return d_loss, d_bs
loss = 0.
if is_training:
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
for x, y in zip([real_data, fake_data], [valid, fake]):
(loss_, d_bs), grads = grad_fn(d_state.params, d_state.batch_stats, True)
d_state = d_state.apply_gradients(grads=grads)
d_state = d_state.update_batch_stats(d_bs)
loss += loss_
else:
for x, y in zip([real_data, fake_data], [valid, fake]):
loss_, d_bs = loss_fn(d_state.params, d_state.batch_stats, False)
loss += loss_
return loss, d_state
key = jax.random.PRNGKey(0)
z_dim = 512
g = Generator()
subkey, key = jax.random.split(key)
g_variables = g.init(subkey, jnp.ones((1, z_dim)), True)
g_tx = optax.adam(0.00005)
g_state = CustomTrainState.create(apply_fn=g.apply, params=g_variables["params"], tx=g_tx, batch_stats=g_variables["batch_stats"])
d = Discriminator()
subkey, key = jax.random.split(key)
d_variables = d.init(subkey, jnp.ones((1, 1)), True)
d_tx = optax.adam(0.0001)
d_state = CustomTrainState.create(apply_fn=d.apply, params=d_variables["params"], tx=d_tx, batch_stats=d_variables["batch_stats"])
for e in range(30):
g_loss_avg, d_loss_avg = 0, 0
tic = time.time()
for x in train_set.as_numpy_iterator():
subkey, key = jax.random.split(key)
z = jax.random.normal(subkey, (32, z_dim))
real_data = jnp.array(x).reshape(-1, 1)
g_loss, g_state, gen_data = g_step(z, g_state, d_state, True)
d_loss, d_state = d_step(real_data, gen_data, d_state, True)
g_loss_avg += g_loss
d_loss_avg += d_loss
g_loss_avg /= len(train_set)
d_loss_avg /= len(train_set)
elapsed = time.time() - tic
print(f"epoch: {e}, g_loss: {g_loss_avg:0.2f}, d_loss: {d_loss_avg:0.2f}, elapased: {elapsed:0.2f}")
z = jax.random.normal(key, (10000, z_dim))
gen_data, g_bs = g_state.apply_fn_with_bn({"params": g_state.params, "batch_stats": g_state.batch_stats}, z, is_training=False)
plt.hist(gen_data.flatten(), 100);
plt.hist(data, 100, alpha=.5);
plt.legend(["Generated", "Data"])
plt.show()
回します。
epoch: 0, g_loss: 0.31, d_loss: 0.21, elapased: 9.47
epoch: 1, g_loss: 0.35, d_loss: 0.18, elapased: 4.16
epoch: 2, g_loss: 0.39, d_loss: 0.16, elapased: 5.11
epoch: 3, g_loss: 0.38, d_loss: 0.17, elapased: 4.14
epoch: 4, g_loss: 0.40, d_loss: 0.16, elapased: 4.09
epoch: 5, g_loss: 0.42, d_loss: 0.14, elapased: 5.11
epoch: 6, g_loss: 0.45, d_loss: 0.13, elapased: 3.95
epoch: 7, g_loss: 0.49, d_loss: 0.12, elapased: 4.16
epoch: 8, g_loss: 0.52, d_loss: 0.11, elapased: 4.12
epoch: 9, g_loss: 0.55, d_loss: 0.10, elapased: 4.01
epoch: 10, g_loss: 0.59, d_loss: 0.09, elapased: 5.11
epoch: 11, g_loss: 0.62, d_loss: 0.09, elapased: 4.02
epoch: 12, g_loss: 0.65, d_loss: 0.08, elapased: 4.12
epoch: 13, g_loss: 0.67, d_loss: 0.08, elapased: 5.11
epoch: 14, g_loss: 0.68, d_loss: 0.08, elapased: 4.16
epoch: 15, g_loss: 0.68, d_loss: 0.08, elapased: 5.11
epoch: 16, g_loss: 0.68, d_loss: 0.08, elapased: 4.09
epoch: 17, g_loss: 0.67, d_loss: 0.09, elapased: 5.11
epoch: 18, g_loss: 0.65, d_loss: 0.09, elapased: 4.23
epoch: 19, g_loss: 0.62, d_loss: 0.10, elapased: 5.11
epoch: 20, g_loss: 0.58, d_loss: 0.12, elapased: 5.11
epoch: 21, g_loss: 0.52, d_loss: 0.14, elapased: 4.11
epoch: 22, g_loss: 0.46, d_loss: 0.16, elapased: 5.11
epoch: 23, g_loss: 0.40, d_loss: 0.19, elapased: 4.19
epoch: 24, g_loss: 0.37, d_loss: 0.20, elapased: 4.11
epoch: 25, g_loss: 0.37, d_loss: 0.20, elapased: 5.11
epoch: 26, g_loss: 0.38, d_loss: 0.19, elapased: 5.11
epoch: 27, g_loss: 0.40, d_loss: 0.18, elapased: 4.16
epoch: 28, g_loss: 0.41, d_loss: 0.18, elapased: 5.11
epoch: 29, g_loss: 0.43, d_loss: 0.17, elapased: 4.34
今回もかなり高速化されました。30エポック合計でみると3.78倍速くなっています。
min | max | total (30 epochs) | |
---|---|---|---|
PyTorch | 12.19 | 20.54 | 532.68 |
JAX | 3.95 | 9.47 (初回コンパイル時) | 140.90 |
PyTorchのときより学習の経過が良い・・・?がすみません、これはちょっと理由がわかりませんでした。同じ学習をしているはずだが・・・
まとめ
JAXでGANを学習する方法を紹介しました。BatchNormが入るとどうしてもパラメタ更新が複雑になってしまいますが、TrainStateをカスタマイズすることで若干見通しを良くしました。更新ステップの部分(g_step
, d_step
)はだいたい毎回同じになるので、スニペット等の形式で使いまわせるようにしておくと良いかもしれません。
-
よく考えるとgeneratorの更新時にdiscriminatorのバッチ統計量を変更すべきかどうかはあまり自明でない気がします。PyTorch-GANの例だと更新しているようなので、今回は更新しました ↩︎
Discussion