🖥️

MAML(Model-Agnostic Meta-Learning)を Pytorch で実装体験しました.

2022/01/23に公開約4,800字

はじめに

正月にメタ学習について理解を深めるために,MAML(Model-Agnostic Meta-Learning)を実装しました.

MAMLの論文(here)と日本語の解説記事(here)はハイパーリンクにあります.

MAML(Model-Agnostic Meta-Learning)とは

image.png

ICML2017に採択された論文で,回帰問題や分類問題,そして強化学習といった様々なタスクに対応できる一般的なモデルとなっております.

少ないパラメータ更新で未知のタスクを学習できるModel-Agnosticなアルゴリズムを持っているのが特徴です.

Model-Agnosticを翻訳すると「モデルにとらわれない」,すなわちMAMLは制約が少ない任意のモデルとなっております.

ここで数式を用いた簡単な流れを説明します.

パラメータ \theta に関するモデルを表す関数 f_{\theta} をタスク \mathcal{T}_{i} に対応させ,以下の更新式に基づいて各タスクに関するパラメータ \theta'_{i} を更新します.

\theta'_{i} = \theta - \alpha\nabla_{\theta}\mathcal{L}_{\mathcal{T}_{i}}\left(f_{\theta}\right)

パラメータ \theta が最小になるように,パラメータ \theta'_{i} に関するモデルを表す関数 f_{\theta'_{i}} を最適化します.

なお各タスクにおける損失関数の総和を取って最適化を行います.

\min_{\theta} \sum_{\mathcal{T}_{i} \sim p\left(\mathcal{T}\right)} \mathcal{L}_{\mathcal{T_{i}}}\left(f_{\theta'_{i}}\right) = \sum_{\mathcal{T}_{i} \sim p\left(\mathcal{T}\right)} \mathcal{L}_{\mathcal{T_{i}}}\left(f_{\theta - \alpha\nabla_{\theta}\mathcal{L}_{\mathcal{T}_{i}}\left(f_{\theta}\right)}\right)

そして以下の更新式に基づいて,パラメータ \theta を更新します.

\theta \leftarrow \theta - \beta\nabla_{\theta} \sum_{\mathcal{T}_{i} \sim p\left(\mathcal{T}\right)} \mathcal{L}_{\mathcal{T_{i}}}\left(f_{\theta'_{i}}\right)

なお \alpha, \beta はハイパーパラメータとなっております.

続いて,MAMLのアルゴリズムは以下の通りです.

まず最初にパラメータ \theta を初期化します.

続いてエポックごとのタスク \mathcal{T} を一定の数サンプリングしてデータとして用います.

そして,サンプリングされたタスクごとに各タスクに関するパラメータ \theta'_{i} の更新とf_{\theta'_{i}} の最適化を行います.

最後に,各タスクの総和を用いてパラメータ \theta を更新します.

実装

MAML+Pytorchの記事を参考にして,sin関数の回帰問題に挑戦してみました.

コードはhereに置いております.

MAMLモデル

MAMLモデルを以下の通りに実装しました.

forwardに入力データに加え,各パラメータを入力します.

なお私は出力層にReLU関数を間違って付けておりましたが,外してください(理由は後述).

class MAML(nn.Module):
    def __init__(self):
        super(MAML, self).__init__()

    def forward(self, x, params):
        x = F.relu(F.linear(x, params['input_net_weight'], params['input_net_bias']))
        x = F.relu(F.linear(x, params['latent_net_weight'], params['latent_net_bias']))
        x = F.linear(x, params['output_net_weight'], params['output_net_bias'])
        return x

1. パラメータの初期化

パラメータをTensorで初期化します.

'○○_net_weight'や'○○_net_bias'といったMAMLモデルの各層の名前とtorch.Tensor(範囲[-1, 1])を格納します.

params = OrderedDict([
        ('input_net_weight', torch.Tensor(32, 1).uniform_(-1., 1.).requires_grad_()),
        ('input_net_bias', torch.Tensor(32).uniform_(-1., 1.).requires_grad_()),

        ('latent_net_weight', torch.Tensor(32, 32).uniform_(-1., 1.).requires_grad_()),
        ('latent_net_bias', torch.Tensor(32).uniform_(-1., 1.).requires_grad_()),

        ('output_net_weight', torch.Tensor(1, 32).uniform_(-1., 1.).requires_grad_()),
        ('output_net_bias', torch.Tensor(1).uniform_(-1., 1.).requires_grad_())
    ])

2.タスクのサンプリング

サポートセット(従来の機械学習では学習データ)を用意します.

 b = 0 if random.choice([True, False]) else math.pi
 train_x = torch.rand(4, 1) * 4 * math.pi - 2 * math.pi
 train_y = torch.sin(train_x + b)

3.各タスクの学習

各タスクのエポック数は5にしています.

ちなみに,torch.autograd.gradは自動で損失関数をパラメータで偏微分してくれます(私は誤ってMAMLモデルの出力をReLU関数に介してしまっていたのに気づかずに,gradsで0.0(勾配0)を出力していました.).

 new_params = params

 for k in range(args.K): # K = 5
     pred_train_y = model(train_x, new_params)
     train_loss = F.l1_loss(pred_train_y, train_y)
	 
     grads = torch.autograd.grad(train_loss, new_params.values(), create_graph=True)
     new_params = OrderedDict((name, param - args.alpha * grad) for ((name, param), grad) in zip(params.items(), grads))

4.Fine-tuneを用いた学習

クエリセット(従来の機械学習でいうテストデータ)を用意し,あらかじめ学習済みのパラメータでFine-tuningして学習します.

 val_x = torch.rand(4, 1) * 4 * math.pi - 2 * math.pi
 val_y = torch.sin(val_x + b)
 
 pred_val_y = model(val_x, new_params)
 val_loss = F.l1_loss(pred_val_y, val_y)
 val_loss.backward(retain_graph=True)
 optim.step()

結果

以下は全エポック数が300,000回,各タスクのエポック数が5の時の結果です.

若干歪んでいるところがありますが,教師データのsin(x)に近い出力になっていることが分かります.

おわりに

本記事では,簡単なMAMLモデルを実装してみました.

今後は他のソースコード(here)を拝借してデータセットを変えつつ学習するつもりですが,データセットの作り方に苦戦しております.

参考文献

https://arxiv.org/abs/1703.03400
https://paperswithcode.com/method/maml
https://qiita.com/ku2482/items/ee2fd87bbb5353664f59
https://www.sagargv.com/blog/meta-learning-in-pytorch/

Discussion

ログインするとコメントできます