👌

PyTorchオンリーで実装する適応的ヒストグラム平坦化(CLAHE)

2025/02/23に公開

CLAHEが使いたいが、Open-CVだとtorch Datasetに対するAugmentationなどで使えない。
Korniaに実装が含まれているのだが、これだけのためにインストールするのもなぁ…と思って、実装部分のみを抜抄してコピペ完結型の1枚のコードにしていきたい。

https://github.com/kornia/kornia/


端的に、以下のプログラムのequalize_claheを呼べば良い。
引数は以下の通り。

  • input: [B, 1, H, W] の輝度画像。 0.0 ~ 1.0の値であることに注意 ← Uint8ではなくfloat
  • clip_limit: 正のfloatの値。Open-CVのデフォルトは2.0
  • grid_size: 正のintの値の組。Open-CVのデフォルトは(8, 8)

clip limitとgrid sizeについては以下のような変化になる。

import math

import torch
from torch import Tensor
from torch.nn import functional as F


def marginal_pdf(values: Tensor, bins: Tensor, sigma: Tensor, epsilon: float = 1e-10) -> tuple[Tensor, Tensor]:
    residuals = values - bins.unsqueeze(0).unsqueeze(0)
    kernel_values = torch.exp(-0.5 * (residuals / sigma).pow(2))
    pdf = torch.mean(kernel_values, dim=1)
    normalization = torch.sum(pdf, dim=1).unsqueeze(1) + epsilon
    pdf = pdf / normalization
    return pdf, kernel_values


def histogram(x: Tensor, bins: Tensor, bandwidth: Tensor, epsilon: float = 1e-10) -> Tensor:
    pdf, _ = marginal_pdf(x.unsqueeze(2), bins, bandwidth, epsilon)
    return pdf


def _torch_histc_cast(input: Tensor, bins: int, min: int, max: int) -> Tensor:
    dtype: torch.dtype = input.dtype
    if dtype not in (torch.float32, torch.float64):
        dtype = torch.float32
    return torch.histc(input.to(dtype), bins, min, max).to(input.dtype)


def _compute_tiles(
    imgs: torch.Tensor, grid_size: tuple[int, int], even_tile_size: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
    batch: torch.Tensor = imgs  # B x C x H x W

    # compute stride and kernel size
    h, w = batch.shape[-2:]
    kernel_vert: int = math.ceil(h / grid_size[0])
    kernel_horz: int = math.ceil(w / grid_size[1])

    if even_tile_size:
        kernel_vert += 1 if kernel_vert % 2 else 0
        kernel_horz += 1 if kernel_horz % 2 else 0

    # add padding (with that kernel size we could need some extra cols and rows...)
    pad_vert = kernel_vert * grid_size[0] - h
    pad_horz = kernel_horz * grid_size[1] - w

    # add the padding in the last coluns and rows
    if pad_vert > batch.shape[-2] or pad_horz > batch.shape[-1]:
        raise ValueError("Cannot compute tiles on the image according to the given grid size")

    if pad_vert > 0 or pad_horz > 0:
        batch = F.pad(batch, [0, pad_horz, 0, pad_vert], mode="reflect")  # B x C x H' x W'

    # compute tiles
    c: int = batch.shape[-3]
    tiles: torch.Tensor = (
        batch.unfold(1, c, c)  # unfold(dimension, size, step)
        .unfold(2, kernel_vert, kernel_vert)
        .unfold(3, kernel_horz, kernel_horz)
        .squeeze(1)
    ).contiguous()  # GH x GW x C x TH x TW
    return tiles, batch


def _compute_interpolation_tiles(padded_imgs: torch.Tensor, tile_size: tuple[int, int]) -> torch.Tensor:
    # tiles to be interpolated are built by dividing in 4 each already existing
    interp_kernel_vert: int = tile_size[0] // 2
    interp_kernel_horz: int = tile_size[1] // 2

    c: int = padded_imgs.shape[-3]
    interp_tiles: torch.Tensor = (
        padded_imgs.unfold(1, c, c)
        .unfold(2, interp_kernel_vert, interp_kernel_vert)
        .unfold(3, interp_kernel_horz, interp_kernel_horz)
        .squeeze(1)
    ).contiguous()  # 2GH x 2GW x C x TH/2 x TW/2
    return interp_tiles


def _my_histc(tiles: torch.Tensor, bins: int) -> torch.Tensor:
    return _torch_histc_cast(tiles, bins=bins, min=0, max=1)


def _compute_luts(
    tiles_x_im: torch.Tensor, num_bins: int = 256, clip: float = 40.0, diff: bool = False
) -> torch.Tensor:
    b, gh, gw, c, th, tw = tiles_x_im.shape
    pixels: int = th * tw
    tiles: torch.Tensor = tiles_x_im.view(-1, pixels)  # test with view  # T x (THxTW)
    if not diff:
        if torch.jit.is_scripting():
            histos = torch.stack([_torch_histc_cast(tile, bins=num_bins, min=0, max=1) for tile in tiles])
        else:
            histos = torch.stack(list(map(_my_histc, tiles, [num_bins] * len(tiles))))
    else:
        bins: torch.Tensor = torch.linspace(0, 1, num_bins, device=tiles.device)
        histos = histogram(tiles, bins, torch.tensor(0.001)).squeeze()
        histos *= pixels

    if clip > 0.0:
        max_val: float = max(clip * pixels // num_bins, 1)
        histos.clamp_(max=max_val)
        clipped: torch.Tensor = pixels - histos.sum(1)
        residual: torch.Tensor = torch.remainder(clipped, num_bins)
        redist: torch.Tensor = (clipped - residual).div(num_bins)
        histos += redist[None].transpose(0, 1)
        # trick to avoid using a loop to assign the residual
        v_range: torch.Tensor = torch.arange(num_bins, device=histos.device)
        mat_range: torch.Tensor = v_range.repeat(histos.shape[0], 1)
        histos += mat_range < residual[None].transpose(0, 1)

    lut_scale: float = (num_bins - 1) / pixels
    luts: torch.Tensor = torch.cumsum(histos, 1) * lut_scale
    luts = luts.clamp(0, num_bins - 1)
    if not diff:
        luts = luts.floor()  # to get the same values as converting to int maintaining the type
    luts = luts.view((b, gh, gw, c, num_bins))
    return luts


def _map_luts(interp_tiles: torch.Tensor, luts: torch.Tensor) -> torch.Tensor:
    # gh, gw -> 2x the number of tiles used to compute the histograms
    # th, tw -> /2 the sizes of the tiles used to compute the histograms
    num_imgs, gh, gw, c, _, _ = interp_tiles.shape

    # precompute idxs for non corner regions (doing it in cpu seems slightly faster)
    j_idxs = torch.empty(0, 4, dtype=torch.long)
    if gh > 2:
        j_floor = torch.arange(1, gh - 1).view(gh - 2, 1).div(2, rounding_mode="trunc")
        j_idxs = torch.tensor([[0, 0, 1, 1], [-1, -1, 0, 0]] * ((gh - 2) // 2))  # reminder + j_idxs[:, 0:2] -= 1
        j_idxs += j_floor

    i_idxs = torch.empty(0, 4, dtype=torch.long)
    if gw > 2:
        i_floor = torch.arange(1, gw - 1).view(gw - 2, 1).div(2, rounding_mode="trunc")
        i_idxs = torch.tensor([[0, 1, 0, 1], [-1, 0, -1, 0]] * ((gw - 2) // 2))  # reminder + i_idxs[:, [0, 2]] -= 1
        i_idxs += i_floor

    # selection of luts to interpolate each patch
    # create a tensor with dims: interp_patches height and width x 4 x num channels x bins in the histograms
    # the tensor is init to -1 to denote non init hists
    luts_x_interp_tiles: torch.Tensor = torch.full(  # B x GH x GW x 4 x C x 256
        (num_imgs, gh, gw, 4, c, luts.shape[-1]), -1, dtype=interp_tiles.dtype, device=interp_tiles.device
    )
    # corner regions
    luts_x_interp_tiles[:, 0 :: gh - 1, 0 :: gw - 1, 0] = luts[:, 0 :: max(gh // 2 - 1, 1), 0 :: max(gw // 2 - 1, 1)]
    # border region (h)
    luts_x_interp_tiles[:, 1:-1, 0 :: gw - 1, 0] = luts[:, j_idxs[:, 0], 0 :: max(gw // 2 - 1, 1)]
    luts_x_interp_tiles[:, 1:-1, 0 :: gw - 1, 1] = luts[:, j_idxs[:, 2], 0 :: max(gw // 2 - 1, 1)]
    # border region (w)
    luts_x_interp_tiles[:, 0 :: gh - 1, 1:-1, 0] = luts[:, 0 :: max(gh // 2 - 1, 1), i_idxs[:, 0]]
    luts_x_interp_tiles[:, 0 :: gh - 1, 1:-1, 1] = luts[:, 0 :: max(gh // 2 - 1, 1), i_idxs[:, 1]]
    # internal region
    luts_x_interp_tiles[:, 1:-1, 1:-1, :] = luts[
        :, j_idxs.repeat(max(gh - 2, 1), 1, 1).permute(1, 0, 2), i_idxs.repeat(max(gw - 2, 1), 1, 1)
    ]
    return luts_x_interp_tiles


def _compute_equalized_tiles(interp_tiles: torch.Tensor, luts: torch.Tensor) -> torch.Tensor:
    mapped_luts: torch.Tensor = _map_luts(interp_tiles, luts)  # Bx2GHx2GWx4xCx256

    # gh, gw -> 2x the number of tiles used to compute the histograms
    # th, tw -> /2 the sizes of the tiles used to compute the histograms
    num_imgs, gh, gw, c, th, tw = interp_tiles.shape

    # equalize tiles
    flatten_interp_tiles: torch.Tensor = (interp_tiles * 255).long().flatten(-2, -1)  # B x GH x GW x 4 x C x (THxTW)
    flatten_interp_tiles = flatten_interp_tiles.unsqueeze(-3).expand(num_imgs, gh, gw, 4, c, th * tw)
    preinterp_tiles_equalized = (
        torch.gather(mapped_luts, 5, flatten_interp_tiles)  # B x GH x GW x 4 x C x TH x TW
        .to(interp_tiles)
        .reshape(num_imgs, gh, gw, 4, c, th, tw)
    )

    # interp tiles
    tiles_equalized: torch.Tensor = torch.zeros_like(interp_tiles)

    # compute the interpolation weights (shapes are 2 x TH x TW because they must be applied to 2 interp tiles)
    ih = (
        torch.arange(2 * th - 1, -1, -1, dtype=interp_tiles.dtype, device=interp_tiles.device)
        .div(2.0 * th - 1)[None]
        .transpose(-2, -1)
        .expand(2 * th, tw)
    )
    ih = ih.unfold(0, th, th).unfold(1, tw, tw)  # 2 x 1 x TH x TW
    iw = (
        torch.arange(2 * tw - 1, -1, -1, dtype=interp_tiles.dtype, device=interp_tiles.device)
        .div(2.0 * tw - 1)
        .expand(th, 2 * tw)
    )
    iw = iw.unfold(0, th, th).unfold(1, tw, tw)  # 1 x 2 x TH x TW

    # compute row and column interpolation weights
    tiw = iw.expand((gw - 2) // 2, 2, th, tw).reshape(gw - 2, 1, th, tw).unsqueeze(0)  # 1 x GW-2 x 1 x TH x TW
    tih = ih.repeat((gh - 2) // 2, 1, 1, 1).unsqueeze(1)  # GH-2 x 1 x 1 x TH x TW

    # internal regions
    tl, tr, bl, br = preinterp_tiles_equalized[:, 1:-1, 1:-1].unbind(3)
    t = torch.addcmul(tr, tiw, torch.sub(tl, tr))
    b = torch.addcmul(br, tiw, torch.sub(bl, br))
    tiles_equalized[:, 1:-1, 1:-1] = torch.addcmul(b, tih, torch.sub(t, b))

    # corner regions
    tiles_equalized[:, 0 :: gh - 1, 0 :: gw - 1] = preinterp_tiles_equalized[:, 0 :: gh - 1, 0 :: gw - 1, 0]

    # border region (h)
    t, b, _, _ = preinterp_tiles_equalized[:, 1:-1, 0].unbind(2)
    tiles_equalized[:, 1:-1, 0] = torch.addcmul(b, tih.squeeze(1), torch.sub(t, b))
    t, b, _, _ = preinterp_tiles_equalized[:, 1:-1, gh - 1].unbind(2)
    tiles_equalized[:, 1:-1, gh - 1] = torch.addcmul(b, tih.squeeze(1), torch.sub(t, b))

    # border region (w)
    left, right, _, _ = preinterp_tiles_equalized[:, 0, 1:-1].unbind(2)
    tiles_equalized[:, 0, 1:-1] = torch.addcmul(right, tiw, torch.sub(left, right))
    left, right, _, _ = preinterp_tiles_equalized[:, gw - 1, 1:-1].unbind(2)
    tiles_equalized[:, gw - 1, 1:-1] = torch.addcmul(right, tiw, torch.sub(left, right))

    # same type as the input
    return tiles_equalized.div(255.0)


def equalize_clahe(
    input: torch.Tensor,
    clip_limit: float = 2.0,
    grid_size: tuple[int, int] = (2, 2),
    slow_and_differentiable: bool = False,
) -> torch.Tensor:
    imgs: torch.Tensor = input  # B x C x H x W

    # hist_tiles: torch.Tensor  # B x GH x GW x C x TH x TW  # not supported by JIT
    # img_padded: torch.Tensor  # B x C x H' x W'  # not supported by JIT
    # the size of the tiles must be even in order to divide them into 4 tiles for the interpolation
    hist_tiles, img_padded = _compute_tiles(imgs, grid_size, True)
    tile_size: tuple[int, int] = (hist_tiles.shape[-2], hist_tiles.shape[-1])
    interp_tiles: torch.Tensor = _compute_interpolation_tiles(img_padded, tile_size)  # B x 2GH x 2GW x C x TH/2 x TW/2
    luts: torch.Tensor = _compute_luts(hist_tiles, clip=clip_limit, diff=slow_and_differentiable)  # B x GH x GW x C x 256
    equalized_tiles: torch.Tensor = _compute_equalized_tiles(interp_tiles, luts)  # B x 2GH x 2GW x C x TH/2 x TW/2

    # reconstruct the images form the tiles
    #    try permute + contiguous + view
    eq_imgs: torch.Tensor = equalized_tiles.permute(0, 3, 1, 4, 2, 5).reshape_as(img_padded)
    h, w = imgs.shape[-2:]
    eq_imgs = eq_imgs[..., :h, :w]  # crop imgs if they were padded
    return eq_imgs

以上のコードを使えばtorchのままCLAHEが使える。
ただし、カラー画像に対してはHSV色空間に変換してVチャンネルをCLAHEしてからまたRGBに直す方法が良い。

import torchvision

img = torchvision.io.read_image("sample.png", mode="GRAY")
img = img[None].to(torch.float) / 255  # Uint8 → 0.0~1.0

img = equalize_clahe(img, clip_limit=2.0, grid_size=(8,8))

img = img[0] * 255  # 0.0~1.0 → Uint8
torchvision.io.write_png(img.to(torch.uint8), "output.png")

Open-CVのCLAHEと比較してみる。

img = cv.imread("sample.png", cv.IMREAD_GRAYSCALE)
clahe = cv.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
img = clahe.apply(img)
cv.imwrite("output.png", img)

見た感じ大体あってそうだ。

Discussion