nnUNetとMONAI事前学習済みモデルの使い方
はじめに
先日、Kaggle での RSNA2025 のコンペが終了しました。自分も参加してたのですがまとまらずサブもせずに終わってしまいました。動脈瘤自体は血管にありますが、小さく RoI のクロップによる解像度の確保がキーポイントだったと思います。
そこで、セグメンテーションによる血管領域の RoI クロップを試みていましたが、timm を使った 2.5D モデルで組んだ UNet では思いの外精度が伸びず泥沼にハマりました。
共有されてるソリューションなどを見ると、MONAI や nnUNet といった医療画像での事前学習済みモデル、特に脳の血管セグメンテーションで学習されたモデルを使うとうまくいっているようです。
確かにデータにもかなりノイズがあったため、過学習しないように精度をある程度出すのは難しかったので理にかなってると思いますし結果も出てますし 🍣。惨敗でした。
MONAI は使ってましたが、MONAI で学習済みモデルを使うことや nnUNet は自分にはなかった選択肢だったので、復習と将来の自分に向けて nnUNet の使い方 と MONAI で事前学習済みモデルを使う方法をまとめておきます。 [1] [2]
想定読者:
- 未来の自分
- nnUNet の使い方がわからない人
- MONAI で事前学習済みモデルの使い方がわからない人
筆者は、以下の環境で動作確認しています。
- host os: ubuntu24.04
- nvidia-driver: 575.57.08
- CUDA: 12.9
- python: 3.11.13 (installed via uv)
- uv: 0.9.1
- nnUNet: 2.6.2
- MONAI: 1.5.0
nnUNet
nnUNet とは
UNet をデータセットに合わせて、自動で最適化してくれるツールです。 DKFZ の Fabian Isensee 氏が開発しています。現在は v2 がリリースされていて、pypi からインストールすることもできます。
cli も提供されていて、データセットの構造を整えれば、あとはコマンドで学習や推論が可能で、推論のみコードに組み込むことも可能です。
ここでは、学習までの一通りの流れと、学習済みモデルを推論パイプラインに組み込んで使える部分までを紹介します。
現在の最新バージョンは v2.6.2 です。
ソースを改変しながら使いたい人は editable install をすることになると思いますが、clone した後は tag がきられてるので switch/checkout してからインストールをお勧めします。
インストール
pip でインストールできます。 ここでは手軽な v2 をインストールする方法を紹介します。
pip install nnunetv2
uv だとこうなります。
uv add nnunetv2
editable install したい場合は、参考資料の [3] を参照してください。
使い方
install すると、cli で nnUNetv2_XXX というコマンドが使えるようになります。
学習を開始する前にコマンドが使うデータセットや前処理済みのデータセット、学習結果のアセットを置くディレクトリを環境変数で指定します。
ソースコードを見てる限り、名前はなんでもいいと思いますが、自分は nnUNet_raw
, nnUNet_preprocessed
, nnUNet_results
としています。
export nnUNet_raw="/path/to/nnUNet_raw"
export nnUNet_preprocessed="/path/to/nnUNet_preprocessed"
export nnUNet_results="/path/to/nnUNet_results"
データセットの構造は以下のようにします。
ここで、DATASETXXX_MyDataset
の XXX
は dataset_id で、学習や推論のコマンドで指定します。 MyDataset
は任意の名前で構いません。
dataset.json の書き方は v1 と v2 で異なるようなので後述します.
nnUNet_raw
└───DATASETXXX_MyDataset
├───imagesTr
│ case_000_0000.nii.gz
│ case_001_0000.nii.gz
│ ...
├───imagesTs
│ case_000_0000.nii.gz
│ case_001_0000.nii.gz
│ ...
└───labelsTr
│ case_000.nii.gz
│ case_001.nii.gz
│ ...
└───dataset.json
dataset.json は以下のように書きます。
{
"name": "",
"description": "",
"reference": "",
"licence": "",
"release": "",
"tensorImageSize": "3D",
"channel_names": {
"0": "MRI"
},
"labels": {
"background": "0",
"anterior": "1",
"posterior": "2"
},
"numTraining": 260,
"numTest": 130,
"file_ending": ".nii.gz"
}
大事なのは以下の属性です。
- channel_names:
- 画像のチャネル名を指定します。CT なら "CT" 、MRI なら "MRI" とします。複数チャネルある場合は "0": "T1", "1": "T2" のようにします。
- これが、nii.gz ファイルの末尾の
_0000
,_0001
の部分と対応します。この場合は、_0000
が MRI のチャネルになります。
- labels:
- セグメンテーションのラベルを指定します。背景が 0 で、以降のラベルを 1, 2, ... とします。
- key と value の書き方が v1 と逆になっているので注意してください。AI による補完で逆に書かれやすく、わかりにくいエラーに悩まされます。
- numTraining:
- 学習データの数を指定します。
- numTest:
- テストデータの数を指定します。
- file_ending:
- 画像ファイルの拡張子を指定します。
その後、学習を行うためには以下の二つのコマンドを実行します
- nnUNetv2_plan_and_preprocess
- nnUNetv2_train
学習や前処理のプランと前処理の実行は nnUNetv2_plan_and_preprocess
のコマンドで行います。
nnUNetv2_plan_and_preprocess --verify_dataset_integrity -d <dataset_id>
画像の spacing を変えたい時などは以下のようにオプションを追加します。
# spacing を指定 ここでは(D,H,W)想定
# defaultはNoneでデータセットのspacingのmedianが使われる
nnUNetv2_plan_and_preprocess \
--verify_dataset_integrity \
--verbose \
-d 004 \
-overwrite_target_spacing 0.8 0.45 0.44
他に指定したいオプションがある時は、v2.6.2 の時点では参考資料の [5] に様々なオプションが定義されているので、コードを参照するのがいいと思います。
学習の実行は nnUNetv2_train
のコマンドで行います。
--npz
オプションは最後の validation の epoch での softmax 出力を保存するオプションです。後で nnUNetv2_find_best_configuration
などをする場合に必要になるので、基本的にはつけておくのが良いと思います。 ディスク容量を食うので注意してください。
nnUNetv2_train <dataset_name_or_id> <configuration> <fold> --npz
3D での UNet を使う場合は、以下のように、configuration に 3d_fullres
を指定します。fold はクロスバリデーションの分割数を指定します。0-4 の 5 分割で学習されます。
nnUNetv2_train 004 3d_fullres 0 --npz
学習に使うモデルのアーキテクチャを変える時は planner を変更します。
nnUNetv2_plan_and_preprocess --verify_dataset_integrity -d <dataset_id> -pl nnUNetPlannerResEncM
# or
nnUNetv2_plan_and_preprocess --verify_dataset_integrity -d <dataset_id> -pl nnUNetPlannerResEncL
# or
nnUNetv2_plan_and_preprocess --verify_dataset_integrity -d <dataset_id> -pl nnUNetPlannerResEncXL
nnUNetv2_train 004 3d_fullres 0 --npz
使える Planner は参考資料の [6] を参照してください。
class の名前を -pl
に指定することで、plan_experiments 関数内で class の type を取得して使われます。
または、nnUNet_preprocessed
の中にある nnUNetPlans.json
を編集しても変えられるようです。(こちらは未検証)
50 epoch ごとに checkpoint が保存されます。さらに学習を続けたい時は --c
オプションを使うようです。
nnUNetv2_train 004 3d_fullres 0 --npz --c
他の configuration などは、参考資料の [4] の Model training を参照してください。
このコマンドが成功すると、nnUNet_results
の中に学習結果が保存されます。
v2.6.2 の時点では学習済みのモデルを推論パイプラインに組み込む時は以下のように、nnUNetPredictor
を使って書きます。
import torch
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
# network属性にロードされたモデルが入ってる
predictor = nnUNetPredictor()
# initialize_from_trained_model_folder で学習済みモデルをロード.
# これはメソッドなのでPredictorのインスタンスを生成した後に呼び出す、かつ返り値がNoneなのに注意
predictor.initialize_from_trained_model_folder(
# /path/to/nnUnet_results/DatasetXXX_MyDataset/<trainer>__<planner>__<configuration>/
model_training_output_dir="nnUNet_results/Dataset004_Hippocampus/nnUNetTrainer__nnUNetPlans__3d_fullres",
# 読み込むモデルのfoldを指定
use_folds=(0,),
# 学習済みモデルのチェックポイントファイル名を指定, checkpoint_best.pth, checkpoint_latest.pthなどがある
checkpoint_name="checkpoint_latest.pth",
)
# PlainConvUNetなら属性はencoder, decoder
print(predictor.network)
assert isinstance(predictor.network, torch.nn.Module)
print(dict(predictor.network.named_children()).keys())
print(predictor.network.encoder)
print(predictor.network.decoder)
MONAI
MONAI とは
医療画像に特化した PyTorch ベースのツールです。transform/network/layers/data/losses/metrics/bundle などを提供しています。
transforms や network, layers, losses, metrics などは使ったことがあるひとも多いと思いますが、事前学習済みモデルの使い方は知りませんでした。
そこで今回は bundle を使って、BraTS で事前学習された SegResNet を使う方法を紹介します。
nnUNet よりは使うのは簡単でしたので、はじめにコード全体を示します。
雑なコードですが、お許しください。
使い方
import torch
from monai.bundle import download
from monai.networks.nets import SegResNet
def download_brats_bundle(name="brats_mri_segmentation", target_dir="./bundles"):
"""
MONAI Bundle をダウンロードする(モデル + チェックポイント含む)
"""
# download 関数を使って、bundle を取得(zip などで展開される)
download(name=name, bundle_dir=target_dir)
def load_checkpoint_from_bundle(ckpt_path):
"""
ダウンロードしたバンドルからチェックポイントファイル (.pt / .pth) を探して読み込む
"""
checkpoint = torch.load(ckpt_path, map_location=torch.device("cpu"))
return checkpoint
def build_model(in_channels=4, out_channels=3):
"""
モデル定義(SegResNet 例)を構築
"""
model = SegResNet(
spatial_dims=3,
in_channels=in_channels,
out_channels=out_channels,
init_filters=16,
dropout_prob=0.2,
)
return model
def main():
bundle_id = "brats_mri_segmentation"
target_dir = "./bundles"
download_brats_bundle(bundle_id, target_dir)
checkpoint = load_checkpoint_from_bundle(
f"{target_dir}/{bundle_id}/models/model.pt"
)
# SegResNet の場合、in_channels=4, out_channels=3 で学習されている
model = build_model(in_channels=1, out_channels=1)
# チャンネル数が合わないので、timm-likeに平均化して変換
# meanを使うことで事前学習の重み情報をある程度活用できる
checkpoint_new = {}
for k, v in checkpoint.items():
if k == "convInit.conv.weight":
checkpoint_new[k] = v.mean(dim=1, keepdim=True) # 4->1チャネルに変換
elif k == "conv_final.2.conv.weight":
checkpoint_new[k] = v.mean(dim=0, keepdim=True) # 3->1チャネルに変換
elif k == "conv_final.2.conv.bias":
checkpoint_new[k] = v.mean(dim=0, keepdim=True) # 3->1チャネルに変換
else:
checkpoint_new[k] = v
print(model.load_state_dict(checkpoint_new, strict=False))
model.eval()
print("Model loaded and ready.")
print(model)
# 推論例(ダミー入力)
dummy = torch.randn(1, 1, 240, 240, 160)
with torch.no_grad():
out = model(dummy)
print("Output shape:", out.shape)
以上のように、bundle.download
関数でバンドルをダウンロードし、torch.load
でチェックポイントを読み込み、モデルを定義して load_state_dict
で重みをロードします。
今回使用した、brats_mri_segmentation バンドルは、4 チャネルの MRI 画像を入力に、3 クラスのセグメンテーションを出力する SegResNet です。
事前学習済みの重み自体は License: Apache-2.0 で公開されているようですがデータセットの BraTS2018 自体のライセンスは非商用利用で引用する場合のみに使用が限られているので注意してください。
新たに学習させたものを商用利用や配布などはライセンスに違反すると思われます。 参考資料[7]
今回使用したモデルの重み自体のライセンスは MONAI チームが BraTS のデータセットの公開元に許諾を得たのでしょうか?わかりません。
まとめ
nnUNet の使い方と MONAI で事前学習済みモデルを使う方法を紹介しました。
筆者はむしゃくしゃしてやったと述べておりますが、これで次に同じような場面でできなかったなどという言い訳はできなくなりました。
医療画像自体は、わからないことが多く非常に勉強になりました。これで次は勝てますね 🍣
使用したコードは整理した後、公開予定です。
TODO: 公開後リンク追記
参考資料
- [1] nnUNet, MIC-DKFZ, GitHub, https://github.com/MIC-DKFZ/nnUNet
- [2] MONAI, Project-MONAI, GitHub, https://github.com/Project-MONAI/MONAI
- [3] nnUNet Installation instruction, https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/installation_instructions.md
- [4] nnUNet how to use nnunet, https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/how_to_use_nnunet.md
- [5] https://github.com/MIC-DKFZ/nnUNet/blob/8c4184d46b60059ff7dc8f74cd535e13554bdeca/nnunetv2/experiment_planning/plan_and_preprocess_entrypoints.py
- [6] https://github.com/MIC-DKFZ/nnUNet/tree/8c4184d46b60059ff7dc8f74cd535e13554bdeca/nnunetv2/experiment_planning/experiment_planners
- [7] https://huggingface.co/MONAI/brats_mri_segmentation/blob/main/docs/README.md
Discussion