🐄

PyTorch3D入門

2022/12/26に公開

はじめに

ほぼ横浜の民 Advent Calendar 2022の2022年12月26日の記事です。

近年Computer Visionと深層学習は切り離せない関係にありますね。
最近注目を集めている3Dに関する技術[1]でも深層学習の活用が発展しています。

というわけで、3D Computer Visionで深層学習を活用する手段であるPyTorch3Dについて、公式のチュートリアルをなぞりながら雰囲気を掴んでみようと思います。
※「3D Computer Visionで深層学習の技術を活用する手段」は他にも多くあり、点群DNNではPyTorch Geometricが、アカデミアの領域ではJax/Flaxもよく使われている印象です。これらも時間を取って勉強したいところです。

基本的な考え方

参考資料

コンピュータビジョン最前線 Autumn 2022の下記2章に目を通すと概要の理解が深まるのでおススメです。

  • ニュウモン微分可能レンダリング―3次元ビジョンの新潮流! 3次元再構成からNeRFまで―(加藤大晴[2]
  • イマドキノNeural Fields―3次元,4次元,N次元ビジョンの信号表現のパラダイムシフト!?―(瀧川永遠希)

https://www.kyoritsu-pub.co.jp/book/b10013880.html

また、今回入門するPyTorch3Dについては、PyTorch公式のYouTubeが勉強になります。
https://www.youtube.com/watch?v=Pph1r-x9nyY

(この記事を読むよりもこれらで勉強することをおススメします。)

微分可能レンダリング

最低限押さえておくべきことは、下図のDifferentiable Redering(微分可能レンダリング)です。(上記Youtube 9:54付近)
今回扱うPyTorch3Dによる処理の全体の流れは下記の通りです。概ね2Dでの深層学習と同じですね。微分可能レンダリングによって連鎖律で損失から入力までを繋げられていることがポイントかと思います。

  • Mesh・Texture・カメラパラメタ・光源といった情報を入力としてレンダリング
  • レンダリング結果とGTを比較する、良い感じの損失を定義
  • レンダリング過程が微分可能なことを活用し\frac{\partial L}{\partial (\mathrm{Mesh})}などを算出
  • 損失が小さくなるように入力パラメタを更新

Neural Fields

3D Computer Visionの中でも特に盛り上がっているNeRF(Neural Radiance Fields) は、微分可能レンダリングを使ったうえで、3D表現をMesh・TextureからNeural Fieldsに変更することで表現力や拡張性を広げていることがポイントだと思います。表現力や拡張性の高さからか、Neural Fields分野の研究は非常に盛り上がっています[3][4]

PyTorch3D

微分可能レンダリングを使うためのライブラリはいくつかありますが、PyTorchを使い慣れている場合にはPyTorch3Dが有力な選択肢になるかと思います。その他のライブラリとしてはMitsuba Rederer 2nvdiffrastがあるようです。

チュートリアル

メッシュの変形

まずは、正解のメッシュを知っている状況で、メッシュを変形させて近づけていくケースを扱います。ここでは微分可能レンダリングは使われていません。微分可能な操作によって入力から出力までを繋いでいること、3Dデータの扱い方、あたりが分かれば良いかと思います。
チュートリアルページ

学習結果 GT

学習

メッシュ自体を更新する代わりに、頂点位置の変化量(初期値からの差分)を更新していきます。
そのため更新する変数は下記のように宣言しています。

deform_verts = torch.full(src_mesh.verts_packed().shape, 0.0, device=device, requires_grad=True)

学習のループは画像分類等におけるPyTorchでの実装とほとんど変わりないですね。
ポイントは下記あたりでしょうか。損失関数に表れる4項のうち、Chamfer distanceがGTに近づける損失関数、その他がメッシュの滑らかさを保証する正則化項にあたります。

for i in loop:
    # Initialize optimizer
    optimizer.zero_grad()
    
    # Deform the mesh
    new_src_mesh = src_mesh.offset_verts(deform_verts)
    
    # We sample 5k points from the surface of each mesh 
    sample_trg = sample_points_from_meshes(trg_mesh, 5000)
    sample_src = sample_points_from_meshes(new_src_mesh, 5000)
    
    # We compare the two sets of pointclouds by computing (a) the chamfer loss
    loss_chamfer, _ = chamfer_distance(sample_trg, sample_src)
    
    # and (b) the edge length of the predicted mesh
    loss_edge = mesh_edge_loss(new_src_mesh)
    
    # mesh normal consistency
    loss_normal = mesh_normal_consistency(new_src_mesh)
    
    # mesh laplacian smoothing
    loss_laplacian = mesh_laplacian_smoothing(new_src_mesh, method="uniform")
    
    # Weighted sum of the losses
    loss = loss_chamfer * w_chamfer + loss_edge * w_edge + loss_normal * w_normal + loss_laplacian * w_laplacian
    
    # Print the losses
    loop.set_description('total_loss = %.6f' % loss)
    
    # Save the losses for plotting
    chamfer_losses.append(float(loss_chamfer.detach().cpu()))
    edge_losses.append(float(loss_edge.detach().cpu()))
    normal_losses.append(float(loss_normal.detach().cpu()))
    laplacian_losses.append(float(loss_laplacian.detach().cpu()))
    
    # Plot mesh
    if i % plot_period == 0:
        plot_pointcloud(new_src_mesh, title="iter: %d" % i)
        
    # Optimization step
    loss.backward()
    optimizer.step()

メッシュ・テクスチャの最適化

続いて、もう少し実践的なケースとして複数視点からの画像を既知情報としてメッシュ・テクスチャを最適化するパターンを扱います。
チュートリアルページ

学習結果 GT

※ チュートリアルのソースコードのまま実行した結果を載せています。ハイパーパラメータを調整することで学習結果の品位が高くなると思われます。

データセット

データセット構築のためのコードは下記の通りです。細かい部分は置いておくと、ラスタライザ・シェーダーの設定をしたレンダラーにメッシュ、カメラ、光源を入力することで各カメラにおける画像を生成していることが分かります。

num_views = 20

elev = torch.linspace(0, 360, num_views)
azim = torch.linspace(-180, 180, num_views)

lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])

R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim)
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)

camera = FoVPerspectiveCameras(device=device, R=R[None, 1, ...], 
                                  T=T[None, 1, ...]) 

raster_settings = RasterizationSettings(
    image_size=128, 
    blur_radius=0.0, 
    faces_per_pixel=1, 
)

renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=camera, 
        raster_settings=raster_settings
    ),
    shader=SoftPhongShader(
        device=device, 
        cameras=camera,
        lights=lights
    )
)

meshes = mesh.extend(num_views)

target_images = renderer(meshes, cameras=cameras, lights=lights)

学習

先ほどと同様、頂点位置の変化量(初期値からの差分)を更新していきます。またテクスチャについては各頂点ごとのRGB値を更新していきます。

verts_shape = src_mesh.verts_packed().shape
deform_verts = torch.full(verts_shape, 0.0, device=device, requires_grad=True)

# We will also learn per vertex colors for our sphere mesh that define texture 
# of the mesh
sphere_verts_rgb = torch.full([1, verts_shape[0], 3], 0.5, device=device, requires_grad=True)

損失関数については、下記のように定義しています。"edge", "normal", "laplacian"については先ほどの例と同じですね。"rgb", "silhouette"については学習ループの中で後述します。

losses = {"rgb": {"weight": 1.0, "values": []},
          "silhouette": {"weight": 0.6, "values": []},
          "edge": {"weight": 1.0, "values": []},
          "normal": {"weight": 0.01, "values": []},
          "laplacian": {"weight": 1.0, "values": []},
         }

# Losses to smooth / regularize the mesh shape
def update_mesh_shape_prior_losses(mesh, loss):
    # and (b) the edge length of the predicted mesh
    loss["edge"] = mesh_edge_loss(mesh)
    
    # mesh normal consistency
    loss["normal"] = mesh_normal_consistency(mesh)
    
    # mesh laplacian smoothing
    loss["laplacian"] = mesh_laplacian_smoothing(mesh, method="uniform")

学習ループの中では、データセットの中からランダムに視線方向を選択し、その方向からのシルエット(マスク)、RGB値それぞれのRMSEを算出しています。

loop = tqdm(range(Niter))

for i in loop:
    # Initialize optimizer
    optimizer.zero_grad()
    
    # Deform the mesh
    new_src_mesh = src_mesh.offset_verts(deform_verts)
    
    # Add per vertex colors to texture the mesh
    new_src_mesh.textures = TexturesVertex(verts_features=sphere_verts_rgb) 
    
    # Losses to smooth /regularize the mesh shape
    loss = {k: torch.tensor(0.0, device=device) for k in losses}
    update_mesh_shape_prior_losses(new_src_mesh, loss)
    
    # Randomly select two views to optimize over in this iteration.  Compared
    # to using just one view, this helps resolve ambiguities between updating
    # mesh shape vs. updating mesh texture
    for j in np.random.permutation(num_views).tolist()[:num_views_per_iteration]:
        images_predicted = renderer_textured(new_src_mesh, cameras=target_cameras[j], lights=lights)

        # Squared L2 distance between the predicted silhouette and the target 
        # silhouette from our dataset
        predicted_silhouette = images_predicted[..., 3]
        loss_silhouette = ((predicted_silhouette - target_silhouette[j]) ** 2).mean()
        loss["silhouette"] += loss_silhouette / num_views_per_iteration
        
        # Squared L2 distance between the predicted RGB image and the target 
        # image from our dataset
        predicted_rgb = images_predicted[..., :3]
        loss_rgb = ((predicted_rgb - target_rgb[j]) ** 2).mean()
        loss["rgb"] += loss_rgb / num_views_per_iteration
    
    # Weighted sum of the losses
    sum_loss = torch.tensor(0.0, device=device)
    for k, l in loss.items():
        sum_loss += l * losses[k]["weight"]
        losses[k]["values"].append(float(l.detach().cpu()))
    
    # Print the losses
    loop.set_description("total_loss = %.6f" % sum_loss)
    
    # Plot mesh
    if i % plot_period == 0:
        visualize_prediction(new_src_mesh, renderer=renderer_textured, title="iter: %d" % i, silhouette=False)
        
    # Optimization step
    sum_loss.backward()
    optimizer.step()

NeRF

さらに一歩踏み込んで、Neural Radiance Fields(Neural Fieldsの中でもRadianceに着目したもの)による3D再構成を扱います。
チュートリアルページ

学習結果 GT

※ チュートリアルのソースコードのまま実行した結果を載せています。高品位な結果を得るにはn_iter=20000が推奨されています。またNeRF自体もチュートリアル向けに簡易化されています。

Implicit renderer

Neural Radiance Fieldsによる表現からレンダリングするためのImplicit rendererは下記のようなraymacherとraysamplerから構成されます。このチュートリアルではNDCMultinomialRaysamplerとMonteCarloRaysamplerの2種類のraysamplerを使います。それぞれの役割は以下です。

  • NDCMultinomialRaysampler: PyTorch3Dの座標系を持つ。学習結果のレンダリングに利用。
  • MonteCarloRaysampler: 各ピクセルに対応する光線上の点をサンプリングする。学習時に利用。

また前者についての実装は下記の通りです。

raysampler_grid = NDCMultinomialRaysampler(
    image_height=render_size,
    image_width=render_size,
    n_pts_per_ray=128,
    min_depth=0.1,
    max_depth=volume_extent_world,
)

raymarcher = EmissionAbsorptionRaymarcher()

renderer_grid = ImplicitRenderer(
    raysampler=raysampler_grid, raymarcher=raymarcher,
)

Neural Radiance Field Model

少し長いですが、下記にNeRFのモデル部分の実装を転記しています。大枠としては、
3次元空間上のある点の座標と方向を持つray_bundleを入力として、RGB値(_get_colors)と密度(_get_densities)を出力するモデルであることの理解が重要です。その上で学習の品位を上げるためにPositional Encoding[5]やRGB値と密度の両方に関係するlatent feature spaceへ畳み込むMLPを持っていることなどを徐々に理解していけば良いと思います。

class NeuralRadianceField(torch.nn.Module):
    def __init__(self, n_harmonic_functions=60, n_hidden_neurons=256):
        super().__init__()
        """
        Args:
            n_harmonic_functions: The number of harmonic functions
                used to form the harmonic embedding of each point.
            n_hidden_neurons: The number of hidden units in the
                fully connected layers of the MLPs of the model.
        """
        
        # The harmonic embedding layer converts input 3D coordinates
        # to a representation that is more suitable for
        # processing with a deep neural network.
        self.harmonic_embedding = HarmonicEmbedding(n_harmonic_functions)
        
        # The dimension of the harmonic embedding.
        embedding_dim = n_harmonic_functions * 2 * 3
        
        # self.mlp is a simple 2-layer multi-layer perceptron
        # which converts the input per-point harmonic embeddings
        # to a latent representation.
        # Not that we use Softplus activations instead of ReLU.
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(embedding_dim, n_hidden_neurons),
            torch.nn.Softplus(beta=10.0),
            torch.nn.Linear(n_hidden_neurons, n_hidden_neurons),
            torch.nn.Softplus(beta=10.0),
        )        
        
        # Given features predicted by self.mlp, self.color_layer
        # is responsible for predicting a 3-D per-point vector
        # that represents the RGB color of the point.
        self.color_layer = torch.nn.Sequential(
            torch.nn.Linear(n_hidden_neurons + embedding_dim, n_hidden_neurons),
            torch.nn.Softplus(beta=10.0),
            torch.nn.Linear(n_hidden_neurons, 3),
            torch.nn.Sigmoid(),
            # To ensure that the colors correctly range between [0-1],
            # the layer is terminated with a sigmoid layer.
        )  
        
        # The density layer converts the features of self.mlp
        # to a 1D density value representing the raw opacity
        # of each point.
        self.density_layer = torch.nn.Sequential(
            torch.nn.Linear(n_hidden_neurons, 1),
            torch.nn.Softplus(beta=10.0),
            # Sofplus activation ensures that the raw opacity
            # is a non-negative number.
        )
        
        # We set the bias of the density layer to -1.5
        # in order to initialize the opacities of the
        # ray points to values close to 0. 
        # This is a crucial detail for ensuring convergence
        # of the model.
        self.density_layer[0].bias.data[0] = -1.5        
                
    def _get_densities(self, features):
        """
        This function takes `features` predicted by `self.mlp`
        and converts them to `raw_densities` with `self.density_layer`.
        `raw_densities` are later mapped to [0-1] range with
        1 - inverse exponential of `raw_densities`.
        """
        raw_densities = self.density_layer(features)
        return 1 - (-raw_densities).exp()
    
    def _get_colors(self, features, rays_directions):
        """
        This function takes per-point `features` predicted by `self.mlp`
        and evaluates the color model in order to attach to each
        point a 3D vector of its RGB color.
        
        In order to represent viewpoint dependent effects,
        before evaluating `self.color_layer`, `NeuralRadianceField`
        concatenates to the `features` a harmonic embedding
        of `ray_directions`, which are per-point directions 
        of point rays expressed as 3D l2-normalized vectors
        in world coordinates.
        """
        spatial_size = features.shape[:-1]
        
        # Normalize the ray_directions to unit l2 norm.
        rays_directions_normed = torch.nn.functional.normalize(
            rays_directions, dim=-1
        )
        
        # Obtain the harmonic embedding of the normalized ray directions.
        rays_embedding = self.harmonic_embedding(
            rays_directions_normed
        )
        
        # Expand the ray directions tensor so that its spatial size
        # is equal to the size of features.
        rays_embedding_expand = rays_embedding[..., None, :].expand(
            *spatial_size, rays_embedding.shape[-1]
        )
        
        # Concatenate ray direction embeddings with 
        # features and evaluate the color model.
        color_layer_input = torch.cat(
            (features, rays_embedding_expand),
            dim=-1
        )
        return self.color_layer(color_layer_input)
    
  
    def forward(
        self, 
        ray_bundle: RayBundle,
        **kwargs,
    ):
        """
        The forward function accepts the parametrizations of
        3D points sampled along projection rays. The forward
        pass is responsible for attaching a 3D vector
        and a 1D scalar representing the point's 
        RGB color and opacity respectively.
        
        Args:
            ray_bundle: A RayBundle object containing the following variables:
                origins: A tensor of shape `(minibatch, ..., 3)` denoting the
                    origins of the sampling rays in world coords.
                directions: A tensor of shape `(minibatch, ..., 3)`
                    containing the direction vectors of sampling rays in world coords.
                lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
                    containing the lengths at which the rays are sampled.

        Returns:
            rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`
                denoting the opacity of each ray point.
            rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
                denoting the color of each ray point.
        """
        # We first convert the ray parametrizations to world
        # coordinates with `ray_bundle_to_ray_points`.
        rays_points_world = ray_bundle_to_ray_points(ray_bundle)
        # rays_points_world.shape = [minibatch x ... x 3]
        
        # For each 3D world coordinate, we obtain its harmonic embedding.
        embeds = self.harmonic_embedding(
            rays_points_world
        )
        # embeds.shape = [minibatch x ... x self.n_harmonic_functions*6]
        
        # self.mlp maps each harmonic embedding to a latent feature space.
        features = self.mlp(embeds)
        # features.shape = [minibatch x ... x n_hidden_neurons]
        
        # Finally, given the per-point features, 
        # execute the density and color branches.
        
        rays_densities = self._get_densities(features)
        # rays_densities.shape = [minibatch x ... x 1]

        rays_colors = self._get_colors(features, ray_bundle.directions)
        # rays_colors.shape = [minibatch x ... x 3]
        
        return rays_densities, rays_colors
    
    def batched_forward(
        self, 
        ray_bundle: RayBundle,
        n_batches: int = 16,
        **kwargs,        
    ):
        """
        This function is used to allow for memory efficient processing
        of input rays. The input rays are first split to `n_batches`
        chunks and passed through the `self.forward` function one at a time
        in a for loop. Combined with disabling PyTorch gradient caching
        (`torch.no_grad()`), this allows for rendering large batches
        of rays that do not all fit into GPU memory in a single forward pass.
        In our case, batched_forward is used to export a fully-sized render
        of the radiance field for visualization purposes.
        
        Args:
            ray_bundle: A RayBundle object containing the following variables:
                origins: A tensor of shape `(minibatch, ..., 3)` denoting the
                    origins of the sampling rays in world coords.
                directions: A tensor of shape `(minibatch, ..., 3)`
                    containing the direction vectors of sampling rays in world coords.
                lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
                    containing the lengths at which the rays are sampled.
            n_batches: Specifies the number of batches the input rays are split into.
                The larger the number of batches, the smaller the memory footprint
                and the lower the processing speed.

        Returns:
            rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`
                denoting the opacity of each ray point.
            rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
                denoting the color of each ray point.

        """

        # Parse out shapes needed for tensor reshaping in this function.
        n_pts_per_ray = ray_bundle.lengths.shape[-1]  
        spatial_size = [*ray_bundle.origins.shape[:-1], n_pts_per_ray]

        # Split the rays to `n_batches` batches.
        tot_samples = ray_bundle.origins.shape[:-1].numel()
        batches = torch.chunk(torch.arange(tot_samples), n_batches)

        # For each batch, execute the standard forward pass.
        batch_outputs = [
            self.forward(
                RayBundle(
                    origins=ray_bundle.origins.view(-1, 3)[batch_idx],
                    directions=ray_bundle.directions.view(-1, 3)[batch_idx],
                    lengths=ray_bundle.lengths.view(-1, n_pts_per_ray)[batch_idx],
                    xys=None,
                )
            ) for batch_idx in batches
        ]
        
        # Concatenate the per-batch rays_densities and rays_colors
        # and reshape according to the sizes of the inputs.
        rays_densities, rays_colors = [
            torch.cat(
                [batch_output[output_i] for batch_output in batch_outputs], dim=0
            ).view(*spatial_size, -1) for output_i in (0, 1)
        ]
        return rays_densities, rays_colors

学習

学習のメイン部分の実装は下記の通りです。大まかな流れはこれまでと変わりないことが分かるかと思います。

  • バッチサイズ分だけランダムに視点を選択
  • 各視点ごとにシルエットとRGB値についてHubar損失を算出
  • 合計値を全体の損失として各パラメタを更新
# The main optimization loop.
for iteration in range(n_iter):      
    # In case we reached the last 75% of iterations,
    # decrease the learning rate of the optimizer 10-fold.
    if iteration == round(n_iter * 0.75):
        print('Decreasing LR 10-fold ...')
        optimizer = torch.optim.Adam(
            neural_radiance_field.parameters(), lr=lr * 0.1
        )
    
    # Zero the optimizer gradient.
    optimizer.zero_grad()
    
    # Sample random batch indices.
    batch_idx = torch.randperm(len(target_cameras))[:batch_size]
    
    # Sample the minibatch of cameras.
    batch_cameras = FoVPerspectiveCameras(
        R = target_cameras.R[batch_idx], 
        T = target_cameras.T[batch_idx], 
        znear = target_cameras.znear[batch_idx],
        zfar = target_cameras.zfar[batch_idx],
        aspect_ratio = target_cameras.aspect_ratio[batch_idx],
        fov = target_cameras.fov[batch_idx],
        device = device,
    )
    
    # Evaluate the nerf model.
    rendered_images_silhouettes, sampled_rays = renderer_mc(
        cameras=batch_cameras, 
        volumetric_function=neural_radiance_field
    )
    rendered_images, rendered_silhouettes = (
        rendered_images_silhouettes.split([3, 1], dim=-1)
    )
    
    # Compute the silhouette error as the mean huber
    # loss between the predicted masks and the
    # sampled target silhouettes.
    silhouettes_at_rays = sample_images_at_mc_locs(
        target_silhouettes[batch_idx, ..., None], 
        sampled_rays.xys
    )
    sil_err = huber(
        rendered_silhouettes, 
        silhouettes_at_rays,
    ).abs().mean()

    # Compute the color error as the mean huber
    # loss between the rendered colors and the
    # sampled target images.
    colors_at_rays = sample_images_at_mc_locs(
        target_images[batch_idx], 
        sampled_rays.xys
    )
    color_err = huber(
        rendered_images, 
        colors_at_rays,
    ).abs().mean()
    
    # The optimization loss is a simple
    # sum of the color and silhouette errors.
    loss = color_err + sil_err
    
    # Log the loss history.
    loss_history_color.append(float(color_err))
    loss_history_sil.append(float(sil_err))
    
    # Every 10 iterations, print the current values of the losses.
    if iteration % 10 == 0:
        print(
            f'Iteration {iteration:05d}:'
            + f' loss color = {float(color_err):1.2e}'
            + f' loss silhouette = {float(sil_err):1.2e}'
        )
    
    # Take the optimization step.
    loss.backward()
    optimizer.step()

終わりに

最後まで読んでいただきありがとうございました。ぼかして書いた部分や他のパートなどカバーできていない内容もかなりありますが、時間切れです。気が向いたら更新していきます。

脚注
  1. コンピュータビジョン最前線のテーマが2巻続けて3Dに関するものです。
    コンピュータビジョン最前線 Autumn 2022
    コンピュータビジョン最前線 Winter 2022 ↩︎

  2. 加藤さんは定期的に講演もされているようで、公開されている資料も大変勉強になります。
    https://docs.google.com/presentation/d/135VTaXJaESu4rDRLe_WbvSe4zq1__dFxBYI_7H3EjBI/edit#slide=id.g12d4f17af70_0_0
    https://docs.google.com/presentation/d/1TfQW_5H1jm9xnWsrujp9NsboZ5I5nD_hq0ksbcIPGGQ/mobilepresent?slide=id.p ↩︎

  3. https://neuralfields.cs.brown.edu/eg22.html ↩︎

  4. https://neuralfields.cs.brown.edu/index.html ↩︎

  5. NeRFにおけるPositional Encodingは未知ポジションからのレンダリング結果が鈍ることへの対策として高周波成分を失わないために導入されていた気がします。一方で、Vision Transformer等でのPositional Encodingは位置情報を埋め込むためのものであると理解しています。このあたりの関係性はどこかで整理したいです。 ↩︎

Discussion