PyTorchでRGBからLab色空間へ変換する
Jooyeol Yun, Sanghyeon Lee, Minho Park, Jaegul Choo
"iColoriT: Towards Propagating Local Hint to the Right Region in Interactive Colorization by Leveraging Vision Transformer"
https://arxiv.org/abs/2207.06831
OpenCVや大規模なライブラリを介さずに、torch tensor内で完結してLab色空間がほしいとき、以下のコードが非常に役に立った。
この記事ではtorch.utils.data.Dataset
などで色変換を使う場合を想定して、torchvision.transforms
風に使えるように整形した。Lab色空間では、L方向に輝度情報が集約されるため、例えば畳み込み演算によるブラーリングやエッジ抽出を1Channelで可能になる。また、L空間を適応的ヒストグラム平坦化(CLAHE; Contrast Limited Adaptive Histogram Equalization)することで色情報を保存しながら適応的輝度調整が可能になる。
// TODO CLAHEのtorch実装も行いたい
コードは引用元のライブラリで、MIT license。この論文の他にもColorization系の実装がほぼ同じ形でutil.pyとして実装されていたのでおそらく安心して動かしていいはず。
RGB to Lab
class RgbToLab(nn.Module):
def __init__(self):
super().__init__()
self.l_cent = 50.
self.l_norm = 100.
self.ab_norm = 110.
return None
def rgb2xyz(self, rgb): # rgb from [0,1]
# xyz_from_rgb = np.array([
# [0.412453, 0.357580, 0.180423],
# [0.212671, 0.715160, 0.072169],
# [0.019334, 0.119193, 0.950227]
# ])
mask = (rgb > .04045).type(torch.FloatTensor)
if (rgb.is_cuda):
mask = mask.cuda()
rgb = (((rgb + .055) / 1.055)**2.4) * mask + rgb / 12.92 * (1 - mask)
x = .412453 * rgb[:, 0, :, :] + .357580 * rgb[:, 1, :, :] + .180423 * rgb[:, 2, :, :]
y = .212671 * rgb[:, 0, :, :] + .715160 * rgb[:, 1, :, :] + .072169 * rgb[:, 2, :, :]
z = .019334 * rgb[:, 0, :, :] + .119193 * rgb[:, 1, :, :] + .950227 * rgb[:, 2, :, :]
out = torch.cat((x[:, None, :, :], y[:, None, :, :], z[:, None, :, :]), dim=1)
return out
def xyz2lab(self, xyz):
# 0.95047, 1., 1.08883 # white
sc = torch.Tensor((0.95047, 1., 1.08883))[None, :, None, None]
if (xyz.is_cuda):
sc = sc.cuda()
xyz_scale = xyz / sc
mask = (xyz_scale > .008856).type(torch.FloatTensor)
if (xyz_scale.is_cuda):
mask = mask.cuda()
xyz_int = xyz_scale**(1 / 3.) * mask + (7.787 * xyz_scale + 16. / 116.) * (1 - mask)
L = 116. * xyz_int[:, 1, :, :] - 16.
a = 500. * (xyz_int[:, 0, :, :] - xyz_int[:, 1, :, :])
b = 200. * (xyz_int[:, 1, :, :] - xyz_int[:, 2, :, :])
out = torch.cat((L[:, None, :, :], a[:, None, :, :], b[:, None, :, :]), dim=1)
return out
def rgb2lab(self, rgb):
lab = self.xyz2lab(self.rgb2xyz(rgb))
l_rs = (lab[:, [0], :, :] - self.l_cent) / self.l_norm
ab_rs = lab[:, 1:, :, :] / self.ab_norm
out = torch.cat((l_rs, ab_rs), dim=1)
return out
def forward(self, img):
img = self.rgb2lab(img)
return img
Lab to RGB
class LabToRgb(nn.Module):
def __init__(self):
super().__init__()
self.l_cent = 50.
self.l_norm = 100.
self.ab_norm = 110.
return None
def lab2xyz(self, lab):
y_int = (lab[:, 0, :, :] + 16.) / 116.
x_int = (lab[:, 1, :, :] / 500.) + y_int
z_int = y_int - (lab[:, 2, :, :] / 200.)
if (z_int.is_cuda):
z_int = torch.max(torch.Tensor((0,)).cuda(), z_int)
else:
z_int = torch.max(torch.Tensor((0,)), z_int)
out = torch.cat(
(x_int[:, None, :, :], y_int[:, None, :, :], z_int[:, None, :, :]), dim=1)
mask = (out > .2068966).type(torch.FloatTensor)
if (out.is_cuda):
mask = mask.cuda()
out = (out**3.) * mask + (out - 16. / 116.) / 7.787 * (1 - mask)
sc = torch.Tensor((0.95047, 1., 1.08883))[None, :, None, None]
sc = sc.to(out.device)
out = out * sc
return out
def xyz2rgb(self, xyz):
# array([[ 3.24048134, -1.53715152, -0.49853633],
# [-0.96925495, 1.87599 , 0.04155593],
# [ 0.05564664, -0.20404134, 1.05731107]])
r = 3.24048134 * xyz[:, 0, :, :] - 1.53715152 * xyz[:, 1, :, :] - 0.49853633 * xyz[:, 2, :, :]
g = -0.96925495 * xyz[:, 0, :, :] + 1.87599 * xyz[:, 1, :, :] + .04155593 * xyz[:, 2, :, :]
b = .05564664 * xyz[:, 0, :, :] - .20404134 * xyz[:, 1, :, :] + 1.05731107 * xyz[:, 2, :, :]
rgb = torch.cat((r[:, None, :, :], g[:, None, :, :], b[:, None, :, :]), dim=1)
# sometimes reaches a small negative number, which causes NaNs
rgb = torch.max(rgb, torch.zeros_like(rgb))
mask = (rgb > .0031308).type(torch.FloatTensor)
if (rgb.is_cuda):
mask = mask.cuda()
rgb = (1.055 * (rgb**(1. / 2.4)) - 0.055) * mask + 12.92 * rgb * (1 - mask)
return rgb
def lab2rgb(self, lab_rs):
l = lab_rs[:, [0], :, :] * self.l_norm + self.l_cent
ab = lab_rs[:, 1:, :, :] * self.ab_norm
lab = torch.cat((l, ab), dim=1)
out = self.xyz2rgb(self.lab2xyz(lab))
return out
def forward(self, img):
img = self.lab2rgb(img)
return img
おまけ: 画像テンソルのベクトル化
最初次の実装をtorch
で書き直すという練習をしていた。
しかしlinearRGB ↔ XYZ色空間の変形部np.dot(rgb, XYZ_MATRIX)
や、XYZ ↔ Labの変換
# f(x)
def trans_function(x, threshold=0.008856):
y = x**(1/3) if x>threshold else (841/108)*x+(16/116)
return y
# XYZ -> L*a*b*
def xyz2lab(x, y, z, xn=0.9505, yn=1.0, zn=1.089):
l = 116 * trans_function(x/yn) - 16
a = 500 * ( trans_function(x/xn) - trans_function(y/yn) )
b = 200 * ( trans_function(y/yn) - trans_function(z/zn) )
return l, a, b
に困難があり、低速な実装しか組めなかった。
残念ながら使えるものにはならなかったが、テンソルのVectorizationや関数のマップを練習するいい機会になったので、以下にメモを残しておく。
画像の行列積化
img_1 = torch.ones(3,256,256)
matrix = torch.onse(3,3)
img_2 = torch.mm(img.view(3,256*256), matrix).view(3,256,256)
テンソルへの非線形関数適用(遅い原因)
NumPyのnp.apply_along_axis
を自作するもの。以下の回答を少し改造して可変引数関数に対応させた。(上記記事中のxyz2lab
をベクトル化画像に適用するため)
def apply_along_axis(function, x, axis=0):
# for部分が速度的に悪さをする
return torch.stack([
function(*x_i) for x_i in torch.unbind(x, dim=axis)
], dim=axis)
f = lambda x,y,z: (y*z**2, z*x**2, x*y**2) # 3変数→3変数の非線形変換
# 以下の変形でChannel方向の各要素を各引数とする変換が可能
img_2 = apply_along_axis(f,img.view(3,256*256), axis=0).view(3,256,256)
Discussion