Describable Textures Dataset (DTD) で重複クラス分類
Mircea Cimpoi, Subhransu Maji, Iasonas Kokkinos, Sammy Mohamed, Andrea Vedaldi,
"Describing Textures in the Wild"
Proceedings of the IEEE Conf. on Computer Vision and Pattern Recognition (CVPR)
2014-11-14
https://arxiv.org/abs/1311.3618
テクスチャ解析などで使われるデータセット。
1つの画像に可変長のラベルが対応していて、整形が少し手間だったので使用時のメモを置きます。
上記サイトより、DTDデータセットの例
データの整形
次のディレクトリ構成を仮定する。
project/
├─ dataset/
│ ├─ banded/ banded_0002.jpg ...
│ ├─ blotchy/ blotchy_0003.jpg ...
│ ...
│ └─ zigzagged/ zigzagged_0008.jpg ...
└─ label.txt
これはやる必要はないが、普段正方形のデータしか使わないので予めResize CenterCropで384×384pxに統一した。
import glob
import torch
import torchvision
from torchvision.transform import functional as F
pathes = glob("dataset/*/*.jpg")
for p in pathes:
img = torchvision.io.read_image(p)
img = F.resize(img, 384)
img = F.center_crop(img, (384,384))
torchvision.io.write_jpeg(img, p)
ラベルの整形
もとのラベルファイルは次のような感じになっている。
banded/banded_0002.jpg banded
banded/banded_0006.jpg banded striped
woven/woven_0032.jpg braided crosshatched fibrous woven
woven/woven_0033.jpg bumpy woven
woven/woven_0036.jpg woven
woven/woven_0038.jpg crosshatched woven
woven/woven_0039.jpg woven
...
これをpandasで整形して次のCSVに変えた。重複クラスは最大5つある。
filename, label0, label1, label2, label3, label4
banded_0002.jpg, banded, , , ,
banded_0006.jpg, banded, striped, , ,
crystalline_0207.jpg, blotchy, bumpy, crystalline, ,
crystalline_0208.jpg, crystalline, , , ,
crystalline_0209.jpg, blotchy, crystalline, , ,
crystalline_0213.jpg, bumpy, crystalline, , ,
crystalline_0214.jpg, blotchy, crystalline, , ,
...
Datasetの作成
pytorchのDataset側で画像とラベルを出力する。
前提として、次の画像とラベルのリストは次のように与えることにする。
CLASS_LABEL = {
"banded": 0,
"blotchy": 1,
...
"wrinkled": 45,
"zigzagged": 46,
}
pathes = glob("dataset_dot/*/*.jpg")
df = pd.read_csv("dataset_dot_label.csv").set_index("filename", drop=True)
labels = []
for p in pathes:
fname = p.split("/")[-1]
labels.append([CLASS_LABEL[l] for l in df.loc[fname].values if pd.isnull(l) == False])
x_train, x_test, y_train, y_test = ttsplit(pathes, labels, shuffle=True, random_state=1, test_size=1/12)
train_dataset = DotDataset(x_train, y_train, tf=train_transforms)
test_dataset = DotDataset(x_test, y_test, tf=test_transforms)
上のコードでは、まずpandasのデータフレームのindexをファイル名filename
にして、df.loc
で行にアクセスできるようにして、次に得られた行からNaNでない値のみをリストにする。この例ではlabels
は可変長のリストを要素に持つリストになっている。
その後テストデータと学習データに分けてDatasetに入れてる。
class DotDataset(torch.utils.data.Dataset):
def __init__(self, img_list, label_list, tf=nn.Identity):
self.img_list = img_list
self.label_list = label_list
self.transforms = tf
return None
def __len__(self):
return len(self.img_list)
def __getitem__(self, idx):
img = torchvision.io.read_image(self.img_list[idx]).to(dtype=torch.float)
img = self.transforms(img)
target = torch.zeros(47, dtype=torch.float)
for t in self.label_list[idx]:
target[t] = 1.0
return img, target
target
のone-hot化を自力で行っている(oneからfour-hotなので)以外はいつもの画像分類用のそれである。
以上でDTDを分類タスクに使えるようにできた。損失はいつものCrossEntropyLossでいけるはず。
Discussion