🫠

Describable Textures Dataset (DTD) で重複クラス分類

に公開

Mircea Cimpoi, Subhransu Maji, Iasonas Kokkinos, Sammy Mohamed, Andrea Vedaldi,
"Describing Textures in the Wild"
Proceedings of the IEEE Conf. on Computer Vision and Pattern Recognition (CVPR)
2014-11-14
https://arxiv.org/abs/1311.3618

https://www.robots.ox.ac.uk/~vgg/data/dtd/

テクスチャ解析などで使われるデータセット。
1つの画像に可変長のラベルが対応していて、整形が少し手間だったので使用時のメモを置きます。

上記サイトより、DTDデータセットの例

データの整形

次のディレクトリ構成を仮定する。

project/
├─ dataset/
│    ├─ banded/     banded_0002.jpg    ...
│    ├─ blotchy/    blotchy_0003.jpg   ...
│    ...
│    └─ zigzagged/  zigzagged_0008.jpg ...
└─ label.txt

これはやる必要はないが、普段正方形のデータしか使わないので予めResize CenterCropで384×384pxに統一した。

import glob
import torch
import torchvision
from torchvision.transform import functional as F

pathes = glob("dataset/*/*.jpg")
for p in pathes:
    img = torchvision.io.read_image(p)
    img = F.resize(img, 384)
    img = F.center_crop(img, (384,384))
    torchvision.io.write_jpeg(img, p)

ラベルの整形

もとのラベルファイルは次のような感じになっている。

labels_joint_anno.txt
banded/banded_0002.jpg banded 
banded/banded_0006.jpg banded striped 
woven/woven_0032.jpg braided crosshatched fibrous woven 
woven/woven_0033.jpg bumpy woven 
woven/woven_0036.jpg woven 
woven/woven_0038.jpg crosshatched woven 
woven/woven_0039.jpg woven 
...

これをpandasで整形して次のCSVに変えた。重複クラスは最大5つある。

filename,             label0,      label1,      label2,      label3, label4
banded_0002.jpg,      banded,      ,            ,            ,
banded_0006.jpg,      banded,      striped,     ,            ,
crystalline_0207.jpg, blotchy,     bumpy,       crystalline, ,
crystalline_0208.jpg, crystalline, ,            ,            ,
crystalline_0209.jpg, blotchy,     crystalline, ,            ,
crystalline_0213.jpg, bumpy,       crystalline, ,            ,
crystalline_0214.jpg, blotchy,     crystalline, ,            ,
...

Datasetの作成

pytorchのDataset側で画像とラベルを出力する。
前提として、次の画像とラベルのリストは次のように与えることにする。

CLASS_LABEL = {
    "banded": 0,
    "blotchy": 1,
    ...
    "wrinkled": 45,
    "zigzagged": 46,
}
    
pathes = glob("dataset_dot/*/*.jpg")
df = pd.read_csv("dataset_dot_label.csv").set_index("filename", drop=True)
labels = []
for p in pathes:
    fname = p.split("/")[-1]
    labels.append([CLASS_LABEL[l] for l in df.loc[fname].values if pd.isnull(l) == False])

x_train, x_test, y_train, y_test = ttsplit(pathes, labels, shuffle=True, random_state=1, test_size=1/12)

train_dataset = DotDataset(x_train, y_train, tf=train_transforms)
test_dataset = DotDataset(x_test, y_test, tf=test_transforms)

上のコードでは、まずpandasのデータフレームのindexをファイル名filenameにして、df.locで行にアクセスできるようにして、次に得られた行からNaNでない値のみをリストにする。この例ではlabelsは可変長のリストを要素に持つリストになっている。
その後テストデータと学習データに分けてDatasetに入れてる。

class DotDataset(torch.utils.data.Dataset):
    def __init__(self, img_list, label_list, tf=nn.Identity):
        self.img_list = img_list
        self.label_list = label_list
        self.transforms = tf 
        return None
    
    def __len__(self):
        return len(self.img_list)  

    def __getitem__(self, idx):
        img = torchvision.io.read_image(self.img_list[idx]).to(dtype=torch.float)
        img = self.transforms(img)
        target = torch.zeros(47, dtype=torch.float)
        for t in self.label_list[idx]:
            target[t] = 1.0
        return img, target

targetのone-hot化を自力で行っている(oneからfour-hotなので)以外はいつもの画像分類用のそれである。

以上でDTDを分類タスクに使えるようにできた。損失はいつものCrossEntropyLossでいけるはず。

Discussion