✂️

PyTorchで画像を円形にマスクする

2024/01/10に公開

画像テンソル\texttt{img} \in \mathbb{R}^{H×W}に対して、中央を中心とした半径r = \min(H,W)/2 - \text{offset}の円で画像を切り抜く。

circlecrop.py
import torch
import torchvision

img = torchvision.io.read_image("sample.jpg")  # shape [3,H,W]
_, h, w = img.shape
offset = 30

x, y = torch.meshgrid(torch.arange(-w//2,w//2), torch.arange(-h//2,h//2), indexing='ij')
crop_r = min(h,w)//2 - offset
mask = torch.where(x**2 + y**2 < crop_r**2, 1, 0).view(1,h,w).repeat(3,1,1)
img = img * mask

torchvision.io.write_jpeg(img.to(torch.uint8), "output.jpg")

xyのメッシュグリッドを式で評価するスタイルなので、 torch.where(x**2 + y**2 < crop_r**2, 1, 0)部分を工夫すれば任意の形状に切り抜ける。

Discussion