🍃

PyTorchとEfficientNetV2で作る画像分類モデル

2023/03/09に公開

はじめに

こんにちは。機械学習初心者の中沢(@shnakazawa_ja)です[1]

世の中にはテーブルデータを対象とした機械学習モデルのチュートリアルは多くありますが、画像に対するものは少なく、コードまで提供されているものは更に少なくなります。そこで、今回から数記事に分けて基本的なコンピュータビジョンモデリングの手法をPythonの深層学習用フレームワークPyTorchで実装していきます。



今回は Classification (画像分類) を扱います。
モデルのアーキテクチャにはEfficientNetV2を使います。EfficientNetV2は画像分類、物体検出、セマンティックセグメンテーションなど、幅広いコンピュータビジョンタスクで優れた性能を達成しています。
本稿では学習済みのモデルをライブラリから取得します[2]

実装はKaggle Notebook上で行うことで誰もが再現できるコードを目指します。主な読者として

  • 仕事や研究で画像を扱う必要が出てきた方
  • Titanicや住宅価格予測のチュートリアルは終えたが、画像コンペへの取り組み方が分からない方
  • Kaggleの画像分類コンペで使えるベースライン(のベース)を探している方

を想定しています。簡単のため、画像の前処理や精度を高めるための工夫などについてはスコープ外とします。

コードはGitHubおよびKaggleにて公開しています

それでは、本題に入っていきましょう。

Classfication (画像分類) と題材

画像分類とは、ある画像がどのカテゴリに属するか分類する/画像の中に写っているものが何かを判別するタスクになります。
今回は題材としてKaggleのCassava Leaf Disease Classificationコンペを取り上げます。

本コンペではキャッサバの葉の写真から「健康か病気か」、「病気の場合は何の病気か」を判別することが目的となっています。
トレーニング用のデータとして約2万枚の写真と、それぞれの健康状態のラベルが与えられています[3]

また、本コンペのポイントとしてコードコンペである点も挙げられます。
テストデータの予測結果ではなく、コードを書いたKaggle Notebookの提出によってスコアがつけられるという仕組みで、実行時間に制限があるため、多くの方は学習用と推論用(提出用)の2ノートブックを作成されていたようです[4]

Notebookの紹介

ここからコードを紹介していきます。本稿ではコードは一部のみの抜粋に留めます。記事が長くなりすぎますしね。適宜GitHubKaggleに飛んでいただければと思います。

コンフィグ設定

ファイルパス、ハイパーパラメータなどの設定を独立したセルで管理することにより、コードの可読性が向上し、ワークフローの整理・デバッグが容易になります。

関数の定義

ノートブック全体で何度も利用する関数はノートブックの最上部で定義しておくと便利です。ノートブックが読みやすくなるだけでなく、デバッグやメンテナンスが容易になります。

探索的データ分析(Exploratory Data analysis: EDA)

さて、データが手元に来たらまず最初にすべきはどんなデータがあるのか把握することです。さすがに2万枚全て見るのは骨が折れますので、いくつか大切なポイントを抽出してみましょう。例えば、

  • 代表的な画像を自分の目で見る
    • 特徴やバリエーションを把握
    • 「自分の目で見て分類できるか?」
  • ラベルごとの画像枚数
  • 画像のサイズや色、拡張子の確認
  • 輝度のヒストグラム等から外れ値画像の有無を調べる

などが挙げられるでしょうか。テーブルデータの場合と比べてEDAはシンプルになります。

初手ではこの程度にしておき、画像の前処理やAugmentationを行うタイミングで、逐一目的を絞って見ています[5]

データの読み込み

df = pd.read_csv(DATA_DIR + 'train.csv')
df.head()

ラベルごとの画像枚数

df['label'].value_counts()

>>3    13158
>>4     2577
>>2     2386
>>1     2189
>>0     1087
>>Name: label, dtype: int64

偏りのあるデータセットであることがわかります。

画像のサイズと拡張子の確認

img_shape = set()
img_ext = set()
img_names = Path(DATA_DIR+'train_images/').glob('*')
pbar = tqdm(img_names, total=len(df))
for img_name in pbar:
    img = load_img(img_name.as_posix())
    img_shape.add(img.shape)
    img_ext.add(img_name.suffix)
print(f'Image shapes are {img_shape}.')
print(f'Image extensions are {img_ext}.')

>>Image shapes are {(600, 800, 3)}.
>>Image extensions are {'.jpg'}.

全て600x800ピクセルのカラー、jpg画像だと判りました。

輝度ヒストグラムのプロット

img_names = Path(DATA_DIR+'train_images/').glob('*')
plt.figure(figsize=(10,10))
pbar = tqdm(img_names, total=len(df))
for img_name in pbar:
    img = load_img(img_name.as_posix())
    hist = cv2.calcHist([img],[0],None,[256],[0,256])
    plt.plot(hist)
plt.show()

画面の半分以上がグレーの画像が1枚あることが判りました。また、サチり気味な写真があることも読み取れます。

モデルコンポーネントの定義

本稿ではPyTorchを用いて画像分類モデルを実装します。PyTorchでの機械学習は一般に以下のような流れで行います。

  1. 画像の変形・Augmentationの定義
  2. Datasetの定義
  3. Dataloaderの定義
  4. Modelの定義
  5. 学習の実行

画像の変形とAugmentation

まず最初に、画像に適用する変形とAugmentation(データ拡張)の定義を行います。

画像サイズの調整はもちろん、ぼかしや回転、平行移動を行うことで画像の枚数を擬似的に増やすことができます。Augmentationは学習用の画像にのみ適用し、評価用の画像はサイズの変更のみに留めます。また注意点として、適用すべきAugmentation手法はタスク次第で、例えば「向きに意味があるような画像」であれば回転は不適切な適用となります。

本稿ではAlbumentationsというパッケージを用います[6][7]

# Image Augmentation
def transform_train():
    transform = [
        A.Resize(512,512,p=1),
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(p=0.5),
        A.CoarseDropout(p=0.5),
        ToTensorV2(p=1.0)
    ]
    return A.Compose(transform)


# Validation (and test) images should only be resized.
def transform_valid():
    transform = [
        A.Resize(512,512,p=1),
        ToTensorV2(p=1.0)
    ]
    return A.Compose(transform)

Datasetの定義

次に、自身の持つデータの形式に合わせDatasetを定義します[8]

# Dataset
class CassavaDataset(Dataset):
    def __init__(self, df, data_root, transforms=None, give_label=True):
        """Datasetオブジェクトがインスタンス化される際に1度だけ実行される
        """ 
        super().__init__()
        self.df = df.reset_index(drop=True).copy() # 画像の情報をまとめたテーブル
        self.data_root = data_root # 画像があるディレクトリ
        self.transforms = transforms # 画像をどう変形するか
        self.give_label = give_label # ラベルある? (TestのときにはFalseになる)
        
        if give_label == True:
            self.labels = self.df['label'].values

    def __len__(self):
        """データセットのレコード数を返す関数
        """ 
        return self.df.shape[0]
    
    def __getitem__(self, index):
        """指定されたindexに対応するサンプルをデータセットから読み込んで返す関数
        """ 
        # get labels
        if self.give_label:
            target = self.labels[index]

        # Load images
        img  = load_img(f'{self.data_root}/{self.df.loc[index]["image_id"]}').astype(np.float32)
        # img /= 255.0 # Normalization

        # Transform images
        if self.transforms:
            img = self.transforms(image=img)['image']

        if self.give_label == True:
            return img, target
        else:
            return img

DataLoaderの定義

次に、DataLoaderを定義します。これはお約束の書き方があります。
M1 Macユーザーは、 DataLoader()の引数にmultiprocessing_context='fork'の明記が必要なのでご注意ください[9]

# DataLoader
def create_dataloader(df, trn_idx, val_idx):
    train_ = df.loc[trn_idx,:].reset_index(drop=True)
    valid_ = df.loc[val_idx,:].reset_index(drop=True)

    # Dataset
    train_datasets = CassavaDataset(train_, DATA_DIR+'train_images/', transforms=transform_train())
    valid_datasets = CassavaDataset(valid_, DATA_DIR+'train_images/', transforms=transform_valid())

    # Data Loader
    train_loader = DataLoader(train_datasets, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True, multiprocessing_context='fork')
    valid_loader = DataLoader(valid_datasets, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False, multiprocessing_context='fork')

    return train_loader, valid_loader

Modelの定義 - ライブラリの利用

今回はEfficientNetV2を採用します。
PyTorchでのモデル定義にはいくつかの方法がありますが、今回はコンピュータビジョンタスクのためのライブラリtimmに収録されている事前学習済みモデルを使用します[10][11]

使用法はとっても簡単。

class EfficientNet_V2(nn.Module):
    def __init__(self, n_out):
        super(EfficientNet_V2, self).__init__()
        # Define model
	# ここの引数を変えるだけで別のモデルも使える!
        self.effnet = timm.create_model('efficientnetv2_s', pretrained=True, num_classes=n_out) 

    def forward(self, x):
        return self.effnet(x)

以上!これだけで強力な画像処理モデルが使えてしまいます。

学習と評価を行う

では、いよいよ学習に入っていきましょう。

高精度な分類モデルの作成には、最適なモデルアーキテクチャとハイパーパラメータの選択が重要です。そのため、最初に交差検証(Cross-validation) を行い、最大の性能を得られる最適なアーキテクチャ、最適なハイパーパラメータの組み合わせを特定します[12]。その後、最適な組み合わせで、全てのデータを用いてモデルを再トレーニングします[13]

交差検証(Cross-validation)

交差検証を行うために、画像データセットを特定の割合で学習用・評価用に分割します。
今回の画像には5種類のラベルが与えられていました。このデータセットをランダムに分割してしまうと、学習用と評価用のデータセットの間に枚数の偏りが出てしまう可能性があります。学習用データと評価用データのデータ分布を揃えるため、scikit-learnのStratifiedKFold()を用います。これによって各ラベルの割合をを保った状態でデータを分割できます。

また、学習が上手く進んだかの確認のため、学習曲線も同時に出力します。

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
print(f'Using {device} device')

# Cross-validation
folds = StratifiedKFold(n_splits=FOLD_NUM, shuffle=True, random_state=SEED)\
        .split(np.arange(df.shape[0]), df['label'].to_numpy())

# For Visualization
train_acc_list = []
valid_acc_list = []
train_loss_list = []
valid_loss_list = []


for fold, (trn_idx, val_idx) in enumerate(folds):
    print(f'==========Cross-Validation Fold {fold+1}==========')
    # Load Data
    train_loader, valid_loader = create_dataloader(df, trn_idx, val_idx)

    # Load model, loss function, and optimizing algorithm
    model = EfficientNet_V2(NUM_CLASSES).to(device)
    loss_fn = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=LR)
            
    # For Visualization
    train_accs = []
    valid_accs = []
    train_losses = []
    valid_losses = []

    # Start training
    best_acc = 0
    for epoch in range(EPOCHS):
        time_start = time.time()
        print(f'==========Epoch {epoch+1} Start Training==========')
        model.train()
        
        epoch_loss = 0
        epoch_accuracy = 0
    
        pbar = tqdm(enumerate(train_loader), total=len(train_loader))
        for step, (img, label) in pbar:
            img = img.to(device).float()
            label = label.to(device).long()

            output = model(img)
            loss = loss_fn(output, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            acc = (output.argmax(dim=1) == label).float().mean()
            epoch_accuracy += acc / len(train_loader)
            epoch_loss += loss / len(train_loader)

        print(f'==========Epoch {epoch+1} Start Validation==========')
        with torch.no_grad():
            epoch_val_accuracy = 0
            epoch_val_loss = 0
            val_labels = []
            val_preds = []

            pbar = tqdm(enumerate(valid_loader), total=len(valid_loader))
            for step, (img, label) in pbar:
                img = img.to(device).float()
                label = label.to(device).long()

                val_output = model(img)
                val_loss = loss_fn(val_output, label)

                acc = (val_output.argmax(dim=1) == label).float().mean()
                epoch_val_accuracy += acc / len(valid_loader)
                epoch_val_loss += val_loss / len(valid_loader)

                val_labels += [label.detach().cpu().numpy()]
                val_preds += [torch.argmax(val_output, 1).detach().cpu().numpy()]
            
            val_labels = np.concatenate(val_labels)
            val_preds = np.concatenate(val_preds)   
        
        # print result from this epoch
        exec_t = int((time.time() - time_start)/60)
        print(
            f'Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f} / Exec time {exec_t} min\n'
        )

        # For visualization
        train_accs.append(epoch_accuracy.cpu().numpy())
        valid_accs.append(epoch_val_accuracy.cpu().numpy())
        train_losses.append(epoch_loss.detach().cpu().numpy())
        valid_losses.append(epoch_val_loss.detach().cpu().numpy())
    
    train_acc_list.append(train_accs)
    valid_acc_list.append(valid_accs)
    train_loss_list.append(train_losses)
    valid_loss_list.append(valid_losses)
    del model, optimizer, train_loader, valid_loader, train_accs, valid_accs, train_losses, valid_losses
    gc.collect()
    torch.cuda.empty_cache()
    
# 学習曲線のプロット。定義はDefine Helper Functions内にて。
show_validation_score(train_acc_list, train_loss_list, valid_acc_list, valid_loss_list)

上図はEPOCH=10としたときの結果。まだ学習が収束していません。
実践ではvalidationのlossが下がらなくなるまでEpochを重ねますが、本稿はこのまま進めます。

全データで学習しモデルを保存する

交差検証で最適なエポック数とハイパーパラメータを見つけたら、その情報を使ってモデルを全てのデータで訓練し、結果を保存します。
モデルの保存には2つの方法があり、今回は推奨されているstate_dict()を用いた方法を使います。

コードは交差検証とほぼ同じなので省略します[14]

コンペで意識することはほぼないでしょうが、実務で使う際には新しいデータが手に入ったり、タスクの要件が変わったりすると、再調整や再トレーニングが必要になる可能性があることを覚えておく必要があります。

推論を行う

最後に、新しいデータを与えて推論を行ってみましょう。

submission_df = pd.DataFrame()
submission_df['image_id'] = list(os.listdir(DATA_DIR+'test_images/'))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
print(f'Using {device} device')

# Load Data
test_datasets = CassavaDataset(submission_df, DATA_DIR+'test_images/', transforms=transform_valid(), give_label=False)

# Data Loader
test_loader = DataLoader(test_datasets, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False, multiprocessing_context='fork')

# Load model, loss function, and optimizing algorithm
model = EfficientNet_V2(NUM_CLASSES).to(device)
model.load_state_dict(torch.load(MODEL_DIR+'classification.pth'))

# Start Inference
print(f'==========Start Inference==========')
with torch.no_grad():
    test_preds = []
    pbar = tqdm(enumerate(test_loader), total=len(test_loader))
    for step, img in pbar:
        img = img.to(device).float()
        test_output = model(img)
        test_preds += [torch.argmax(test_output, 1).detach().cpu().numpy()]
    test_preds = np.concatenate(test_preds)
submission_df['label'] = test_preds
submission_df.head()

そして提出。

submission_df.to_csv('submission.csv', index=False)

現在の正解率は0.7程度と、「全部3!」と答えた場合より少し良い程度です。そもそも学習サイクル (Epoch数) が足りていませんね。

精度を上げるために

画像の読み込みから推論まで一通り動き、分類モデルはようやくスタートラインに立ったというところです。
ここから精度を上げる様々な工夫を加えていくことになります。例えば

  • 画像前処理
    • 画像の次元削減、特徴量の抽出、不要な背景の除去
    • 誤ラベル付けされた画像の修正または削除
    • Augmentationの見直し
      • 生成するパターンを増やす
      • Augmentationの結果、変な画像が生成されていないかを確認する
  • 不均衡データの対応
    • アップサンプリング/ダウンサンプリング
    • 損失関数の変更
  • モデルの最適化
    • Epoch数を増やす
    • 'efficientnetv2_l'などのより大きなアーキテクチャへ変更
    • 異なるアーキテクチャの適用
    • インプットサイズの拡大
    • ハイパーパラメータの調整
  • 予測結果からのフィードバック
    • 混同行列分析の作成
    • 誤ラベル付けされたデータ、されやすいラベルの特定
  • 複数モデルのアンサンブル
  • バグがないかの確認
  • その他[15]

などが考えられますね。各アプローチがモデルの精度にどういう影響を与えるかを評価することで、特定のタスクやデータセットに対して、モデルの精度を向上させるための最適なアプローチを特定していきます[16]

次回はDetection Transformer (DETR)を用いたObject Detection (物体認識) モデルの実装を行います。
本稿がみなさまの分類モデル構築、ひいてはコンピュータビジョンモデリング入門に役立つことを祈っています。

謝辞

本シリーズの執筆にあたり、えんがわAI研究所 所長まっちゃんさん、相場雅彰さんにご協力いただきました。

参考情報

修正履歴

  • 2023/3/16
    • シリーズ他記事へのリンクを追加
    • コード微修正
    • コード修正に伴う出力結果の差し替え&文章の書き換え
  • 2023/3/10
    • シリーズ他記事へのリンクを追加
脚注
  1. まだまだ勉強中のため、本稿には誤りや無駄も多いと思います。厳しいご意見をいただけると嬉しいです。 ↩︎

  2. 他の方法については、以降の記事・ノートブックで取り上げます。 ↩︎

  3. ラベル付けが雑で間違っている写真も多く、それ故「運ゲーコンペ」と揶揄されたりもしているようです。ハイスコアを狙うために「ラベルを付け直す」作業をした方もおられた模様。 ↩︎

  4. 実務的にも、学習と推論は分けたほうが使いやすいケースが多いかもしれません。 ↩︎

  5. 例えば、「画像の向きは重要か?」など。 ↩︎

  6. Qiita - Albumentationsのaugmentationをひたすら動かす。Albumenationsを用いることでどのような処理が行えるかはこちらの記事が参考になります。 ↩︎

  7. このパッケージはバージョンによって書き方が結構変わる。本稿では1.3.0を使用。 ↩︎

  8. 分類タスクでフォルダを自由に操作できる場合はtorch.datasets.ImageFolderを使う方が簡単ですが、ここでは自前で定義します。 ↩︎

  9. M1 mac でmultiprocessに失敗する問題の対処法 ↩︎

  10. 他の方法については、以降の記事・ノートブックで取り上げます。 ↩︎

  11. ライブラリ収録モデルを用いる利点は「簡単に使える」に尽きますね。Kaggleのインターネットオフコンペでも簡単に使える点は大きな強みかと思います。 ↩︎

  12. 本稿では簡単のために1パターンしか示しませんが、実践ではMLFlow等の実験管理ツールを使いつつ、異なるモデルアーキテクチャ、異なるハイパーパラメータから得られる結果を比較します。 ↩︎

  13. 全てのデータで学習し直すべきか、交差検証の過程で作られた複数のモデルを用いる(推論結果は複数モデルのアンサンブルとする)べきかは議論があります。実践的には推論時間や実装の手間、利用可能な計算リソースを考えていずれのアプローチを取るか決定することになるかと思います。参考 ↩︎

  14. 実践では、管理のしやすさも考えると、交差検証と合わせて一つのスクリプト・セルにしてしまうのが良いかと思います。if文でどちらのパターンで学習するかを切り替えるイメージ。 ↩︎

  15. KaggleのDiscussionは知見の宝庫。 ↩︎

  16. ここでもMLFlow等の実験管理ツールが活躍します。 ↩︎

Aidemy Tech Blog

Discussion