🤖

微分可能シルエットレンダリングによる3Dキューブのポーズ最適化

に公開

3Dコンピュータグラフィックスと機械学習の融合は、近年めざましい進展を遂げています。今回は、微分可能レンダリングの手法を用いて、3Dキューブの姿勢(位置と回転)を推定する実験について紹介します。

概要

この手法では、目標となる3Dキューブのシルエット画像が与えられたとき、別の初期状態からスタートしたキューブを、目標シルエットと一致するように最適化します。具体的には、勾配降下法を用いて、キューブの位置(x, y, z)と回転(rx, ry, rz)のパラメータを少しずつ更新していきます。

技術的アプローチ

この実験では以下の技術を組み合わせています:

符号付き距離関数(SDF) - 3D空間上の任意の点から物体表面までの最短距離を計算

微分可能レンダリング - 画像の生成過程を微分可能にすることで、逆問題を解ける

シミュレーテッドアニーリング - 局所解に陥るのを防ぐためのランダム探索

実装の流れ

目標となるキューブの位置と回転を設定

目標シルエットを非微分のレイキャスティングでレンダリング

初期パラメータからスタートし、微分可能SDF関数を使ってシルエットを生成

生成シルエットと目標シルエットの差を損失として計算

勾配降下法でパラメータを更新

過程中にシミュレーテッドアニーリングを適用し、局所解を回避

結果と考察

最適化プロセスを通じて、初期状態(中央から少しずれた正面向きのキューブ)から、目標状態(右上に移動し、X軸に30度、Y軸に45度、Z軸に22.5度回転したキューブ)へ徐々に近づいていくことが確認できました。

500回のイテレーション後、キューブの位置と回転角度は目標値に非常に近い値に収束しました。特に損失グラフを見ると、最初の100イテレーションで急速に減少し、その後緩やかに改善していく様子が分かります。

実用的な応用先

この手法は以下のような分野で応用できます:

ロボットビジョンでの物体ポーズ推定

拡張現実(AR)での物体トラッキング

コンピュータビジョンでの3D再構成

医療画像処理での臓器位置推定

まとめ

微分可能レンダリングを用いた3Dポーズ最適化は、2D画像から3D情報を逆推論する強力な手法です。今回のシンプルなキューブの例でも、回転や位置のパラメータを正確に推定できることを示しました。

この技術の本質は、レンダリング過程を微分可能にすることで、「見た目」から「形状・姿勢」への逆問題を解けるようにした点にあります。これは従来のコンピュータグラフィックスの一方向性(形状から画像への生成)を超える革新的なアプローチといえるでしょう。

(コードは以下に一括で掲載します)

import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# 再現性のための乱数シード設定
torch.manual_seed(42)
np.random.seed(42)

# 画像の解像度
width, height = 100, 100

# 3Dキューブの頂点座標
vertices = torch.tensor([
    [-0.5, -0.5, 0.5],  # 0: 前面左下
    [-0.5, 0.5, 0.5],   # 1: 前面左上
    [0.5, 0.5, 0.5],    # 2: 前面右上
    [0.5, -0.5, 0.5],   # 3: 前面右下
    [-0.5, -0.5, -0.5], # 4: 背面左下
    [-0.5, 0.5, -0.5],  # 5: 背面左上
    [0.5, 0.5, -0.5],   # 6: 背面右上
    [0.5, -0.5, -0.5],  # 7: 背面右下
], dtype=torch.float32)

# キューブの三角形面(12枚の三角形、各面2枚ずつ)
faces = torch.tensor([
    [0, 1, 2], [0, 2, 3],  # 前面
    [4, 5, 6], [4, 6, 7],  # 背面
    [0, 1, 5], [0, 5, 4],  # 左面
    [3, 2, 6], [3, 6, 7],  # 右面
    [1, 2, 6], [1, 6, 5],  # 上面
    [0, 3, 7], [0, 7, 4]   # 下面
], dtype=torch.int64)

# キューブ用のSDF(符号付き距離関数)
def cube_sdf(points, center, half_size, rotation_matrix):
    # 点を物体の座標系に変換(回転と平行移動の逆変換)
    # 回転の逆変換
    inv_rotation = rotation_matrix.transpose(0, 1)
    # 平行移動の逆変換を適用した後、回転の逆変換を適用
    centered_points = points - center
    transformed_points = torch.matmul(centered_points, inv_rotation)
    
    # キューブのSDF計算
    # 各軸に対する距離を計算
    d = torch.abs(transformed_points) - half_size
    
    # 外部の距離:max(d, 0)のノルム
    outside_distance = torch.norm(torch.maximum(d, torch.zeros_like(d)), dim=1)
    
    # 内部の距離:min(max(d.x, d.y, d.z), 0)
    inside_distance = torch.minimum(torch.max(d[:, 0], torch.max(d[:, 1], d[:, 2])), torch.zeros(d.shape[0]))
    
    # 合計SDF
    return outside_distance + inside_distance

# メイン実行部分
if __name__ == "__main__":
    print("微分可能シルエットレンダリングの初期化...")
    
    # カメラパラメータ
    camera_position = torch.tensor([0.0, 0.0, -3.0], requires_grad=False)
    focal_length = 1.0
    
    # 目標パラメータ(真の位置と回転)
    target_translation = torch.tensor([0.5, 0.3, 0.0], requires_grad=False)
    target_rotation_x = torch.tensor(np.pi/6, requires_grad=False)  # 30度
    target_rotation_y = torch.tensor(np.pi/4, requires_grad=False)  # 45度
    target_rotation_z = torch.tensor(np.pi/8, requires_grad=False)  # 22.5度
    
    # 画像平面の正規化座標の準備
    x = torch.linspace(-1, 1, width)
    y = torch.linspace(-1, 1, height)
    y_grid, x_grid = torch.meshgrid(y, x, indexing='ij')
    
    # 各ピクセルのレイ方向を作成
    ray_dirs = torch.stack([
        x_grid.flatten(), 
        y_grid.flatten(), 
        torch.ones_like(x_grid.flatten()) * focal_length
    ], dim=1)
    
    # レイ方向を正規化
    ray_lengths = torch.sqrt(torch.sum(ray_dirs**2, dim=1, keepdim=True))
    ray_dirs = ray_dirs / ray_lengths
    
    # 回転行列を作成する関数(自動微分のためにPyTorch演算を使用)
    def create_rotation_matrix(rx, ry, rz):
        # 各軸の回転を計算するためのサイン・コサイン
        cos_x, sin_x = torch.cos(rx), torch.sin(rx)
        cos_y, sin_y = torch.cos(ry), torch.sin(ry)
        cos_z, sin_z = torch.cos(rz), torch.sin(rz)
        
        # X軸周りの回転行列
        R_x = torch.zeros((3, 3), dtype=torch.float32)
        R_x[0, 0] = 1.0
        R_x[1, 1] = cos_x
        R_x[1, 2] = -sin_x
        R_x[2, 1] = sin_x
        R_x[2, 2] = cos_x
        
        # Y軸周りの回転行列
        R_y = torch.zeros((3, 3), dtype=torch.float32)
        R_y[0, 0] = cos_y
        R_y[0, 2] = sin_y
        R_y[1, 1] = 1.0
        R_y[2, 0] = -sin_y
        R_y[2, 2] = cos_y
        
        # Z軸周りの回転行列
        R_z = torch.zeros((3, 3), dtype=torch.float32)
        R_z[0, 0] = cos_z
        R_z[0, 1] = -sin_z
        R_z[1, 0] = sin_z
        R_z[1, 1] = cos_z
        R_z[2, 2] = 1.0
        
        # 合成回転行列(順序: Z → Y → X)
        return torch.matmul(torch.matmul(R_z, R_y), R_x)
    
    # 頂点を回転と平行移動で変換する関数
    def transform_vertices(vertices, rotation_matrix, translation):
        # 回転を適用
        rotated_vertices = torch.matmul(vertices, rotation_matrix.transpose(-1, -2))
        # 平行移動を適用
        transformed_vertices = rotated_vertices + translation
        return transformed_vertices
    
    # 非微分関数: レイキャスティングによるシルエット生成(ターゲット用)
    def render_silhouette(transformed_vertices, faces, camera_position, ray_dirs):
        # 衝突マスクの初期化
        hit_mask = torch.zeros(ray_dirs.shape[0], dtype=torch.float32)
        
        # 各三角形に対して、すべてのレイとの交差をチェック
        for face_idx in range(faces.shape[0]):
            face = faces[face_idx]
            
            # 三角形の頂点を取得
            v0 = transformed_vertices[face[0]]
            v1 = transformed_vertices[face[1]]
            v2 = transformed_vertices[face[2]]
            
            # 三角形のエッジを計算
            edge1 = v1 - v0
            edge2 = v2 - v0
            
            # Möller–Trumboreアルゴリズムによるレイと三角形の交差判定
            h = torch.cross(ray_dirs, edge2.unsqueeze(0).expand(ray_dirs.shape[0], -1))
            a = torch.sum(edge1.unsqueeze(0).expand(ray_dirs.shape[0], -1) * h, dim=1)
            
            # 平行に近いレイを除外
            epsilon = 1e-8
            mask = torch.abs(a) > epsilon
            
            if not mask.any():
                continue
                
            # 交差パラメータの計算
            f = torch.zeros_like(a)
            f[mask] = 1.0 / a[mask]
            
            s = camera_position.unsqueeze(0) - v0.unsqueeze(0)
            u = f * torch.sum(s * h, dim=1)
            
            # 交差が三角形内にあるかチェック
            mask = mask & (u >= 0.0) & (u <= 1.0)
            
            if not mask.any():
                continue
                
            q = torch.cross(s, edge1.unsqueeze(0).expand(ray_dirs.shape[0], -1))
            v = f * torch.sum(ray_dirs * q, dim=1)
            
            mask = mask & (v >= 0.0) & (u + v <= 1.0)
            
            if not mask.any():
                continue
                
            # 交差がカメラの前にあるか確認(tが正)
            t = f * torch.sum(edge2.unsqueeze(0).expand(ray_dirs.shape[0], -1) * q, dim=1)
            mask = mask & (t > epsilon)
            
            # ヒットマスクを更新
            hit_mask[mask] = 1.0
        
        # 画像サイズに整形
        return hit_mask.reshape(height, width)
    
    # ターゲットシルエットの生成
    print("ターゲットシルエットの生成...")
    target_rotation_matrix = create_rotation_matrix(
        target_rotation_x, target_rotation_y, target_rotation_z
    )
    target_transformed_vertices = transform_vertices(
        vertices, target_rotation_matrix, target_translation
    )
    
    with torch.no_grad():
        target_silhouette = render_silhouette(
            target_transformed_vertices, faces, camera_position, ray_dirs
        )
    
    # 最適化するパラメータの初期値設定
    translation = torch.tensor([-0.3, -0.2, 0.0], requires_grad=True)
    rotation_x = torch.tensor(0.0, requires_grad=True)
    rotation_y = torch.tensor(0.0, requires_grad=True)
    rotation_z = torch.tensor(0.0, requires_grad=True)
    
    # 学習率設定(回転は高めに設定)
    lr_translation = 0.01
    lr_rotation = 0.1
    iterations = 500  # イテレーション数を増やす
    
    # オプティマイザの作成(並行と回転で別々に設定)
    optimizer_translation = torch.optim.Adam([translation], lr=lr_translation)
    optimizer_rotation = torch.optim.Adam([rotation_x, rotation_y, rotation_z], lr=lr_rotation)
    
    # キューブの半サイズ(各軸方向の長さの半分)
    half_size = torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32)
    
    # 微分可能なレンダラー(キューブのSDFを使用)
    def differentiable_render_cube(translation, rotation_matrix, camera_position, ray_dirs, sharpness=10.0):
        # カメラ位置から各レイ方向に沿ったサンプルポイントを生成
        # 複数のサンプル距離
        sample_distances = torch.linspace(0.1, 10.0, 50)
        
        # 全てのレイに対するサンプルポイントを計算
        n_rays = ray_dirs.shape[0]
        n_samples = sample_distances.shape[0]
        
        all_sample_points = camera_position.unsqueeze(0).unsqueeze(1) + \
                           ray_dirs.unsqueeze(1) * sample_distances.unsqueeze(0).unsqueeze(2)
        
        # all_sample_pointsの形状: [n_rays, n_samples, 3]
        # n_rays * n_samplesの2D配列に平坦化
        flat_sample_points = all_sample_points.view(-1, 3)
        
        # 各サンプルポイントに対してSDFを計算
        sdf_values = cube_sdf(flat_sample_points, translation, half_size, rotation_matrix)
        
        # SDFを元の形状に戻す [n_rays, n_samples]
        sdf_values = sdf_values.view(n_rays, n_samples)
        
        # 各レイに沿った最小SDF値を見つける(オブジェクトに最も近い点)
        min_sdf_values, _ = torch.min(sdf_values, dim=1)
        
        # 最小SDF値をシルエット値に変換(SDF値が負ならオブジェクト内、正ならオブジェクト外)
        silhouette = torch.sigmoid(-sharpness * min_sdf_values)
        
        return silhouette.reshape(height, width)
    
    # 最適化の進捗を記録
    loss_history = []
    param_history = []
    
    # アニーリングパラメータ
    annealing_temp_initial = 0.2
    annealing_temp = annealing_temp_initial
    annealing_decay = 0.98
    rotation_gradient_scale = 10.0  # 回転勾配スケール
    
    # 最適化ループ
    print("最適化を開始...")
    for i in tqdm(range(iterations)):
        optimizer_translation.zero_grad()
        optimizer_rotation.zero_grad()
        
        # 現在の回転行列を作成
        rotation_matrix = create_rotation_matrix(rotation_x, rotation_y, rotation_z)
        
        # 現在のシルエットをレンダリング
        current_silhouette = differentiable_render_cube(
            translation, rotation_matrix, camera_position, ray_dirs
        )
        
        # 損失を計算(MSE)
        loss = torch.mean((current_silhouette - target_silhouette) ** 2)
        
        # 履歴を保存
        loss_history.append(loss.item())
        param_history.append({
            'tx': translation[0].item(),
            'ty': translation[1].item(),
            'tz': translation[2].item(),
            'rx': rotation_x.item(),
            'ry': rotation_y.item(),
            'rz': rotation_z.item()
        })
        
        # 勾配を計算
        loss.backward()
        
        # 勾配のデバッグ出力
        if i == 0 or (i + 1) % 50 == 0:
            print(f"  Translation gradient: {translation.grad}")
            print(f"  Rotation gradient before scaling: ({rotation_x.grad}, {rotation_y.grad}, {rotation_z.grad})")
        
        # 回転の勾配をスケーリング
        if rotation_x.grad is not None:
            rotation_x.grad *= rotation_gradient_scale
        if rotation_y.grad is not None:
            rotation_y.grad *= rotation_gradient_scale
        if rotation_z.grad is not None:
            rotation_z.grad *= rotation_gradient_scale
        
        if i == 0 or (i + 1) % 50 == 0:
            print(f"  Rotation gradient after scaling: ({rotation_x.grad}, {rotation_y.grad}, {rotation_z.grad})")
        
        # パラメータ更新
        optimizer_translation.step()
        optimizer_rotation.step()
        
        # シミュレーテッドアニーリング(初期の150イテレーションのみ)
        if i % 10 == 0 and i < 150:
            with torch.no_grad():
                # 回転にランダムな摂動を加える
                rotation_x.add_(torch.randn(1).item() * annealing_temp)
                rotation_y.add_(torch.randn(1).item() * annealing_temp)
                rotation_z.add_(torch.randn(1).item() * annealing_temp)
                
                # 温度を下げる
                annealing_temp *= annealing_decay
        
        # 回転角を正規化(-πからπの範囲に制限)
        with torch.no_grad():
            rotation_x.data = torch.remainder(rotation_x.data + np.pi, 2 * np.pi) - np.pi
            rotation_y.data = torch.remainder(rotation_y.data + np.pi, 2 * np.pi) - np.pi
            rotation_z.data = torch.remainder(rotation_z.data + np.pi, 2 * np.pi) - np.pi
        
        # 進捗表示
        if (i + 1) % 50 == 0:
            print(f"Iteration {i+1}/{iterations}, Loss: {loss.item():.6f}")
            print(f"  Translation: ({translation[0].item():.4f}, {translation[1].item():.4f}, {translation[2].item():.4f})")
            print(f"  Rotation: ({rotation_x.item()*180/np.pi:.1f}°, {rotation_y.item()*180/np.pi:.1f}°, {rotation_z.item()*180/np.pi:.1f}°)")
    
    # 最終結果
    print("\n最適化完了!")
    print("ターゲットパラメータ:")
    print(f"  Translation: ({target_translation[0].item():.4f}, {target_translation[1].item():.4f}, {target_translation[2].item():.4f})")
    print(f"  Rotation: ({target_rotation_x.item()*180/np.pi:.1f}°, {target_rotation_y.item()*180/np.pi:.1f}°, {target_rotation_z.item()*180/np.pi:.1f}°)")
    print("\n最適化後のパラメータ:")
    print(f"  Translation: ({translation[0].item():.4f}, {translation[1].item():.4f}, {translation[2].item():.4f})")
    print(f"  Rotation: ({rotation_x.item()*180/np.pi:.1f}°, {rotation_y.item()*180/np.pi:.1f}°, {rotation_z.item()*180/np.pi:.1f}°)")
    print(f"最終損失: {loss_history[-1]:.6f}")
    
    # 結果の可視化
    plt.figure(figsize=(15, 10))
    
    # ターゲットシルエット
    plt.subplot(2, 3, 1)
    plt.imshow(target_silhouette.detach().numpy(), cmap='gray')
    plt.title('ターゲットシルエット')
    plt.axis('off')
    
    # 初期シルエット
    with torch.no_grad():
        initial_rotation_matrix = create_rotation_matrix(
            torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)
        )
        initial_silhouette = differentiable_render_cube(
            torch.tensor([-0.3, -0.2, 0.0]), 
            initial_rotation_matrix, 
            camera_position, 
            ray_dirs
        )
    
    plt.subplot(2, 3, 2)
    plt.imshow(initial_silhouette.detach().numpy(), cmap='gray')
    plt.title('初期シルエット')
    plt.axis('off')
    
    # 最適化後のシルエット
    with torch.no_grad():
        final_rotation_matrix = create_rotation_matrix(
            rotation_x, rotation_y, rotation_z
        )
        final_silhouette = differentiable_render_cube(
            translation, 
            final_rotation_matrix, 
            camera_position, 
            ray_dirs
        )
    
    plt.subplot(2, 3, 3)
    plt.imshow(final_silhouette.detach().numpy(), cmap='gray')
    plt.title('最適化後のシルエット')
    plt.axis('off')
    
    # 損失履歴
    plt.subplot(2, 3, 4)
    plt.plot(loss_history)
    plt.title('損失の推移')
    plt.xlabel('イテレーション')
    plt.ylabel('損失')
    plt.grid(True)
    
    # 平行移動履歴
    plt.subplot(2, 3, 5)
    tx = [p['tx'] for p in param_history]
    ty = [p['ty'] for p in param_history]
    tz = [p['tz'] for p in param_history]
    plt.plot(tx, label='tx')
    plt.plot(ty, label='ty')
    plt.plot(tz, label='tz')
    plt.axhline(y=target_translation[0].item(), color='r', linestyle='--', alpha=0.5, label='Target tx')
    plt.axhline(y=target_translation[1].item(), color='g', linestyle='--', alpha=0.5, label='Target ty')
    plt.axhline(y=target_translation[2].item(), color='b', linestyle='--', alpha=0.5, label='Target tz')
    plt.title('平行移動の推移')
    plt.xlabel('イテレーション')
    plt.ylabel('位置')
    plt.legend()
    plt.grid(True)
    
    # 回転履歴
    plt.subplot(2, 3, 6)
    rx = [p['rx']*180/np.pi for p in param_history]
    ry = [p['ry']*180/np.pi for p in param_history]
    rz = [p['rz']*180/np.pi for p in param_history]
    plt.plot(rx, label='rx')
    plt.plot(ry, label='ry')
    plt.plot(rz, label='rz')
    plt.axhline(y=target_rotation_x.item()*180/np.pi, color='r', linestyle='--', alpha=0.5, label='Target rx')
    plt.axhline(y=target_rotation_y.item()*180/np.pi, color='g', linestyle='--', alpha=0.5, label='Target ry')
    plt.axhline(y=target_rotation_z.item()*180/np.pi, color='b', linestyle='--', alpha=0.5, label='Target rz')
    plt.title('回転角の推移 (度)')
    plt.xlabel('イテレーション')
    plt.ylabel('角度')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.show()

Discussion