Open4

EfficientNet の Swish が含まれたPyTorchモデルのONNXエクスポート

PINTOPINTO
model = EfficientNet.from_name(model_name='efficientnet-b0')
model.set_swish(memory_efficient=False)
torch.onnx.export(model, torch.rand(10,3,240,240), "EfficientNet-B0.onnx")
PINTOPINTO
  • 2023/03/03現在、普通にエクスポートできるようになってた(torchvision の models からなら)
import torch
import torchvision.models as models

model = models.efficientnet_b0(pretrained=True)
model.eval()
onnx_file = f'efficiennnet_b0_11.onnx'
x = torch.randn([1,3,224,224])
torch.onnx.export(
    model,
    args=(x),
    f=onnx_file,
    opset_version=11,
    input_names=[
        'input',
    ],
    output_names=[
        'output',
    ],
)
import onnx
from onnxsim import simplify
model_onnx2 = onnx.load(onnx_file)
model_simp, check = simplify(model_onnx2)
onnx.save(model_simp, onnx_file)
PINTOPINTO
  • 2023/03/03現在、普通にエクスポートできるようになってた(efficientnet_pytorch からも同様に)
from efficientnet_pytorch import EfficientNet

model = EfficientNet.from_name(model_name='efficientnet-b0')
model.eval()
onnx_file = f'efficientnet_b0_11_efficientnet_pytorch.onnx'
x = torch.randn([1,3,224,224])
torch.onnx.export(
    model,
    args=(x),
    f=onnx_file,
    opset_version=11,
    input_names=[
        'input',
    ],
    output_names=[
        'output',
    ],
)
import onnx
from onnxsim import simplify
model_onnx2 = onnx.load(onnx_file)
model_simp, check = simplify(model_onnx2)
onnx.save(model_simp, onnx_file)