🕯️

PyTorchで形に制約を加えたパラメータを実装する方法

2023/08/29に公開

PyTorchで、三角行列や対角行列などの形になるよう制約を加えてパラメータを実装する方法を紹介します。ここで紹介した方法を派生すれば、任意の形を保ったパラメータを実装できるはずです。

# 下三角行列の形のパラメータ     # 対角行列の形のパラメータ
[[a11,   0,   0,   0],         [[a11,   0,   0,   0],
 [a21, a22,   0,   0],          [  0, a22,   0,   0],  
 [a31, a32, a33,   0],          [  0,   0, a33,   0],
 [a41, a42, a43, a44]]          [  0,   0,   0, a44]]

以下では、三角行列対角行列の形のパラメータの実装方法を、例としてそれぞれ紹介します。

三角行列の形のパラメータ

例として、実装するモデルに次の式が登場したとします。

A = LL^{\top}

Aは正方行列、Lは下三角行列です。いずれもn\times n行列とします。Lをパラメータとして実装し、Aを計算で求めたい場合は、以下のプログラムで実現可能です。

import torch
import torch.nn as nn

class TriModel(nn.Module):
    def __init__(self, n) -> None:
        super().__init__()
        self.n = n
        num_L_elem = n * (n + 1) // 2  # Lの下三角部分の要素数
        self.L_elem = nn.Parameter(torch.randn(num_L_elem))  # Lの要素をパラメータとして保持
	
	# モデルに紐づく定数を定義する。この定数は学習されない。
        # 定数は self.zero_mat で利用可能。
        self.zero_mat_ = torch.zeros(n, n)
        self.register_buffer("zero_mat", self.zero_mat_)

    def _make_L(self) -> torch.Tensor:
        """下三角行列Lを作る"""
        L = self.zero_mat.clone()  # 下三角行列のベースとなる零行列をコピー
        L_indices = torch.tril_indices(self.n, self.n)  # Lの下三角部分のインデックスをまとめて取得
        L[*L_indices] = self.L_elem  # self.L_elemのパラメータから、下三角行列を作成
        return L

    def _calc_A(self) -> torch.Tensor:
        """Aを計算する"""
        L = self._make_L()
        A = torch.matmul(L, L.t())
        return A

    (以下省略)

今回の例ではパラメータを下三角行列の形にしましたが、上三角行列の形にしたい場合は、_make_Lメソッドにあるtorch.tril_indices(...)[1]torch.triu_indices(...)[2]に書き換えます。

対角行列の形のパラメータ

先程と同様、L_indicesのようにインデックスをまとめた変数を用意すれば、対角行列(Dとします)の形のパラメータを実装できます。以下では、先程とは異なり、インデックスを直接指定する方法で実装しています。

import torch
import torch.nn as nn

class DiagModel(nn.Module):
    def __init__(self, n) -> None:
        super().__init__()
        self.n = n
        num_D_elem = n  # Dの対角部分の要素数
        self.D_elem = nn.Parameter(torch.randn(num_D_elem))  # Dの要素をパラメータとして保持
	
	# モデルに紐づく定数を定義
        self.zero_mat_ = torch.zeros(n, n)
        self.register_buffer("zero_mat", self.zero_mat_)

    def _make_D(self) -> torch.Tensor:
        """対角行列Dを作る"""
        D = self.zero_mat.clone()  # 対角行列のベースとなる零行列をコピー
        D[range(self.n), range(self.n)] = self.D_elem
        return D

    (以下省略)

参考文献

  1. python - Enforcing a structure in a nn.Parameter (matrix) parameter in Pytorch - Stack Overflow
  2. PyTorch Moduleに紐づく定数のtensorを定義する|gota_morishita
  3. Why model.to(device) wouldn't put tensors on a custom layer to the same device? - PyTorch Forums
  4. What is the difference between register_buffer and register_parameter of nn.Module - PyTorch Forums
  5. python - Replace diagonal elements with vector in PyTorch - Stack Overflow
脚注
  1. torch.tril_indices — PyTorch 2.0 documentation ↩︎

  2. torch.triu_indices — PyTorch 2.0 documentation ↩︎

Discussion