🫠

Describable Textures Dataset (DTD) で重複クラス分類

2024/01/12に公開

Mircea Cimpoi, Subhransu Maji, Iasonas Kokkinos, Sammy Mohamed, Andrea Vedaldi,
"Describing Textures in the Wild"
Proceedings of the IEEE Conf. on Computer Vision and Pattern Recognition (CVPR)
2014-11-14
https://arxiv.org/abs/1311.3618

https://www.robots.ox.ac.uk/~vgg/data/dtd/

テクスチャ解析などで使われるデータセット。
1つの画像に可変長のラベルが対応していて、整形が少し手間だったので使用時のメモを置きます。

上記サイトより、DTDデータセットの例

データの整形

次のディレクトリ構成を仮定する。

project/
├─ dataset/
│    ├─ banded/     banded_0002.jpg    ...
│    ├─ blotchy/    blotchy_0003.jpg   ...
│    ...
│    └─ zigzagged/  zigzagged_0008.jpg ...
└─ label.txt

これはやる必要はないが、普段正方形のデータしか使わないので予めResize CenterCropで384×384pxに統一した。

import glob
import torch
import torchvision
from torchvision.transform import functional as F

pathes = glob("dataset/*/*.jpg")
for p in pathes:
    img = torchvision.io.read_image(p)
    img = F.resize(img, 384)
    img = F.center_crop(img, (384,384))
    torchvision.io.write_jpeg(img, p)

ラベルの整形

もとのラベルファイルは次のような感じになっている。

labels_joint_anno.txt
banded/banded_0002.jpg banded 
banded/banded_0006.jpg banded striped 
woven/woven_0032.jpg braided crosshatched fibrous woven 
woven/woven_0033.jpg bumpy woven 
woven/woven_0036.jpg woven 
woven/woven_0038.jpg crosshatched woven 
woven/woven_0039.jpg woven 
...

これをpandasで整形して次のCSVに変えた。重複クラスは最大5つある。

filename,             label0,      label1,      label2,      label3, label4
banded_0002.jpg,      banded,      ,            ,            ,
banded_0006.jpg,      banded,      striped,     ,            ,
crystalline_0207.jpg, blotchy,     bumpy,       crystalline, ,
crystalline_0208.jpg, crystalline, ,            ,            ,
crystalline_0209.jpg, blotchy,     crystalline, ,            ,
crystalline_0213.jpg, bumpy,       crystalline, ,            ,
crystalline_0214.jpg, blotchy,     crystalline, ,            ,
...

Datasetの作成

pytorchのDataset側で画像とラベルを出力する。
前提として、次の画像とラベルのリストは次のように与えることにする。

CLASS_LABEL = {
    "banded": 0,
    "blotchy": 1,
    ...
    "wrinkled": 45,
    "zigzagged": 46,
}
    
pathes = glob("dataset_dot/*/*.jpg")
df = pd.read_csv("dataset_dot_label.csv").set_index("filename", drop=True)
labels = []
for p in pathes:
    fname = p.split("/")[-1]
    labels.append([CLASS_LABEL[l] for l in df.loc[fname].values if pd.isnull(l) == False])

x_train, x_test, y_train, y_test = ttsplit(pathes, labels, shuffle=True, random_state=1, test_size=1/12)

train_dataset = DotDataset(x_train, y_train, tf=train_transforms)
test_dataset = DotDataset(x_test, y_test, tf=test_transforms)

上のコードでは、まずpandasのデータフレームのindexをファイル名filenameにして、df.locで行にアクセスできるようにして、次に得られた行からNaNでない値のみをリストにする。この例ではlabelsは可変長のリストを要素に持つリストになっている。
その後テストデータと学習データに分けてDatasetに入れてる。

class DotDataset(torch.utils.data.Dataset):
    def __init__(self, img_list, label_list, tf=nn.Identity):
        self.img_list = img_list
        self.label_list = label_list
        self.transforms = tf 
        return None
    
    def __len__(self):
        return len(self.img_list)  

    def __getitem__(self, idx):
        img = torchvision.io.read_image(self.img_list[idx]).to(dtype=torch.float)
        img = self.transforms(img)
        target = torch.zeros(47, dtype=torch.float)
        for t in self.label_list[idx]:
            target[t] = 1.0
        return img, target

targetのone-hot化を自力で行っている(oneからfour-hotなので)以外はいつもの画像分類用のそれである。

以上でDTDを分類タスクに使えるようにできた。損失はいつものCrossEntropyLossでいけるはず。

Discussion