🎨

PyTorchでRGBからLab色空間へ変換する

2023/12/13に公開

Jooyeol Yun, Sanghyeon Lee, Minho Park, Jaegul Choo
"iColoriT: Towards Propagating Local Hint to the Right Region in Interactive Colorization by Leveraging Vision Transformer"
https://arxiv.org/abs/2207.06831

OpenCVや大規模なライブラリを介さずに、torch tensor内で完結してLab色空間がほしいとき、以下のコードが非常に役に立った。

https://github.com/pmh9960/iColoriT/blob/main/utils.py

この記事ではtorch.utils.data.Datasetなどで色変換を使う場合を想定して、torchvision.transforms風に使えるように整形した。Lab色空間では、L方向に輝度情報が集約されるため、例えば畳み込み演算によるブラーリングやエッジ抽出を1Channelで可能になる。また、L空間を適応的ヒストグラム平坦化(CLAHE; Contrast Limited Adaptive Histogram Equalization)することで色情報を保存しながら適応的輝度調整が可能になる。

// TODO CLAHEのtorch実装も行いたい

http://labs.eecs.tottori-u.ac.jp/sd/Member/oyamada/OpenCV/html/py_tutorials/py_imgproc/py_histograms/py_histogram_equalization/py_histogram_equalization.html

コードは引用元のライブラリで、MIT license。この論文の他にもColorization系の実装がほぼ同じ形でutil.pyとして実装されていたのでおそらく安心して動かしていいはず。

RGB to Lab

my_util.py
class RgbToLab(nn.Module):
    def __init__(self):
        super().__init__()
        self.l_cent = 50.
        self.l_norm = 100.
        self.ab_norm = 110.
        return None

    def rgb2xyz(self, rgb):  # rgb from [0,1]
        # xyz_from_rgb = np.array([
        # [0.412453, 0.357580, 0.180423],
        # [0.212671, 0.715160, 0.072169],
        # [0.019334, 0.119193, 0.950227]
        # ])
        mask = (rgb > .04045).type(torch.FloatTensor)
        if (rgb.is_cuda):
            mask = mask.cuda()
        rgb = (((rgb + .055) / 1.055)**2.4) * mask + rgb / 12.92 * (1 - mask)
        x = .412453 * rgb[:, 0, :, :] + .357580 * rgb[:, 1, :, :] + .180423 * rgb[:, 2, :, :]
        y = .212671 * rgb[:, 0, :, :] + .715160 * rgb[:, 1, :, :] + .072169 * rgb[:, 2, :, :]
        z = .019334 * rgb[:, 0, :, :] + .119193 * rgb[:, 1, :, :] + .950227 * rgb[:, 2, :, :]
        out = torch.cat((x[:, None, :, :], y[:, None, :, :], z[:, None, :, :]), dim=1)
        return out

    def xyz2lab(self, xyz):
        # 0.95047, 1., 1.08883 # white
        sc = torch.Tensor((0.95047, 1., 1.08883))[None, :, None, None]
        if (xyz.is_cuda):
            sc = sc.cuda()
        xyz_scale = xyz / sc
        mask = (xyz_scale > .008856).type(torch.FloatTensor)
        if (xyz_scale.is_cuda):
            mask = mask.cuda()
        xyz_int = xyz_scale**(1 / 3.) * mask + (7.787 * xyz_scale + 16. / 116.) * (1 - mask)
        L = 116. * xyz_int[:, 1, :, :] - 16.
        a = 500. * (xyz_int[:, 0, :, :] - xyz_int[:, 1, :, :])
        b = 200. * (xyz_int[:, 1, :, :] - xyz_int[:, 2, :, :])
        out = torch.cat((L[:, None, :, :], a[:, None, :, :], b[:, None, :, :]), dim=1)
        return out

    def rgb2lab(self, rgb):
        lab = self.xyz2lab(self.rgb2xyz(rgb))
        l_rs = (lab[:, [0], :, :] - self.l_cent) / self.l_norm
        ab_rs = lab[:, 1:, :, :] / self.ab_norm
        out = torch.cat((l_rs, ab_rs), dim=1)
        return out

    def forward(self, img):
        img = self.rgb2lab(img)
        return img

Lab to RGB

my_util.py
class LabToRgb(nn.Module):
    def __init__(self):
        super().__init__()
        self.l_cent = 50.
        self.l_norm = 100.
        self.ab_norm = 110.
        return None

    def lab2xyz(self, lab):
        y_int = (lab[:, 0, :, :] + 16.) / 116.
        x_int = (lab[:, 1, :, :] / 500.) + y_int
        z_int = y_int - (lab[:, 2, :, :] / 200.)
        if (z_int.is_cuda):
            z_int = torch.max(torch.Tensor((0,)).cuda(), z_int)
        else:
            z_int = torch.max(torch.Tensor((0,)), z_int)
        out = torch.cat(
            (x_int[:, None, :, :], y_int[:, None, :, :], z_int[:, None, :, :]), dim=1)
        mask = (out > .2068966).type(torch.FloatTensor)
        if (out.is_cuda):
            mask = mask.cuda()
        out = (out**3.) * mask + (out - 16. / 116.) / 7.787 * (1 - mask)
        sc = torch.Tensor((0.95047, 1., 1.08883))[None, :, None, None]
        sc = sc.to(out.device)
        out = out * sc
        return out

    def xyz2rgb(self, xyz):
        # array([[ 3.24048134, -1.53715152, -0.49853633],
        #        [-0.96925495,  1.87599   ,  0.04155593],
        #        [ 0.05564664, -0.20404134,  1.05731107]])
        r = 3.24048134 * xyz[:, 0, :, :] - 1.53715152 * xyz[:, 1, :, :] - 0.49853633 * xyz[:, 2, :, :]
        g = -0.96925495 * xyz[:, 0, :, :] + 1.87599 * xyz[:, 1, :, :] + .04155593 * xyz[:, 2, :, :]
        b = .05564664 * xyz[:, 0, :, :] - .20404134 * xyz[:, 1, :, :] + 1.05731107 * xyz[:, 2, :, :]
        rgb = torch.cat((r[:, None, :, :], g[:, None, :, :], b[:, None, :, :]), dim=1)
        # sometimes reaches a small negative number, which causes NaNs
        rgb = torch.max(rgb, torch.zeros_like(rgb))
        mask = (rgb > .0031308).type(torch.FloatTensor)
        if (rgb.is_cuda):
            mask = mask.cuda()
        rgb = (1.055 * (rgb**(1. / 2.4)) - 0.055) * mask + 12.92 * rgb * (1 - mask)
        return rgb

    def lab2rgb(self, lab_rs):
        l = lab_rs[:, [0], :, :] * self.l_norm + self.l_cent
        ab = lab_rs[:, 1:, :, :] * self.ab_norm
        lab = torch.cat((l, ab), dim=1)
        out = self.xyz2rgb(self.lab2xyz(lab))
        return out

    def forward(self, img):
        img = self.lab2rgb(img)
        return img

おまけ: 画像テンソルのベクトル化

最初次の実装をtorchで書き直すという練習をしていた。

https://kibata-ai-labo.com/programming/convert_srgb2lab/

しかしlinearRGB ↔ XYZ色空間の変形部np.dot(rgb, XYZ_MATRIX)や、XYZ ↔ Labの変換

# f(x)
def trans_function(x, threshold=0.008856):
    y = x**(1/3) if x>threshold else (841/108)*x+(16/116)
    return y

# XYZ -> L*a*b*
def xyz2lab(x, y, z, xn=0.9505, yn=1.0, zn=1.089):
    l = 116 * trans_function(x/yn) - 16
    a = 500 * ( trans_function(x/xn) - trans_function(y/yn) )
    b = 200 * ( trans_function(y/yn) - trans_function(z/zn) )
    return l, a, b

に困難があり、低速な実装しか組めなかった。
残念ながら使えるものにはならなかったが、テンソルのVectorizationや関数のマップを練習するいい機会になったので、以下にメモを残しておく。

画像の行列積化

\bold{\texttt{img}} \in \mathbb{R}^{3×H×W}\bold{M} \in \mathbb{R}^{3×3}の状況で画像の各channelへ行列演算を行うとき、\bold{\texttt{img}}の空間方向H,Wで平坦化することで、\bold{\texttt{img}} \in \mathbb{R}^{3×HW}にできる。これに以下の計算を行うことで、各画素のChannelの線形変換が可能になる。Channle方向への一般化も成り立つ。

\bold{\texttt{img}}_2 = \bold{M} \bold{\texttt{img}}_1
img_1 = torch.ones(3,256,256)
matrix = torch.onse(3,3)

img_2 = torch.mm(img.view(3,256*256), matrix).view(3,256,256)

テンソルへの非線形関数適用(遅い原因)

NumPyのnp.apply_along_axisを自作するもの。以下の回答を少し改造して可変引数関数に対応させた。(上記記事中のxyz2labをベクトル化画像に適用するため)

https://discuss.pytorch.org/t/apply-a-function-along-an-axis/130440

def apply_along_axis(function, x, axis=0):
    # for部分が速度的に悪さをする
    return torch.stack([
        function(*x_i) for x_i in torch.unbind(x, dim=axis)
    ], dim=axis)

f = lambda x,y,z: (y*z**2, z*x**2, x*y**2)  # 3変数→3変数の非線形変換
    
# 以下の変形でChannel方向の各要素を各引数とする変換が可能
img_2 = apply_along_axis(f,img.view(3,256*256), axis=0).view(3,256,256)

Discussion