🐡

AIを使ってイラストの情報量を可視化できないか試してみた

2024/05/16に公開

はじめに

皆様は「イラストの情報量」という概念をご存じでしょうか?
「イラストの情報量」は、情報工学でいうところの「情報量(エントロピー)」とは異なり、「絵やイラストを見たときに読み取れる情報の数」のようなものを指すことが多いです。
(曖昧な書き方をしたのは、「イラストの情報量」という概念は結構人によって解釈が異なり、定義も色々あるので断言が難しいため)

この「イラストの情報量」という概念は、視線誘導(絵を用いて伝えたい部分により視線が留まりやすいようにする)を考える際に頻出する単語です。

https://ichi-up.net/2018/038

https://www.clipstudio.net/oekaki/archives/152701

この情報量の制御ができるようになると、より魅力的な絵を描けるようになると言われています。
今回は、この「イラストの情報量」をAIを使ってわかりやすく可視化できないか?というお話です。

この記事における「イラストの情報量」と「可視化」の定義

イラストの情報量

さて、可視化をするためにはまず人によって解釈の異なる「イラストの情報量」という単語に厳密な定義を置かなければいけません。
よく見る定義としては、色数であったり、明暗のコントラストや描きこまれた線の数、絵の中で書かれている物体の数等があります。

どの要素を重要視するかは人によって異なると思いますが、今回は「絵の中に書かれている物体の数」に着目し、かつある程度他の要素も取り入られるように「SegmentAnythingを用いてSegmentされた領域がどの程度密集しているか(密集している部分は情報量が多い)」という定義を用います。

SegmentAnythingは過去の記事でも何度か紹介した、Metaが出しているゼロショットのSegmentation技術で、物体や光、色の境界などをベースに画像を複数の領域に分割してくれるAI技術となります。

詳しく知りたい方は以下の記事をどうぞ
https://zenn.dev/mattyamonaca/articles/dcacb4f6dcd58f

可視化

つづいて、可視化の定義も決めていく必要があります。
なるべく一目でどこに情報が集中しているかがをわかりやすく見えるようにしたいので、この記事では以下のようなヒートマップ(jetフォーマット)による可視化を行います。

出力イメージ

情報量(セグメントされた領域の密集度)が高いほど赤く、低いほど青くなっていくイメージです。

実装

では、定義が決まったので実装していきましょう。
処理手順はこちら。

  1. 分析対象の画像読み込み
  2. 読み込んだ画像をSAMで領域分け
  3. 領域の密集度を計算
  4. 密集度に則ってヒートマップを作成
  5. 読み込んだ画像にヒートマップを重ねて出力

この中で工夫が必要となる工程は3. 領域の密集度の計算です。
まず、セグメント分けされた画像というのがどのようなものなのかを確認してみましょう。
この画像を対象として、情報量を可視化していきます。

SAMの実装部分は以前も紹介したことがあるので割愛します(この記事の最後にセグメンテーション部分も含めたコード全量を置いてあるのでそちらを確認してください)
セグメンテーションした結果がこちら。

このように色分けされた領域がどの程度密集しているか?をヒートマップに変換していきます。

まず、32×32のウィンドウを用意し領域分けされた画像の左上に配置します。

この枠の中に存在する領域の数を、この枠内における情報量と定義します。
たとえば、この赤枠の中に10個の領域がある場合は、この赤枠に属する全てのピクセルに「10」という値が割り振られます。

つづいて、この赤枠をすこし右にずらします(厳密には1ピクセルずつずらす)

最初と同様に、この赤枠内の領域の数を数え上げ、その値を情報量としてずらした赤枠に属するピクセルすべてに割り振ります。
このとき、最初の赤枠ですでに情報量が割り当てられたピクセルに、新たに情報量を割り振るような操作を行う必要があります。
このような場合は、平均を取る事で情報量を新たに設定します。
(例えば、最初の赤枠に含まれる領域数が10、右にずらした赤枠に含まれる領域数が9だった場合、両方の赤枠に所属するピクセルには9.5という情報量が付与される)
このような操作を、赤枠が右下にくるまで実施することで、すべてのピクセルに情報量が割り当てられ、ヒートマップの作成が可能になります。

最終的な結果がこちら
しっかり領域が密集している部分が赤くなっていますね。


実装コードはこちら

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import sys
from PIL import Image
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

def load_and_process_image(image):
    # 画像を読み込む
    #image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    image_array = np.array(image)
    image = cv2.cvtColor(image_array, cv2.COLOR_RGB2BGR)
    # オリジナルのサイズを取得
    original_height, original_width, _ = image.shape

    
    # 横幅を256に固定し、縦幅を比率を保ってリサイズ
    aspect_ratio = original_height / original_width
    new_width = 256
    new_height = int(new_width * aspect_ratio)
    
    # リサイズ
    resized_image = cv2.resize(image, (new_width, new_height))
    
    return resized_image

def visualize_clusters(image_cls):
    # クラスタインデックスの最大値を取得
    num_clusters = np.max(image_cls)
    # クラスタごとにランダムな色を生成
    colors = np.random.randint(0, 255, (num_clusters + 1, 3))
    # カラー画像を作成
    height, width = image_cls.shape
    color_image = np.zeros((height, width, 3), dtype=np.uint8)
    for i in range(1, num_clusters + 1):
        color_image[image_cls == i] = colors[i]
    # 画像を表示
    #plt.imshow(color_image)
    #plt.axis('off')
    #plt.show()
    cv2.imwrite("seg_mask.png", color_image)

def create_heatmap_overlay(image, result, alpha=0.5, save_path='heatmap_overlay.png'):
    height, width = image.shape[:2]  # オリジナル画像の高さと幅を取得

    # データ型をチェック・変換
    image = np.array(image, dtype=np.uint8)

    # カラースケール画像を作成
    color_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # カラーマップでヒートマップを生成
    plt.imshow(result, cmap='jet', alpha=alpha)
    plt.axis('off')

    # ヒートマップを画像として保存
    plt.savefig('heatmap.png', bbox_inches='tight', pad_inches=0)
    plt.close()

    # ヒートマップを読み込み
    heatmap = cv2.cvtColor(cv2.imread('heatmap.png'), cv2.COLOR_BGR2RGB)
    heatmap = cv2.resize(heatmap, (width, height))

    # ヒートマップとオリジナル画像を合成
    overlay = cv2.addWeighted(color_image, 1-alpha, heatmap, alpha, 0)

    # オーバーレイを画像として保存
    cv2.imwrite(save_path, overlay)

    return overlay

def cluster_analysis(image, window_size = 64):
    height, width = image.shape[:2]
    result = np.zeros((height, width))
    cluster_count_sum = np.zeros((height, width))
    window_count = np.zeros((height, width))

    for i in range(height - window_size + 1):
        for j in range(width - window_size + 1):
            window = image[i:i + window_size, j:j + window_size]
            unique_clusters = np.unique(window)
            cluster_count = len(unique_clusters)

            cluster_count_sum[i:i + window_size, j:j + window_size] += cluster_count
            window_count[i:i + window_size, j:j + window_size] += 1

    result = cluster_count_sum / window_count
    return result

def create_image_from_segmentations(segmentations):
    height, width = segmentations[0].shape
    image = np.zeros((height, width), dtype=int)

    for cluster_index, seg in enumerate(segmentations):
        image[seg] = cluster_index + 1

    return image

def create_image_from_anns(anns):
    """
    'show_anns()'関数で表示されるセグメンテーションに基づいて、
    クラスタインデックスを割り振る画像を生成する。
    :param anns: セグメンテーション情報のリスト
    :return: クラスタインデックスで埋められた2D配列
    """
    if len(anns) == 0:
        return None

    # セグメンテーションをサイズ順にソート
    sorted_anns = sorted(anns, key=lambda x: x['area'], reverse=True)
    height, width = sorted_anns[0]['segmentation'].shape

    # インデックスを保持するための配列を初期化
    image_cls = np.zeros((height, width), dtype=int)

    # 各クラスタに対してセグメンテーションマスクを適用
    for idx, ann in enumerate(sorted_anns):
        m = ann['segmentation']
        image_cls[m] = idx + 1  # クラスタインデックスを1から始める

    return image_cls

def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=lambda x: x['area'], reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)
    return img

def process(
        model_dir, 
        image, window_size ,
        points_per_side,
        pred_iou_thresh,
        stability_score_thresh,
        min_mask_region_area,
        ):
    width = image.width
    height = image.height
    image = load_and_process_image(image)

    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    sam_checkpoint = f"{model_dir}/sam_vit_h_4b8939.pth"
    model_type = "vit_h"
    device = "cuda"

    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)

    mask_generator = SamAutomaticMaskGenerator(
        model=sam,
        points_per_side=points_per_side,
        pred_iou_thresh=pred_iou_thresh,
        stability_score_thresh=stability_score_thresh,
        crop_n_layers=1,
        crop_n_points_downscale_factor=2,
        min_mask_region_area=min_mask_region_area,
    )

    masks = mask_generator.generate(image)
    segs = [mask["segmentation"] for mask in masks]

    # クラスタ情報の配列を生成
    image_cls = create_image_from_anns(masks)
    visualize_clusters(image_cls)
    result = cluster_analysis(image_cls, window_size)
    result = cv2.resize(result, (width, height))

最後に

今回はイラストの情報量を自分なりに定義し、ヒートマップで可視化する手法を記事にしました。
さて、記事中で触れたように「イラストの情報量」は人や状況によって定義がかなり異なります。
そこで、今回の記事と同じような方法での可視化を様々な定義に基づいて行えるツールを公開しました。

この記事の公開時点で、以下の5つの方法でヒートマップによる可視化が行えます!

  • オブジェクトの密集度(この記事で紹介したもの)
  • 深度の情報
  • 明度の情報
  • コントラスト(明暗の差)の情報
  • 顕著性マップ

興味がある方は是非遊んでみてください!

https://github.com/mattyamonaca/AMeThyst

Discussion