PyTorchで学ぶ勾配降下法: 円の位置最適化問題
こんにちは、ZENブログの読者の皆さん。今回は機械学習の基礎である「勾配降下法」について、視覚的に理解しやすい例を用いて解説します。
今回の目標
今回は、PyTorchを使って「ある位置にある円を、目標となる位置まで自動的に移動させる」という問題に取り組みます。これは単純ですが、深層学習の最適化問題の本質を含んでいます。
コードの概要
このコードでは、次のことを行います:
目標となる円の画像を生成
初期位置に円を配置
勾配降下法を用いて円の位置を最適化
最適化の過程を可視化
技術的な解説
1. 環境のセットアップ
まず必要なライブラリをインポートし、画像サイズを設定します。
pythonimport torch
import matplotlib.pyplot as plt
from tqdm import tqdm
画像サイズ
width, height = 64, 64
- ターゲット(目標)の設定
目標となる円の位置とサイズを設定します。これが最終的に到達したい状態です。
python# ターゲット(円)のパラメータ設定
target_x = torch.tensor(0.5, requires_grad=False)
target_y = torch.tensor(0.3, requires_grad=False)
target_radius = torch.tensor(0.2, requires_grad=False)
- 画像座標系の作成
画像の各ピクセルの座標を計算します。これにより、円を描画する際に各ピクセルが円の内側か外側かを判断できます。
python# 画像座標系の作成
image_x = torch.linspace(-1, 1, width)
image_y = torch.linspace(-1, 1, height)
image_y_grid, image_x_grid = torch.meshgrid(image_y, image_x)
- ターゲット画像の生成
目標となる円の画像を生成します。各ピクセルが円の中心からの距離に基づいて値が決まります。
python# ターゲット画像の生成 - 修正した距離関数
dist_squared_target = (image_x_grid - target_x)**2 + (image_y_grid - target_y)**2
sharpness = 10.0
target_image = 0.5 - 0.5 * torch.tanh(sharpness * (dist_squared_target - target_radius**2))
ここで tanh 関数を使って、円の境界をスムーズにしています。sharpness パラメータは境界のシャープさを制御します。
8. 最適化するパラメータの設定
最適化したいパラメータ(円の中心座標)を設定します。requires_grad=True を設定することで、PyTorchにこのパラメータを最適化するよう指示します。
python# 最適化するパラメータ
optim_x = torch.tensor(-0.4, requires_grad=True)
optim_y = torch.tensor(-0.3, requires_grad=True)
initial_radius = torch.tensor(0.2, requires_grad=False)
- 最適化のハイパーパラメータとオプティマイザーの設定
最適化のプロセスを制御するパラメータとオプティマイザーを設定します。
python# 最適化のハイパーパラメータ
iterations = 100 # イテレーション数
lr = 0.02 # 学習率
オプティマイザーの設定
parameters = [optim_x, optim_y]
optimizer = torch.optim.Adam(parameters, lr=lr)
ここでは「Adam」オプティマイザーを使用しています。これは勾配降下法の拡張版で、より効率的に最適解に収束する特性があります。
7. 最適化ループ
この部分が勾配降下法の核心です。各イテレーションで以下のステップを繰り返します:
現在のパラメータで円を描画
目標画像との差(損失)を計算
損失を最小化する方向に勾配を計算
パラメータを更新
python# 最適化ループ
for i in range(iterations):
# オプティマイザーの勾配をリセット
optimizer.zero_grad()
# 現在の状態をレンダリング
dist_squared = (image_x_grid - optim_x)**2 + (image_y_grid - optim_y)**2
rendered_image = 0.5 - 0.5 * torch.tanh(sharpness * (dist_squared - initial_radius**2))
# 損失を計算
loss = torch.sum((rendered_image - target_image)**2)
loss_history.append(loss.item())
# 勾配を計算
loss.backward()
# パラメータを更新
optimizer.step()
損失関数として、レンダリングされた画像と目標画像の二乗誤差を使用しています。
8. 結果の可視化
最適化の過程と結果を可視化します。
python# 最適化の結果を表示
print(f"最適化前: x={initial_x_value:.4f}, y={initial_y_value:.4f}")
print(f"最適化後: x={optim_x.item():.4f}, y={optim_y.item():.4f}")
print(f"目標値: x={target_x.item():.4f}, y={target_y.item():.4f}")
勾配降下法の仕組み
このコードで学べる最も重要な概念は「勾配降下法」です。これは機械学習の最適化問題を解くための基本的なアルゴリズムで、以下のように機能します:
損失関数の定義: 現在の状態と目標状態の差を測る関数を定義
勾配の計算: 損失関数が最も急速に減少する方向(勾配)を計算
パラメータの更新: 計算された勾配の方向にパラメータを少しずつ更新
繰り返し: 損失が十分に小さくなるまで、または指定された回数まで繰り返す
PyTorchの自動微分機能(loss.backward())により、勾配の計算が自動的に行われるため、複雑なモデルでも効率的に最適化できます。
実行結果
最適化の過程を通じて、初期位置(x=-0.4, y=-0.3)からスタートした円が、目標位置(x=0.5, y=0.3)に向かって移動していきます。損失のグラフを見ると、イテレーションを重ねるごとに損失が減少し、最適化が進んでいることがわかります。
まとめ
この例は単純ですが、ニューラルネットワークの学習過程の本質を捉えています。ディープラーニングでは、何百万もの重みパラメータを同時に最適化しますが、基本的な原理はこの例と同じです。
PyTorchのような深層学習フレームワークの強みは、複雑な計算グラフの自動微分を可能にし、効率的な最適化を実現することです。
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
# 画像サイズ
width, height = 64, 64
# ターゲット(円)のパラメータ設定
target_x = torch.tensor(0.5, requires_grad=False)
target_y = torch.tensor(0.3, requires_grad=False)
target_radius = torch.tensor(0.2, requires_grad=False)
# 画像座標系の作成
image_x = torch.linspace(-1, 1, width)
image_y = torch.linspace(-1, 1, height)
image_y_grid, image_x_grid = torch.meshgrid(image_y, image_x)
# ターゲット画像の生成 - 修正した距離関数
dist_squared_target = (image_x_grid - target_x)**2 + (image_y_grid - target_y)**2
sharpness = 10.0
target_image = 0.5 - 0.5 * torch.tanh(sharpness * (dist_squared_target - target_radius**2))
# 最適化するパラメータ
optim_x = torch.tensor(-0.4, requires_grad=True)
optim_y = torch.tensor(-0.3, requires_grad=True)
initial_radius = torch.tensor(0.2, requires_grad=False)
# 最適化のハイパーパラメータ
iterations = 100 # イテレーション数を増やす
lr = 0.02 # 学習率を調整
# オプティマイザーの設定
parameters = [optim_x, optim_y]
optimizer = torch.optim.Adam(parameters, lr=lr)
# 損失の履歴を保存するリスト
loss_history = []
# 初期状態を保存
initial_x_value = optim_x.item()
initial_y_value = optim_y.item()
# 図の設定
plt.figure(figsize=(15, 5))
# 最適化ループ
for i in range(iterations):
# オプティマイザーの勾配をリセット
optimizer.zero_grad()
# 現在の状態をレンダリング - 修正した距離関数
dist_squared = (image_x_grid - optim_x)**2 + (image_y_grid - optim_y)**2
rendered_image = 0.5 - 0.5 * torch.tanh(sharpness * (dist_squared - initial_radius**2))
# 損失を計算 - sum を使用
loss = torch.sum((rendered_image - target_image)**2)
loss_history.append(loss.item())
# 勾配を計算
loss.backward()
# パラメータを更新
optimizer.step()
# 10イテレーションごとに表示
if (i + 1) % 10 == 0:
print(f"Iteration {i+1}/{iterations}, Loss: {loss.item():.6f}, x: {optim_x.item():.4f}, y: {optim_y.item():.4f}")
# 図を更新
plt.clf()
plt.subplot(1, 3, 1)
plt.imshow(target_image.detach().numpy(), cmap='gray')
plt.title(f'Target: x={target_x.item():.2f}, y={target_y.item():.2f}')
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(rendered_image.detach().numpy(), cmap='gray')
plt.title(f'Current: x={optim_x.item():.2f}, y={optim_y.item():.2f}')
plt.axis('off')
plt.subplot(1, 3, 3)
plt.plot(loss_history)
plt.title(f'Loss: {loss.item():.2f}')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.grid(True)
plt.tight_layout()
plt.pause(0.1)
# 最適化の結果を表示
print(f"最適化前: x={initial_x_value:.4f}, y={initial_y_value:.4f}")
print(f"最適化後: x={optim_x.item():.4f}, y={optim_y.item():.4f}")
print(f"目標値: x={target_x.item():.4f}, y={target_y.item():.4f}")
# 結果の可視化
plt.figure(figsize=(15, 5))
# 初期状態
plt.subplot(1, 3, 1)
initial_dist_squared = (image_x_grid - torch.tensor(initial_x_value))**2 + (image_y_grid - torch.tensor(initial_y_value))**2
initial_image = 0.5 - 0.5 * torch.tanh(sharpness * (initial_dist_squared - initial_radius**2))
plt.imshow(initial_image.detach().numpy(), cmap='gray')
plt.title(f'Initial: x={initial_x_value:.2f}, y={initial_y_value:.2f}')
plt.axis('off')
# 最適化後
plt.subplot(1, 3, 2)
plt.imshow(rendered_image.detach().numpy(), cmap='gray')
plt.title(f'Optimized: x={optim_x.item():.2f}, y={optim_y.item():.2f}')
plt.axis('off')
# ターゲット
plt.subplot(1, 3, 3)
plt.imshow(target_image.detach().numpy(), cmap='gray')
plt.title(f'Target: x={target_x.item():.2f}, y={target_y.item():.2f}')
plt.axis('off')
plt.tight_layout()
plt.show()
# 損失の推移をプロット
plt.figure(figsize=(10, 5))
plt.plot(loss_history)
plt.title('Optimization Loss')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.grid(True)
plt.show()
Discussion