📎

データ拡張ライブラリ Albumentationsの簡単な使い方

2024/03/26に公開

今回はAlbumentationsの使い方について解説します。

1. Albumentations

1.1 Albumentationsとは

Albumentationsは機械学習分野で人気の、画像データ拡張ライブラリです。
主にコンピュータビジョン分野でよく利用されます。

データ拡張とは、学習用データに様々な処理を行い、データの種類や数を増やすことです。これにより、より汎用的なモデルを作る事ができます。

1.2 使用法

簡単な使用法を説明します。

1.2.1 Import

各種ライブラリをインポートします。

# import
import matplotlib.pyplot as plt
import albumentations as A
import numpy as np
import cv2
1.2.2 Load

データ拡張する画像をロードします。

# load
img_path = '/kaggle/input/a-simple-dog/dog.png'

img = cv2.imread(img_path, cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.uint8)
plt.imshow(img)
1.2.3 データ拡張を定義

Albumentationオブジェクトを定義します。
ここで様々なデータ拡張を定義します。

# define data augmentation
transform = A.Compose([
    A.HorizontalFlip(p=0.5), # 水平反転(確率50%)
    A.RandomCrop(width=512, height=512), # 指定サイズでランダムな箇所切り取り
    A.Rotate(limit=45), # ランダムに回転(最大45°)
])
1.2.4 適用

データ拡張を適用します。

# apply
def augment_image(image):
    augmented = transform(image=image)
    return augmented['image']

ここで、返り値は変換後の画像と、使用した手法やオブジェクトをまとめた辞書で返されます。
変換後の画像は'image'keyで取り出す事ができます。

実際の適用

# import
!pip install -U albumentations
import matplotlib.pyplot as plt
import albumentations as A
import numpy as np
import cv2

# load
img_path = '/kaggle/input/a-simple-dog/dog.png'

img = cv2.imread(img_path, cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.uint8)
plt.imshow(img)

# define data augmentation
transform = A.Compose([
    A.HorizontalFlip(p=0.5), # 水平反転(確率50%)
    A.RandomCrop(width=512, height=512), # 指定サイズでランダムな箇所切り取り
    A.Rotate(limit=45), # ランダムに回転(最大45°)
])

# apply
def augment_image(image):
    augmented = transform(image=image)
    return augmented['image']

# show
plt.imshow(augment_image(img))

・出力

上記のように、Albumentationオブジェクトに対してimage渡すことで、定義したデータ拡張を実行してくれます。

1.2.5 自作関数

Albumentationに実装されていない、自作関数によるデータ拡張は以下のように行います。

# Define augmentation
def augmentation(img):
    img_copy = img.copy()
    composition = A.Compose([
            A.Lambda(image=custom_masking, p=0.3)
        ])
    return composition(image=img_copy)['image']

# Homemade function
def custom_masking(image, **kwargs):
    # Your masking logic
    mask_height_Hz = 10  # max mask length
    mask_height_Time = 20  # max mask length
    max_row = 3  # max mask row num
    for _ in range(max_row):
        start_row = np.random.randint(0, image.shape[0] - mask_height_Hz)
        image[start_row:start_row + mask_height_Hz, :] = 0  # Apply mask
        start_col = np.random.randint(0, image.shape[1] - mask_height_Time)
        image[:, start_col:start_col+mask_height_Time] = 0  # Apply mask
    return image


plt.imshow(augmentation(img))

・出力

ここでは横軸と縦軸をランダムにマスクする関数を作成し、適用しました。

このように、Albumentaionを使用することで、簡単に画像データの拡張を行う事ができます。

まとめ

今回はAlbumentationの簡単な使い方について解説しました。
次回は様々なデータ拡張を紹介しようと考えています。

公式ドキュメントに、利用可能な手法が掲載されているので、興味が湧いた方は試してみて下さい。

Discussion