Zenn
😺

U-2-NET vs openAI o4 image generation 背景透過

2025/03/30に公開

U-2-NET を使った背景透過

U-2-NET とは

U²-Net(U-square Net)は、顕著性マップ(Salient Object Detection, SOD)を生成するために設計された深層学習モデルです。従来の U-Net 構造を拡張し、エンコーダ・デコーダの各ブロック内にさらに小さな U-Net 構造を内包することで、局所的かつ広域的な特徴を同時に捉えることが可能です。この二重構造により、高精度な画像分割を比較的軽量なモデルで実現しています。U²-Net は事前学習不要でも優れた性能を発揮する点も特徴です。
https://github.com/xuebinqin/U-2-Net

U-2-NET の用途

U²-Net は主に画像の前景抽出や背景除去に利用されます。特に人物や物体のセグメンテーションにおいて優れており、画像編集・バーチャル背景生成・プロダクト写真の加工・医療画像解析など、様々な分野で活用されています。また、軽量版の U²-NetP はモバイル端末やリアルタイム処理など、リソース制限のある環境にも適しています。

背景透過ツールを作る

仮想環境を用意します

python -m venv .venv

仮想環境に入ります

source .venv/bin/activate

必要なライブライ群を requirements.txt を使ってインストールするために、requirements.txt を用意します。
(ちなみに Python 3.12.6 使ってます。)

requirements.txt
contourpy==1.3.1
cycler==0.12.1
filelock==3.18.0
fonttools==4.56.0
fsspec==2025.3.0
imageio==2.37.0
Jinja2==3.1.6
kiwisolver==1.4.8
lazy_loader==0.4
MarkupSafe==3.0.2
matplotlib==3.10.1
mpmath==1.3.0
networkx==3.4.2
numpy==2.2.4
opencv-python==4.11.0.86
packaging==24.2
pillow==11.1.0
pyparsing==3.2.3
python-dateutil==2.9.0.post0
scikit-image==0.25.2
scipy==1.15.2
setuptools==78.1.0
six==1.17.0
sympy==1.13.1
tifffile==2025.3.13
torch==2.6.0
torchvision==0.21.0
typing_extensions==4.13.0

ライブラリをインストールします。

pip install -r requirements.txt

requirements.txt を使わずにライブラリ個別にインストールしたい場合はこちらのコマンドを実行して下さい。

pip install torch torchvision numpy opencv-python Pillow scikit-image matplotlib

U-2-NET リポジトリをクローンします。

git clone https://github.com/xuebinqin/U-2-Net.git

U-2-Netディレクトリが作成されました。

.
├── U-2-Net
└── requirements.txt

使用するモデルは別途ダウンロードする必要があります。リポジトリの README にもありますが、こちらにもリンクを貼っておきます。
通常版(176.3 M):u2net.pth
軽量版(4 M):u2netp.pth
U-2-Net/saved_modelsディレクトリ配下にモデル名と同名のディレクトリを用意して、先ほどダウンロードしたモデルを格納します。

U-2-Ne
├── saved_models
│   ├── face_detection_cv2
│   │   └── haarcascade_frontalface_default.xml
│   ├── u2net
│   │   └── u2net.pth # 通常版
│   └── u2netp
│       └── u2netp.pth # 軽量版

U-2-Netディレクトリ配下に下記のファイルを追加して下さい。背景透過処理するコードです。(通常版)

U-2-Net/transparent.py
import os
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils
# import torch.optim as optim

import numpy as np
from PIL import Image
import glob

from data_loader import RescaleT
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset

from model import U2NET # full size version 173.6 MB
from model import U2NETP # small version u2net 4.7 MB

# normalize the predicted SOD probability map
def normPRED(d):
    ma = torch.max(d)
    mi = torch.min(d)

    dn = (d-mi)/(ma-mi)

    return dn

def save_output(image_name,pred,d_dir):

    predict = pred
    predict = predict.squeeze()
    predict_np = predict.cpu().data.numpy()

    im = Image.fromarray(predict_np*255).convert('RGB')
    img_name = image_name.split(os.sep)[-1]
    image = io.imread(image_name)
    imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)

    pb_np = np.array(imo)

    aaa = img_name.split(".")
    bbb = aaa[0:-1]
    imidx = bbb[0]
    for i in range(1,len(bbb)):
        imidx = imidx + "." + bbb[i]

    imo.save(d_dir+imidx+'.png')

def make_mask(model_name):

    # --------- 1. get image path and name ---------
    # model_name='u2net'#u2netp
    # model_name='u2netp'#u2netp



    image_dir = os.path.join(os.getcwd(), 'img', 'transparent')
    prediction_dir = os.path.join(os.getcwd(), 'img', model_name + '_results' + os.sep)
    model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '.pth')

    img_name_list = glob.glob(image_dir + os.sep + '*')
    print(img_name_list)

    # --------- 2. dataloader ---------
    #1. dataloader
    test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
                                        lbl_name_list = [],
                                        transform=transforms.Compose([RescaleT(320),
                                                                      ToTensorLab(flag=0)])
                                        )
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if(model_name=='u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3,1)
    elif(model_name=='u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3,1)

    if torch.cuda.is_available():
        net.load_state_dict(torch.load(model_dir))
        net.cuda()
    else:
        net.load_state_dict(torch.load(model_dir, map_location='cpu'))
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:",img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1,d2,d3,d4,d5,d6,d7= net(inputs_test)

        # normalization
        pred = d1[:,0,:,:]
        pred = normPRED(pred)

        # save results to test_results folder
        if not os.path.exists(prediction_dir):
            os.makedirs(prediction_dir, exist_ok=True)
        save_output(img_name_list[i_test],pred,prediction_dir)

        del d1,d2,d3,d4,d5,d6,d7

def transparent(model_name):
    image_dir = os.path.join(os.getcwd(), 'img', 'transparent')
    prediction_dir = os.path.join(os.getcwd(), 'img', model_name + '_results' + os.sep)
    img_name_list = glob.glob(image_dir + os.sep + '*')

    for in_file in img_name_list:
        # in_fileパスから拡張子なしファイル名を取得
        file_name_without_ext = os.path.splitext(os.path.basename(in_file))[0]
        # mask_fileを指定
        mask_file = os.path.join(prediction_dir, file_name_without_ext + ".png")
        # out_fileを指定
        out_file = os.path.join(prediction_dir, file_name_without_ext + "_transparent.png")
        # オリジナル画像
        img = Image.open(in_file).convert("RGB")
        # マスク画像(白黒)
        mask = Image.open(mask_file).convert("L")

        # アルファチャンネル(透過)を追加
        img.putalpha(mask)
        img.save(out_file)

if __name__ == "__main__":
    model_name='u2net'#u2netp
    # model_name='u2netp'#u2netp
    make_mask(model_name)
    transparent(model_name)

背景透過をかけたい画像を用意します。
U-2-Netディレクトリ配下にimg/transparentディレクトリを用意して、画像ファイルを格納します。ここに格納されている画像は一括して背景透過処理されます。
今回は4つのファイルを実行してみます。png 形式でも jpg 形式でも OK ですが、背景透過後のファイルは png 形式になってしまいます。(jpg 形式のままにしたい場合は transparent.py を見直して下さい)

U-2-Net
├── img
│   └── transparent
│       ├── cat.png
│       ├── demo_1_dark.png
│       ├── devteru_500_500.png
│       └── diego-ph-fIq0tET6llw-unsplash.jpg

実行結果(得意/不得意がわかった)

4つの画像の結果を見てきましょう

1. 猫

背景にある星が潰れちゃいました。でも髭はちゃんときり抜かれている!!
u2net_cat

2. パズルピース

背景とパズルピースの色が同じ黒のオリジナル画像。ピースの凸凹が潰れちゃってます。
u2net_piece

3. ロケット

ロケットの細かい切り抜きができていないのと、右上の雲の中がしっかりマスクされていない。オリジナルの右上の雲自体が幾らか透過処理されていたからかも。
u2net_rocket

4. 写真(人の手と電球)

U-2-NET の得意とするところ。綺麗にマスクできている。電球の中は透過されていない。(指の間はさすがに無理か。。。)
u2net_photo

openAI o4 image generation で背景透過処理した画像と比較

背景処理の方法は簡単です。こんな感じ。
o4

1. 猫

4o image generation の方は背景に散りばめられている星もちゃんと切り抜かれています。すばらしい!!
compare_cat

2. パズルピース

4o image genaration ではパズルピースの中も背景として認識されています。ピースの凸凹も完璧!!
compare_piece

3. ロケット

4o image generation ではロケットの細かい切り抜きも、右上の薄い雲も完璧です。
compare_rocket

4. 写真(人の手と電球)

4o image genaration のほうは、なぜか縦に伸びてしまっている。電球の中が透過されないのはどちらも同じ。
compare_photo
拡大してみてみると違いがはっきりと出ます。
左が 4o image generation、右が U-2-NET です。
左は写真ではなくて生成 AI が作った画像であることがわかります(DALL-E っぽさがでてる)
それに対して、U-2-NET は写真の合成ですから、作り物感は感じないですよね。
compare_photo_zoom

まとめ

  • U²-Net は軽量で高精度な背景透過モデルで、セットアップとコードによりローカルで画像の透過処理が可能。

  • U²-Net は画像によって得意・不得意があり、特に写真や人物は得意だが、細かい輪郭や背景と同系色の物体には弱い。

  • OpenAI の 4o image generation は、より精度の高い背景透過が可能で、特に細部や複雑な形状の切り抜きに優れている。

  • OpenAI の 4o image generation で写真の背景透過処理をすると、生成 AI で作成されだ画像として出力される。(もはや写真ではない)

Discussion

ログインするとコメントできます