PyTorchとDetection Transformer (DETR)で作る物体認識モデル
はじめに
こんにちは。機械学習ビギナーの中沢(@shnakazawa_ja)です[1]。
世の中にはテーブルデータを対象とした機械学習モデルのチュートリアルは多くありますが、画像に対するものは少なく、コードまで提供されているものは更に少なくなります。そこで、本シリーズでは基本的なコンピュータビジョンモデリングの手法をPythonの深層学習用フレームワークPyTorchで実装していきます。
今回はObject detection (物体認識) を扱います。
モデルのアーキテクチャはDetection Transformer (DETR)を採用し、学習済みのモデルをtorch.hub.load()
を用いて取得します[2]。
実装はKaggle Notebook上で行うことで誰もが再現できるコードを目指します。想定読者は
- 仕事や研究で画像を扱う必要が出てきた方
- Titanicや住宅価格予測のチュートリアルは終えたが、画像コンペへの取り組み方が分からない方
- Kaggleの画像分類コンペで使えるベースライン(のベース)を探している方
といった方々です。画像の前処理や精度を高めるための工夫などについてはスコープ外とします。
コードはGitHubおよびKaggleにて公開しています[3]。
それでは、本題に入っていきましょう。
Object detection (物体認識) と題材
物体認識とは、画像の中にある物体の種類と位置を特定するタスクになります。「コンピュータビジョン」と聞いて一番に思い浮かべるのはこれかもしれません[4]。
今回はKaggleのGlobal Wheat Detectionコンペを題材とします。
本コンペの目的は小麦の写真からの穂の位置検出です。
トレーニング用のデータとして約3,500枚の写真と、それぞれの画像の上での穂の場所を示すバウンディングボックスの情報が与えられています[5]。
この情報を元に、新しい小麦の写真から穂を検出し、その位置を示すバウンディングボックスの座標を出力します。
Notebookの紹介
ここからコードを抜粋しつつ紹介していきます[6]。実際に触ってみないとコードの挙動や意図が掴めないと思うので、ぜひお手元で書き換えながら動かしてみてください。
コンフィグ設定
ファイルパス、ハイパーパラメータなどの設定を独立したセルで管理することにより、コードの可読性が向上し、ワークフローの整理・デバッグが容易になります。
DETRモジュールのインポート
今回用いるDETRでは特別な最適化アルゴリズムと損失関数が必要となるため、開発チームのGitHubリポジトリからダウンロードします。
if os.path.exists(DETR_DIR) == False:
!git clone https://github.com/facebookresearch/detr.git
import sys
sys.path.append(DETR_DIR)
from detr.models.matcher import HungarianMatcher
from detr.models.detr import SetCriterion
関数の定義
ノートブック全体で何度も利用する関数はノートブックの最上部で定義しておくと便利です。ノートブックが読みやすくなるだけでなく、デバッグやメンテナンスが容易になります。
探索的データ分析(Exploratory Data analysis: EDA)
データが手元に来たらまずどんなデータがあるのか把握します。
前回紹介したような内容に加え、物体認識タスクで重要になるポイントが物体の位置(バウンディングボックスの座標)をどのようにエンコードしているかの確認です。
一般的なフォーマットはcoco
, voc-pascal
, yolo
の3つで、それぞれ以下の方法で座標をエンコードしています。
-
coco
:[x, y, width, height]
.x
とy
はボックスの左上の座標を表す。 -
voc-pascal
:[x1, y1, x2, y2]
.x1
,y1
はボックスの左上の座標を、x2
,y2
は右下の座標を表す。 -
yolo
:[x, y, width, height]
.x
とy
はボックスの中心を表す。
自身の使いたいモデルに応じてフォーマットを整える必要があります。
今回用いるDETRはcocoフォーマットでデータをハンドリングしますが、データセットはどのフォーマットで用意されているでしょうか?
データの読み込みと成形
df = pd.read_csv(DATA_DIR + 'train.csv')
df.head()
これをみると、データセットは1行に1つのバウンディングボックスがあり、バウンディングボックスはcocoフォーマットで与えられていることがわかりました。DETRにそのまま放り込める形ですね。
しかし、このデータの持ち方は扱いづらいので成形します。まず、x
, y
, width
, height
をカラムに分割しましょう。
# ref: https://www.kaggle.com/code/tanulsingh077/end-to-end-object-detection-with-transformers-detr#Wheat-Detection-Competition-With-DETR
marking = df.copy()
bboxs = np.stack(marking['bbox'].apply(lambda x: np.fromstring(x[1:-1], sep=',')))
for i, column in enumerate(['x', 'y', 'w', 'h']):
marking[column] = bboxs[:,i]
marking.drop(columns=['bbox'], inplace=True)
marking.head()
次いで、1行1画像のデータフレームを作ります。
# GitHub/KaggleではDefine Helper Functions内で定義
def group_bboxes(df):
df_ = df.copy()
df_['bbox_count'] = 1
df_ = df_.groupby('image_id').count().reset_index()
df_['source'] = df[['image_id', 'source']].groupby('image_id').min().to_numpy()
return_df = df_[['image_id', 'source', 'bbox_count']]
return return_df
grouped_df = group_bboxes(marking)
NUM_QUERIES = max(grouped_df['bbox_count']) # モデルのハイパーパラメータとして使います。1つの画像に最大いくつのバウンディングボックスを置くか。
grouped_df.head()
これら2つのデータフレームを使ってPyTorchに画像データを読み込ませます。
モデルコンポーネントの定義
PyTorchでの機械学習は一般に以下のような流れで行います。
- 画像の変形・Augmentationの定義
- Datasetの定義
- Dataloaderの定義
- Modelの定義
- 学習の実行
画像の変形とAugmentation
まず最初に、画像に適用する変形とAugmentation(データ拡張)の定義を行います。
画像サイズの調整はもちろん、ぼかしや回転、平行移動を行うことで画像の枚数を擬似的に増やすことができます[7][8]。
前回同様、本稿でもAlbumentations[9]を用いて実装します。
物体認識タスクに置いては、画像のみでなくバウンディングボックスも合わせて変形させる必要があります。Albumentationsを使えば、それも同時に行うことができます[10]。
# Image Transformation & Augmentation
def transform_train():
transforms = [
A.Resize(512, 512, p=1),
A.OneOf([
A.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit= 0.2, val_shift_limit=0.2, p=0.9),
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.9)
], p=0.9),
A.ToGray(p=0.01),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.Cutout(num_holes=8, max_h_size=64, max_w_size=64, fill_value=0, p=0.5),
ToTensorV2(p=1)
]
bbox_params = A.BboxParams(format='coco', min_area=0, min_visibility=0, label_fields=['labels'])
return A.Compose(transforms, bbox_params, p=1)
# Validation images undergo only resizing.
def transform_valid():
transforms = [
A.Resize(512,512,p=1),
ToTensorV2(p=1)
]
bbox_params = A.BboxParams(format='coco', min_area=0, min_visibility=0,label_fields=['labels'])
return A.Compose(transforms, bbox_params, p=1)
Datasetの定義
次に、自身の持つデータの形式に合わせDatasetを定義します。
# Dataset
class WheatDataset(Dataset):
def __init__(self, image_ids, dataframe, data_root, transforms=None):
super().__init__()
self.image_ids = image_ids
self.dataframe = dataframe # 1行1画像のデータフレームを与える
self.data_root = data_root # 画像データが入っているフォルダ
self.transforms = transforms # 画像サイズの変更やAugmentation処理
def __len__(self):
return self.image_ids.shape[0]
def __getitem__(self, index):
image_id = self.image_ids[index]
records = self.dataframe[self.dataframe['image_id'] == image_id]
# Load images
image = load_img(f'{self.data_root}{image_id}.jpg').astype(np.float32)
image /= 255.0 # normalization
# bbox and area
boxes = records[['x', 'y', 'w', 'h']].to_numpy()
area = boxes[:,2] * boxes[:,3]
area = torch.as_tensor(area, dtype=torch.float32)
labels = np.zeros(len(boxes), dtype=np.int32)
# Transform images
if self.transforms:
sample = {
'image': image,
'bboxes': boxes,
'labels': labels
}
sample = self.transforms(**sample)
image = sample['image']
boxes = sample['bboxes']
labels = sample['labels']
# Normalizing bboxes. 0~1の範囲に落とし込む。
_,h,w = image.shape
boxes = normalize_bbox(sample['bboxes'],rows=h,cols=w) # Define helper functions内で定義
target = {}
target['boxes'] = torch.as_tensor(boxes, dtype=torch.float32)
target['labels'] = torch.as_tensor(labels, dtype=torch.long)
target['image_id'] = torch.tensor([index])
target['area'] = area
return image, target, image_id
DataLoaderの定義
次に、DataLoaderを定義します。お約束の書き方で[11]。
collate_fn()
はデータセットから取り出されたた個々のデータをミニバッチにまとめる役割を担っています[12]。
def collate_fn(batch):
return tuple(zip(*batch))
def create_dataloader(df, marking, trn_idx, val_idx):
train_ = df.loc[trn_idx,:].reset_index(drop=True)
valid_ = df.loc[val_idx,:].reset_index(drop=True)
# Dataset
train_datasets = WheatDataset(train_['image_id'].to_numpy(), marking, DATA_DIR+'train/', transforms=transform_train())
valid_datasets = WheatDataset(valid_['image_id'].to_numpy(), marking, DATA_DIR+'train/', transforms=transform_valid())
# DataLoader
train_loader = DataLoader(train_datasets, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=NUM_WORKERS, shuffle=True, multiprocessing_context='fork')
valid_loader = DataLoader(valid_datasets, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=NUM_WORKERS, shuffle=False, multiprocessing_context='fork')
return train_loader, valid_loader
Modelの定義 - torch.hub.load()
今回はモデルのアーキテクチャとしてDetection Transformer (DETR)を採用します。
学習済みのモデルはPyTorch Hubから取得できます。こちらには様々なモデルが公開されており、torch.hub.load()
を使って公開されたモデルを取得することができます[13][14][15]。
DETRには2つのパラメータ(NUM_CLASSES
, NUM_QUERIES
)を明示的に与える必要があり、それぞれ「対象物の種類数(+1)」、「1画像内のバウンディングボックスの数の最大値」となります[16]。これらのパラメータを今回のターゲットのクラスとバウンディングボックスの数に合わせて修正します。
class DETRModel(nn.Module):
def __init__(self):
super(DETRModel,self).__init__()
self.num_classes = NUM_CLASSES
self.num_queries = NUM_QUERIES
# Donwload pre-trained model
self.model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
self.in_features = self.model.class_embed.in_features
self.model.class_embed = nn.Linear(in_features=self.in_features, out_features=self.num_classes)
self.model.num_queries = self.num_queries
def forward(self,imgs):
return self.model(imgs)
以上で準備が整いました。
学習と評価を行う
さて、いよいよ学習です。
交差検証(Cross-validation)
交差検証を行うために、画像データセットを特定の割合で学習用・評価用に分割します。
今回は通常のK-Fold法を使います。また、学習が上手く進んだかの確認のため、学習曲線も同時に出力します。
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Cross-validation
folds = KFold(n_splits=FOLD_NUM, shuffle=True, random_state=SEED)\
.split(np.arange(grouped_df.shape[0]), grouped_df['image_id'].to_numpy())
# For Visualization
train_loss_list = []
valid_loss_list = []
for fold, (trn_idx, val_idx) in enumerate(folds):
print(f'==========Cross Validation Fold {fold+1}==========')
# Define matcher, weight, and loss.
# See https://github.com/facebookresearch/detr/blob/3af9fa878e73b6894ce3596450a8d9b89d918ca9/models/detr.py#L304
matcher = HungarianMatcher()
weight_dict = weight_dict = {'loss_ce': 1, 'loss_bbox': 1 , 'loss_giou': 1}
losses = ['labels', 'boxes', 'cardinality']
# Load Data
train_loader, valid_loader = create_dataloader(grouped_df, marking, trn_idx, val_idx)
# Load model, loss function, optimizing algorithm
model = DETRModel().to(device)
criterion = SetCriterion(NUM_CLASSES-1, matcher, weight_dict, eos_coef = NULL_CLASS_COEF, losses=losses).to(device) # eos_coef is used in the output layer to affect the output corresponding to the absence of an object.
optimizer = optim.Adam(model.parameters(), lr=LR)
# For Visualization
train_losses = []
valid_losses = []
# Start training
best_loss = 10**5
for epoch in range(EPOCHS):
time_start = time.time()
print(f'==========Epoch {epoch+1} Start Training==========')
model.train()
# criterion.train()
train_loss = AverageMeter()
pbar = tqdm(enumerate(train_loader), total=len(train_loader))
for step, (imgs, targets, image_ids) in pbar:
# print('.', end='') # Sometimes progress bars do not emerge on the notebook. In in the case, remove hash
img_list = list(img.to(device) for img in imgs)
targets = [{k: v.to(device) for k, v in target.items()} for target in targets]
output = model(img_list)
loss_dict = criterion(output, targets)
weight_dict = criterion.weight_dict
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
optimizer.zero_grad()
losses.backward()
optimizer.step()
train_loss.update(losses.item(), BATCH_SIZE)
print(f'==========Epoch {epoch+1} Start Validation==========')
with torch.no_grad():
valid_loss = AverageMeter()
preds = []
pbar = tqdm(enumerate(valid_loader), total=len(valid_loader))
for step, (imgs, targets, image_ids) in pbar:
img_list = list(img.to(device) for img in imgs)
targets = [{k: v.to(device) for k, v in target.items()} for target in targets]
output = model(img_list)
loss_dict = criterion(output, targets)
weight_dict = criterion.weight_dict
val_losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
valid_loss.update(val_losses.item(), BATCH_SIZE)
# print results from this epoch
exec_t = int((time.time() - time_start)/60)
print(
f'Epoch : {epoch+1} - loss : {train_loss.avg:.4f} - val_loss : {valid_loss.avg:.4f} / Exec time {exec_t} min\n'
)
# For visualization
train_losses.append(train_loss.avg)
valid_losses.append(valid_loss.avg)
train_loss_list.append(train_losses)
valid_loss_list.append(valid_losses)
del model, optimizer, train_loader, valid_loader, train_losses, valid_losses
gc.collect()
torch.cuda.empty_cache()
# 関数はDefine helper functions内で定義
show_validation_score(train_loss_list, valid_loss_list)
良い感じで学習が進んでいますが、Epoch数5では十分でなさそうです。
実践ではvalidationのlossが下がらなくなるまでEpochを重ねますが、本稿はこのまま進めます。
全データで学習しモデルを保存する
交差検証で最適なエポック数とハイパーパラメータを見つけたら、その情報を使ってモデルを全てのデータで学習し、学習が済んだモデルをstate_dict()で保存します[17]。
コードは交差検証とほぼ同じであるため本稿では割愛します[18]。
推論を行う
では、新しいデータを与えて推論を行ってみましょう。今回はテストデータは画像としてのみ与えられており、トレーニングデータのようなテーブル情報が存在しません。そのため、Datasetクラスを新しく定義し直します[19]。
def transform_test():
transforms = [
A.Resize(512,512,p=1),
ToTensorV2(p=1)
]
return A.Compose(transforms)
class WheatTestDataset(Dataset):
def __init__(self, image_ids, data_root, transforms=None):
super().__init__()
self.image_ids = image_ids
self.data_root = data_root
self.transforms = transforms
def __len__(self):
return self.image_ids.shape[0]
def __getitem__(self, index):
image_id = self.image_ids[index][:-4]
# Load images
image = load_img(f'{self.data_root}{image_id}.jpg').astype(np.float32)
image /= 255.0 # normalization
# Transform images
if self.transforms:
# image = self.transforms(image)['image']
sample = {'image': image}
sample = self.transforms(**sample)
image = sample['image']
return image, image_id
そして推論。
# 出力用データフレームを作成
submission_df = pd.DataFrame()
image_id_list = list(os.listdir(DATA_DIR+'test/'))
submission_df['image_file'] = image_id_list
submission_df['image_id'] = [image[:-4] for image in image_id_list]
submission_df['PredictionString'] = '' # Kaggleで求められているかたちに整える
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load Data
test_datasets = WheatTestDataset(submission_df['image_file'].to_numpy(), DATA_DIR+'test/', transforms=transform_test())
# Data Loader
test_loader = DataLoader(test_datasets, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=NUM_WORKERS, shuffle=False, multiprocessing_context='fork')
# Load model
model = DETRModel().to(device)
model.load_state_dict(torch.load(MODEL_DIR+'objectdetection.pth'))
# Start Inference
print(f'==========Start Inference==========')
with torch.no_grad():
pbar = tqdm(enumerate(test_loader), total=len(test_loader))
for step, (imgs, image_ids) in pbar:
img_list = list(img.to(device) for img in imgs)
output = model(img_list)
for i, image_id in enumerate(image_ids):
pred_str = ''
prediction_scores = output['pred_logits'][i].softmax(1).detach().cpu().numpy()[:,0]
predicted_boxes = output['pred_boxes'][i].detach().cpu().numpy()
# _, h, w = imgs[0].shape
h,w,_ = load_img(f'{DATA_DIR}test/{image_id}.jpg').shape # height, width, color of origianl image
# 0~1 -> 1024x1024へ
denormalized_boxes = denormalize_bbox(predicted_boxes, rows=h, cols=w) # Define helper functions内で定義
for box, p in zip(denormalized_boxes, prediction_scores):
if p > THRESHOLD:
score = p
pred_str += f'{score} {int(box[0])} {int(box[1])} {int(box[2])} {int(box[3])} '
submission_df.loc[submission_df['image_id'] == image_id, 'PredictionString'] = pred_str
print(submission_df.head())
推論結果の可視化
このままでは推論がうまくいっているのかよくわかりませんので、1つ適当な画像を可視化してみましょう。
target = submission_df.iloc[0]
img = load_img(DATA_DIR+'test/'+target['image_file']).copy()
prediction_results = target['PredictionString'].split(' ')
num_bbox = len(prediction_results)//5
for i in range(num_bbox):
x = int(prediction_results[i*5+1])
y = int(prediction_results[i*5+2])
w = int(prediction_results[i*5+3])
h = int(prediction_results[i*5+4])
cv2.rectangle(img,
(x, y),
(x+w, y+h),
[255,0,0], 3)
plt.imshow(img)
学習が十分に行えていないにも関わらず、なかなか良い精度で小麦の穂の検出ができているように思われます。
精度を上げるために
本稿では物体認識モデルの基本をさらいました。ここがスタートラインとなり、精度を上げる様々な工夫を加えていくことになります。精度をさらに向上させるためには、例えば、
- 画像の前処理
- より幅広いAugmentationを適用する
- より複雑・巨大な(あるいはシンプルで小さな)モデルアーキテクチャを採用する
- ハイパーパラメータチューニング
- 複数モデルのアンサンブル
などが簡単に思いつくところでしょうか。KaggleのDiscussionも非常に有用で、様々な工夫が紹介されています。
一方で、種々の工夫の有効性はタスクの特性に大きく依存します。あれじゃないこれじゃないと複数のテクニックを試し、ときには組み合わせながら、少しずつ精度の向上を目指していきます[20]。
前回の分類モデルに続き、今回は物体認識モデルの実装を行いました。本稿がみなさんのコンピュータビジョンモデル構築の助けになることを願っています。
謝辞
本シリーズの執筆にあたり、えんがわAI研究所 所長まっちゃんさん、相場雅彰さんにご協力いただきました。
参考情報
- End-to-End Object Detection with Transformers ... DETR原著論文
- PyTorch
- Albumentations
- GitHub/facebookresearch/detr
- End to End Object Detection with Transformers:DETR ... DETRの実装法を大変参考にさせていただいたNotebook
- GitHub/albumentations-team/albumentations#spatial-level-transforms
- Albumentations Documentation/Bounding boxes augmentation for object detection
- M1 mac でmultiprocessに失敗する問題の対処法
- PyTorchにおけるcollate_fnのデフォルト挙動のメモ
- 【機械学習】交差検証後の最終モデルの選び方
修正履歴
- 2023/3/16
- シリーズ他記事へのリンクを追加
- コード微修正
- コード修正に伴う出力結果の差し替え&文章の書き換え
-
まだまだ勉強中のため、本稿には誤りや無駄も多いと思います。厳しいご意見をいただけると嬉しいです。 ↩︎
-
他の方法については、別の記事・ノートブックで取り上げます。 ↩︎
-
GitHub/Kaggleで公開しているため、本稿ではコードは一部のみの抜粋に留めます。 ↩︎
-
個人差があります。 ↩︎
-
サンプル画像はイメージです。 ↩︎
-
適用すべき拡張はタスク次第。例えば「画像の向きに意味があるような画像」であれば回転は不適切な適用。 ↩︎
-
Augmentationをやりすぎて元の画像の特徴が消えてしまっていた、みたいなこともよくあるアンチパターン。 ↩︎
-
本稿では1.3.0を使用。このパッケージはバージョンによって書き方が結構変わる。 ↩︎
-
これらの2ページも参照:GitHub/albumentations-team/albumentations#spatial-level-transforms, Albumentations Documentation/Bounding boxes augmentation for object detection ↩︎
-
M1 Macユーザーは、
DataLoader()
の引数にmultiprocessing_context='fork'
の明記が必要なのでご注意ください。参考。 ↩︎ -
PyTorchにおけるcollate_fnのデフォルト挙動のメモ。collate_fnが何をしているかはこちらの記事がわかりやすくまとめてくださっています。 ↩︎
-
自分のモデルをアップロードすることもできます ↩︎
-
難点としてはインターネット接続が必要な点。Kaggleのインターネットオフコンペでは使えない。 ↩︎
-
他のモデル定義方法については、別の記事・ノートブックで取り上げます。 ↩︎
-
初期値は NUM_CLASSES = 92 (91(+1)), NUM_QUERIES = 100 ↩︎
-
全てのデータで学習し直すべきか、交差検証の過程で作られた複数のモデルを用いる(推論結果は複数モデルのアンサンブルとする)べきかは議論があります。実践的には推論時間や実装の手間、利用可能な計算リソースを考えていずれのアプローチを取るか決定することになるかと思います。参考 ↩︎
-
実践では、交差検証と合わせて一つのスクリプト・セルにしてしまうと管理が楽で良いかと思います。 ↩︎
-
入力の形式がほとんど変わらないのであれば、同一のDatasetクラスを用いながら引数で学習/推論を区別する方法もよく使われます。コードが大きく変わるのであれば別に定義し直した方が楽。 ↩︎
-
この試行錯誤がモデルづくりの一番の楽しみ。 ↩︎
Discussion