🤖
微分可能シルエットレンダリングによる3Dキューブのポーズ最適化
3Dコンピュータグラフィックスと機械学習の融合は、近年めざましい進展を遂げています。今回は、微分可能レンダリングの手法を用いて、3Dキューブの姿勢(位置と回転)を推定する実験について紹介します。
概要
この手法では、目標となる3Dキューブのシルエット画像が与えられたとき、別の初期状態からスタートしたキューブを、目標シルエットと一致するように最適化します。具体的には、勾配降下法を用いて、キューブの位置(x, y, z)と回転(rx, ry, rz)のパラメータを少しずつ更新していきます。
技術的アプローチ
この実験では以下の技術を組み合わせています:
符号付き距離関数(SDF) - 3D空間上の任意の点から物体表面までの最短距離を計算
微分可能レンダリング - 画像の生成過程を微分可能にすることで、逆問題を解ける
シミュレーテッドアニーリング - 局所解に陥るのを防ぐためのランダム探索
実装の流れ
目標となるキューブの位置と回転を設定
目標シルエットを非微分のレイキャスティングでレンダリング
初期パラメータからスタートし、微分可能SDF関数を使ってシルエットを生成
生成シルエットと目標シルエットの差を損失として計算
勾配降下法でパラメータを更新
過程中にシミュレーテッドアニーリングを適用し、局所解を回避
結果と考察
最適化プロセスを通じて、初期状態(中央から少しずれた正面向きのキューブ)から、目標状態(右上に移動し、X軸に30度、Y軸に45度、Z軸に22.5度回転したキューブ)へ徐々に近づいていくことが確認できました。
500回のイテレーション後、キューブの位置と回転角度は目標値に非常に近い値に収束しました。特に損失グラフを見ると、最初の100イテレーションで急速に減少し、その後緩やかに改善していく様子が分かります。
実用的な応用先
この手法は以下のような分野で応用できます:
ロボットビジョンでの物体ポーズ推定
拡張現実(AR)での物体トラッキング
コンピュータビジョンでの3D再構成
医療画像処理での臓器位置推定
まとめ
微分可能レンダリングを用いた3Dポーズ最適化は、2D画像から3D情報を逆推論する強力な手法です。今回のシンプルなキューブの例でも、回転や位置のパラメータを正確に推定できることを示しました。
この技術の本質は、レンダリング過程を微分可能にすることで、「見た目」から「形状・姿勢」への逆問題を解けるようにした点にあります。これは従来のコンピュータグラフィックスの一方向性(形状から画像への生成)を超える革新的なアプローチといえるでしょう。
(コードは以下に一括で掲載します)
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
# 再現性のための乱数シード設定
torch.manual_seed(42)
np.random.seed(42)
# 画像の解像度
width, height = 100, 100
# 3Dキューブの頂点座標
vertices = torch.tensor([
[-0.5, -0.5, 0.5], # 0: 前面左下
[-0.5, 0.5, 0.5], # 1: 前面左上
[0.5, 0.5, 0.5], # 2: 前面右上
[0.5, -0.5, 0.5], # 3: 前面右下
[-0.5, -0.5, -0.5], # 4: 背面左下
[-0.5, 0.5, -0.5], # 5: 背面左上
[0.5, 0.5, -0.5], # 6: 背面右上
[0.5, -0.5, -0.5], # 7: 背面右下
], dtype=torch.float32)
# キューブの三角形面(12枚の三角形、各面2枚ずつ)
faces = torch.tensor([
[0, 1, 2], [0, 2, 3], # 前面
[4, 5, 6], [4, 6, 7], # 背面
[0, 1, 5], [0, 5, 4], # 左面
[3, 2, 6], [3, 6, 7], # 右面
[1, 2, 6], [1, 6, 5], # 上面
[0, 3, 7], [0, 7, 4] # 下面
], dtype=torch.int64)
# キューブ用のSDF(符号付き距離関数)
def cube_sdf(points, center, half_size, rotation_matrix):
# 点を物体の座標系に変換(回転と平行移動の逆変換)
# 回転の逆変換
inv_rotation = rotation_matrix.transpose(0, 1)
# 平行移動の逆変換を適用した後、回転の逆変換を適用
centered_points = points - center
transformed_points = torch.matmul(centered_points, inv_rotation)
# キューブのSDF計算
# 各軸に対する距離を計算
d = torch.abs(transformed_points) - half_size
# 外部の距離:max(d, 0)のノルム
outside_distance = torch.norm(torch.maximum(d, torch.zeros_like(d)), dim=1)
# 内部の距離:min(max(d.x, d.y, d.z), 0)
inside_distance = torch.minimum(torch.max(d[:, 0], torch.max(d[:, 1], d[:, 2])), torch.zeros(d.shape[0]))
# 合計SDF
return outside_distance + inside_distance
# メイン実行部分
if __name__ == "__main__":
print("微分可能シルエットレンダリングの初期化...")
# カメラパラメータ
camera_position = torch.tensor([0.0, 0.0, -3.0], requires_grad=False)
focal_length = 1.0
# 目標パラメータ(真の位置と回転)
target_translation = torch.tensor([0.5, 0.3, 0.0], requires_grad=False)
target_rotation_x = torch.tensor(np.pi/6, requires_grad=False) # 30度
target_rotation_y = torch.tensor(np.pi/4, requires_grad=False) # 45度
target_rotation_z = torch.tensor(np.pi/8, requires_grad=False) # 22.5度
# 画像平面の正規化座標の準備
x = torch.linspace(-1, 1, width)
y = torch.linspace(-1, 1, height)
y_grid, x_grid = torch.meshgrid(y, x, indexing='ij')
# 各ピクセルのレイ方向を作成
ray_dirs = torch.stack([
x_grid.flatten(),
y_grid.flatten(),
torch.ones_like(x_grid.flatten()) * focal_length
], dim=1)
# レイ方向を正規化
ray_lengths = torch.sqrt(torch.sum(ray_dirs**2, dim=1, keepdim=True))
ray_dirs = ray_dirs / ray_lengths
# 回転行列を作成する関数(自動微分のためにPyTorch演算を使用)
def create_rotation_matrix(rx, ry, rz):
# 各軸の回転を計算するためのサイン・コサイン
cos_x, sin_x = torch.cos(rx), torch.sin(rx)
cos_y, sin_y = torch.cos(ry), torch.sin(ry)
cos_z, sin_z = torch.cos(rz), torch.sin(rz)
# X軸周りの回転行列
R_x = torch.zeros((3, 3), dtype=torch.float32)
R_x[0, 0] = 1.0
R_x[1, 1] = cos_x
R_x[1, 2] = -sin_x
R_x[2, 1] = sin_x
R_x[2, 2] = cos_x
# Y軸周りの回転行列
R_y = torch.zeros((3, 3), dtype=torch.float32)
R_y[0, 0] = cos_y
R_y[0, 2] = sin_y
R_y[1, 1] = 1.0
R_y[2, 0] = -sin_y
R_y[2, 2] = cos_y
# Z軸周りの回転行列
R_z = torch.zeros((3, 3), dtype=torch.float32)
R_z[0, 0] = cos_z
R_z[0, 1] = -sin_z
R_z[1, 0] = sin_z
R_z[1, 1] = cos_z
R_z[2, 2] = 1.0
# 合成回転行列(順序: Z → Y → X)
return torch.matmul(torch.matmul(R_z, R_y), R_x)
# 頂点を回転と平行移動で変換する関数
def transform_vertices(vertices, rotation_matrix, translation):
# 回転を適用
rotated_vertices = torch.matmul(vertices, rotation_matrix.transpose(-1, -2))
# 平行移動を適用
transformed_vertices = rotated_vertices + translation
return transformed_vertices
# 非微分関数: レイキャスティングによるシルエット生成(ターゲット用)
def render_silhouette(transformed_vertices, faces, camera_position, ray_dirs):
# 衝突マスクの初期化
hit_mask = torch.zeros(ray_dirs.shape[0], dtype=torch.float32)
# 各三角形に対して、すべてのレイとの交差をチェック
for face_idx in range(faces.shape[0]):
face = faces[face_idx]
# 三角形の頂点を取得
v0 = transformed_vertices[face[0]]
v1 = transformed_vertices[face[1]]
v2 = transformed_vertices[face[2]]
# 三角形のエッジを計算
edge1 = v1 - v0
edge2 = v2 - v0
# Möller–Trumboreアルゴリズムによるレイと三角形の交差判定
h = torch.cross(ray_dirs, edge2.unsqueeze(0).expand(ray_dirs.shape[0], -1))
a = torch.sum(edge1.unsqueeze(0).expand(ray_dirs.shape[0], -1) * h, dim=1)
# 平行に近いレイを除外
epsilon = 1e-8
mask = torch.abs(a) > epsilon
if not mask.any():
continue
# 交差パラメータの計算
f = torch.zeros_like(a)
f[mask] = 1.0 / a[mask]
s = camera_position.unsqueeze(0) - v0.unsqueeze(0)
u = f * torch.sum(s * h, dim=1)
# 交差が三角形内にあるかチェック
mask = mask & (u >= 0.0) & (u <= 1.0)
if not mask.any():
continue
q = torch.cross(s, edge1.unsqueeze(0).expand(ray_dirs.shape[0], -1))
v = f * torch.sum(ray_dirs * q, dim=1)
mask = mask & (v >= 0.0) & (u + v <= 1.0)
if not mask.any():
continue
# 交差がカメラの前にあるか確認(tが正)
t = f * torch.sum(edge2.unsqueeze(0).expand(ray_dirs.shape[0], -1) * q, dim=1)
mask = mask & (t > epsilon)
# ヒットマスクを更新
hit_mask[mask] = 1.0
# 画像サイズに整形
return hit_mask.reshape(height, width)
# ターゲットシルエットの生成
print("ターゲットシルエットの生成...")
target_rotation_matrix = create_rotation_matrix(
target_rotation_x, target_rotation_y, target_rotation_z
)
target_transformed_vertices = transform_vertices(
vertices, target_rotation_matrix, target_translation
)
with torch.no_grad():
target_silhouette = render_silhouette(
target_transformed_vertices, faces, camera_position, ray_dirs
)
# 最適化するパラメータの初期値設定
translation = torch.tensor([-0.3, -0.2, 0.0], requires_grad=True)
rotation_x = torch.tensor(0.0, requires_grad=True)
rotation_y = torch.tensor(0.0, requires_grad=True)
rotation_z = torch.tensor(0.0, requires_grad=True)
# 学習率設定(回転は高めに設定)
lr_translation = 0.01
lr_rotation = 0.1
iterations = 500 # イテレーション数を増やす
# オプティマイザの作成(並行と回転で別々に設定)
optimizer_translation = torch.optim.Adam([translation], lr=lr_translation)
optimizer_rotation = torch.optim.Adam([rotation_x, rotation_y, rotation_z], lr=lr_rotation)
# キューブの半サイズ(各軸方向の長さの半分)
half_size = torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32)
# 微分可能なレンダラー(キューブのSDFを使用)
def differentiable_render_cube(translation, rotation_matrix, camera_position, ray_dirs, sharpness=10.0):
# カメラ位置から各レイ方向に沿ったサンプルポイントを生成
# 複数のサンプル距離
sample_distances = torch.linspace(0.1, 10.0, 50)
# 全てのレイに対するサンプルポイントを計算
n_rays = ray_dirs.shape[0]
n_samples = sample_distances.shape[0]
all_sample_points = camera_position.unsqueeze(0).unsqueeze(1) + \
ray_dirs.unsqueeze(1) * sample_distances.unsqueeze(0).unsqueeze(2)
# all_sample_pointsの形状: [n_rays, n_samples, 3]
# n_rays * n_samplesの2D配列に平坦化
flat_sample_points = all_sample_points.view(-1, 3)
# 各サンプルポイントに対してSDFを計算
sdf_values = cube_sdf(flat_sample_points, translation, half_size, rotation_matrix)
# SDFを元の形状に戻す [n_rays, n_samples]
sdf_values = sdf_values.view(n_rays, n_samples)
# 各レイに沿った最小SDF値を見つける(オブジェクトに最も近い点)
min_sdf_values, _ = torch.min(sdf_values, dim=1)
# 最小SDF値をシルエット値に変換(SDF値が負ならオブジェクト内、正ならオブジェクト外)
silhouette = torch.sigmoid(-sharpness * min_sdf_values)
return silhouette.reshape(height, width)
# 最適化の進捗を記録
loss_history = []
param_history = []
# アニーリングパラメータ
annealing_temp_initial = 0.2
annealing_temp = annealing_temp_initial
annealing_decay = 0.98
rotation_gradient_scale = 10.0 # 回転勾配スケール
# 最適化ループ
print("最適化を開始...")
for i in tqdm(range(iterations)):
optimizer_translation.zero_grad()
optimizer_rotation.zero_grad()
# 現在の回転行列を作成
rotation_matrix = create_rotation_matrix(rotation_x, rotation_y, rotation_z)
# 現在のシルエットをレンダリング
current_silhouette = differentiable_render_cube(
translation, rotation_matrix, camera_position, ray_dirs
)
# 損失を計算(MSE)
loss = torch.mean((current_silhouette - target_silhouette) ** 2)
# 履歴を保存
loss_history.append(loss.item())
param_history.append({
'tx': translation[0].item(),
'ty': translation[1].item(),
'tz': translation[2].item(),
'rx': rotation_x.item(),
'ry': rotation_y.item(),
'rz': rotation_z.item()
})
# 勾配を計算
loss.backward()
# 勾配のデバッグ出力
if i == 0 or (i + 1) % 50 == 0:
print(f" Translation gradient: {translation.grad}")
print(f" Rotation gradient before scaling: ({rotation_x.grad}, {rotation_y.grad}, {rotation_z.grad})")
# 回転の勾配をスケーリング
if rotation_x.grad is not None:
rotation_x.grad *= rotation_gradient_scale
if rotation_y.grad is not None:
rotation_y.grad *= rotation_gradient_scale
if rotation_z.grad is not None:
rotation_z.grad *= rotation_gradient_scale
if i == 0 or (i + 1) % 50 == 0:
print(f" Rotation gradient after scaling: ({rotation_x.grad}, {rotation_y.grad}, {rotation_z.grad})")
# パラメータ更新
optimizer_translation.step()
optimizer_rotation.step()
# シミュレーテッドアニーリング(初期の150イテレーションのみ)
if i % 10 == 0 and i < 150:
with torch.no_grad():
# 回転にランダムな摂動を加える
rotation_x.add_(torch.randn(1).item() * annealing_temp)
rotation_y.add_(torch.randn(1).item() * annealing_temp)
rotation_z.add_(torch.randn(1).item() * annealing_temp)
# 温度を下げる
annealing_temp *= annealing_decay
# 回転角を正規化(-πからπの範囲に制限)
with torch.no_grad():
rotation_x.data = torch.remainder(rotation_x.data + np.pi, 2 * np.pi) - np.pi
rotation_y.data = torch.remainder(rotation_y.data + np.pi, 2 * np.pi) - np.pi
rotation_z.data = torch.remainder(rotation_z.data + np.pi, 2 * np.pi) - np.pi
# 進捗表示
if (i + 1) % 50 == 0:
print(f"Iteration {i+1}/{iterations}, Loss: {loss.item():.6f}")
print(f" Translation: ({translation[0].item():.4f}, {translation[1].item():.4f}, {translation[2].item():.4f})")
print(f" Rotation: ({rotation_x.item()*180/np.pi:.1f}°, {rotation_y.item()*180/np.pi:.1f}°, {rotation_z.item()*180/np.pi:.1f}°)")
# 最終結果
print("\n最適化完了!")
print("ターゲットパラメータ:")
print(f" Translation: ({target_translation[0].item():.4f}, {target_translation[1].item():.4f}, {target_translation[2].item():.4f})")
print(f" Rotation: ({target_rotation_x.item()*180/np.pi:.1f}°, {target_rotation_y.item()*180/np.pi:.1f}°, {target_rotation_z.item()*180/np.pi:.1f}°)")
print("\n最適化後のパラメータ:")
print(f" Translation: ({translation[0].item():.4f}, {translation[1].item():.4f}, {translation[2].item():.4f})")
print(f" Rotation: ({rotation_x.item()*180/np.pi:.1f}°, {rotation_y.item()*180/np.pi:.1f}°, {rotation_z.item()*180/np.pi:.1f}°)")
print(f"最終損失: {loss_history[-1]:.6f}")
# 結果の可視化
plt.figure(figsize=(15, 10))
# ターゲットシルエット
plt.subplot(2, 3, 1)
plt.imshow(target_silhouette.detach().numpy(), cmap='gray')
plt.title('ターゲットシルエット')
plt.axis('off')
# 初期シルエット
with torch.no_grad():
initial_rotation_matrix = create_rotation_matrix(
torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)
)
initial_silhouette = differentiable_render_cube(
torch.tensor([-0.3, -0.2, 0.0]),
initial_rotation_matrix,
camera_position,
ray_dirs
)
plt.subplot(2, 3, 2)
plt.imshow(initial_silhouette.detach().numpy(), cmap='gray')
plt.title('初期シルエット')
plt.axis('off')
# 最適化後のシルエット
with torch.no_grad():
final_rotation_matrix = create_rotation_matrix(
rotation_x, rotation_y, rotation_z
)
final_silhouette = differentiable_render_cube(
translation,
final_rotation_matrix,
camera_position,
ray_dirs
)
plt.subplot(2, 3, 3)
plt.imshow(final_silhouette.detach().numpy(), cmap='gray')
plt.title('最適化後のシルエット')
plt.axis('off')
# 損失履歴
plt.subplot(2, 3, 4)
plt.plot(loss_history)
plt.title('損失の推移')
plt.xlabel('イテレーション')
plt.ylabel('損失')
plt.grid(True)
# 平行移動履歴
plt.subplot(2, 3, 5)
tx = [p['tx'] for p in param_history]
ty = [p['ty'] for p in param_history]
tz = [p['tz'] for p in param_history]
plt.plot(tx, label='tx')
plt.plot(ty, label='ty')
plt.plot(tz, label='tz')
plt.axhline(y=target_translation[0].item(), color='r', linestyle='--', alpha=0.5, label='Target tx')
plt.axhline(y=target_translation[1].item(), color='g', linestyle='--', alpha=0.5, label='Target ty')
plt.axhline(y=target_translation[2].item(), color='b', linestyle='--', alpha=0.5, label='Target tz')
plt.title('平行移動の推移')
plt.xlabel('イテレーション')
plt.ylabel('位置')
plt.legend()
plt.grid(True)
# 回転履歴
plt.subplot(2, 3, 6)
rx = [p['rx']*180/np.pi for p in param_history]
ry = [p['ry']*180/np.pi for p in param_history]
rz = [p['rz']*180/np.pi for p in param_history]
plt.plot(rx, label='rx')
plt.plot(ry, label='ry')
plt.plot(rz, label='rz')
plt.axhline(y=target_rotation_x.item()*180/np.pi, color='r', linestyle='--', alpha=0.5, label='Target rx')
plt.axhline(y=target_rotation_y.item()*180/np.pi, color='g', linestyle='--', alpha=0.5, label='Target ry')
plt.axhline(y=target_rotation_z.item()*180/np.pi, color='b', linestyle='--', alpha=0.5, label='Target rz')
plt.title('回転角の推移 (度)')
plt.xlabel('イテレーション')
plt.ylabel('角度')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
Discussion