Open2
SRPose (PoseEstimation) の環境構築 (mmpose)
Dockerfile
FROM ubuntu:22.04
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update \
&& apt-get upgrade -y \
&& apt-get install -y \
software-properties-common \
&& add-apt-repository ppa:deadsnakes/ppa \
&& apt-get install -y \
python3.9 \
python3.9-venv \
&& update-alternatives \
--install /usr/bin/python python /usr/bin/python3.9 130 \
&& apt-get install -y \
nano \
wget \
curl \
sudo \
git \
build-essential \
libssl-dev \
libffi-dev \
python3.9-dev \
libgl1-mesa-dev \
&& apt clean \
&& rm -rf /var/lib/apt/lists/* \
&& sed -i 's/# set linenumbers/set linenumbers/g' /etc/nanorc \
&& curl -kL https://bootstrap.pypa.io/get-pip.py | python \
&& pip install pip -U \
&& pip install \
numpy==1.23.2 \
matplotlib==3.4.1 \
opencv-python==4.5.2.52 \
pycocotools==2.0.6 \
scikit-image==0.18.2 \
scipy==1.9.0 \
tensorboard==2.5.0 \
protobuf==3.20.* \
tqdm==4.60.0 \
yacs==0.1.8 \
einops==0.3.0 \
Cython==3.0.4 \
onnx==1.14.1 \
onnxruntime==1.16.1 \
onnxsim==0.4.33 \
mmcv==1.6.1 \
xtcocotools==1.14.3 \
json_tricks==3.17.3 \
timm==0.9.8 \
munkres==1.1.4 \
&& pip install git+https://github.com/svenkreiss/poseval.git \
&& pip install torch torchvision torchaudio \
--index-url https://download.pytorch.org/whl/cpu
ENV USERNAME=user
RUN echo "root:root" | chpasswd \
&& adduser --disabled-password --gecos "" "${USERNAME}" \
&& echo "${USERNAME}:${USERNAME}" | chpasswd \
&& echo "%${USERNAME} ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers.d/${USERNAME} \
&& chmod 0440 /etc/sudoers.d/${USERNAME}
USER ${USERNAME}
ARG WKDIR=/workdir
WORKDIR ${WKDIR}
RUN sudo chown ${USERNAME}:${USERNAME} ${WKDIR}
RUN git clone https://github.com/uyoung-jeong/CrowdPose.git \
&& cd CrowdPose/crowdpose-api/PythonAPI \
&& python setup.py install --user
RUN git clone https://github.com/cocodataset/cocoapi.git \
&& cd cocoapi/PythonAPI \
&& make install \
&& python setup.py install --user
pytorch2onnx.py
import os
import sys
import argparse
import mmcv
import numpy as np
import torch
from mmcv.runner import load_checkpoint
sys.path.append('/workspaces/SRPose')
from mmpose.models import build_posenet
try:
import onnx
import onnxruntime as rt
except ImportError as e:
raise ImportError(f'Please install onnx and onnxruntime first. {e}')
# try:
# from mmcv.onnx.symbolic import register_extra_symbolics
# except ModuleNotFoundError:
# raise NotImplementedError('please update mmcv to version>=1.0.4')
def _convert_batchnorm(module):
"""Convert the syncBNs into normal BN3ds."""
module_output = module
if isinstance(module, torch.nn.SyncBatchNorm):
module_output = torch.nn.BatchNorm3d(
module.num_features, module.eps,
module.momentum, module.affine,
module.track_running_stats
)
if module.affine:
module_output.weight.data = module.weight.data.clone().detach()
module_output.bias.data = module.bias.data.clone().detach()
# keep requires_grad unchanged
module_output.weight.requires_grad = module.weight.requires_grad
module_output.bias.requires_grad = module.bias.requires_grad
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
for name, child in module.named_children():
module_output.add_module(name, _convert_batchnorm(child))
del module
return module_output
def pytorch2onnx(
model,
input_shape,
opset_version=11,
show=False,
output_file='tmp.onnx',
verify=False
):
"""Convert pytorch model to onnx model.
Args:
model (:obj:`nn.Module`): The pytorch model to be exported.
input_shape (tuple[int]): The input tensor shape of the model.
opset_version (int): Opset version of onnx used. Default: 11.
show (bool): Determines whether to print the onnx model architecture.
Default: False.
output_file (str): Output onnx model name. Default: 'tmp.onnx'.
verify (bool): Determines whether to verify the onnx model.
Default: False.
"""
model.cpu().eval()
one_img = torch.randn(input_shape)
# register_extra_symbolics(opset_version)
torch.onnx.export(
model,
one_img,
output_file,
export_params=True,
keep_initializers_as_inputs=True,
verbose=show,
opset_version=opset_version)
print(f'Successfully exported ONNX model: {output_file}')
if verify:
# check by onnx
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)
# check the numerical value
# get pytorch output
pytorch_result = model(one_img).detach().numpy()
# get onnx output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [
node.name for node in onnx_model.graph.initializer
]
net_feed_input = list(set(input_all) - set(input_initializer))
assert len(net_feed_input) == 1
sess = rt.InferenceSession(output_file)
onnx_result = sess.run(None, {net_feed_input[0]: one_img.detach().numpy()})[0]
# only compare part of results
assert np.allclose(
pytorch_result, onnx_result,
atol=1.e-5), 'The outputs are different between Pytorch and ONNX'
print('The numerical values are same between Pytorch and ONNX')
def parse_args():
parser = argparse.ArgumentParser(description='Convert MMPose models to ONNX')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--show', action='store_true', help='show onnx graph')
parser.add_argument('--output-file', type=str, default='tmp.onnx')
parser.add_argument('--opset-version', type=int, default=11)
parser.add_argument(
'--verify',
action='store_true',
help='verify the onnx model output against pytorch output')
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[1, 3, 256, 192],
help='input size')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
assert args.opset_version == 11, 'MMPose only supports opset 11 now'
cfg = mmcv.Config.fromfile(args.config)
# build the model
model = build_posenet(cfg.model)
model = _convert_batchnorm(model)
# onnx.export does not support kwargs
if hasattr(model, 'forward_dummy'):
model.forward = model.forward_dummy
else:
raise NotImplementedError(
'Please implement the forward method for exporting.')
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
# conver model to onnx file
pytorch2onnx(
model,
args.shape,
opset_version=args.opset_version,
show=args.show,
output_file=f'{os.path.splitext(os.path.basename(args.checkpoint))[0]}_{args.shape[0]}x{args.shape[1]}x{args.shape[2]}x{args.shape[3]}.onnx',
verify=args.verify)