Open2

SRPose (PoseEstimation) の環境構築 (mmpose)

PINTOPINTO
Dockerfile
FROM ubuntu:22.04

ENV DEBIAN_FRONTEND=noninteractive

RUN apt-get update \
    && apt-get upgrade -y \
    && apt-get install -y \
        software-properties-common \
    && add-apt-repository ppa:deadsnakes/ppa \
    && apt-get install -y \
        python3.9 \
        python3.9-venv \
    && update-alternatives \
        --install /usr/bin/python python /usr/bin/python3.9 130 \
    && apt-get install -y \
        nano \
        wget \
        curl \
        sudo \
        git \
        build-essential \
        libssl-dev \
        libffi-dev \
        python3.9-dev \
        libgl1-mesa-dev \
    && apt clean \
    && rm -rf /var/lib/apt/lists/* \
    && sed -i 's/# set linenumbers/set linenumbers/g' /etc/nanorc \
    && curl -kL https://bootstrap.pypa.io/get-pip.py | python \
    && pip install pip -U \
    && pip install \
        numpy==1.23.2 \
        matplotlib==3.4.1 \
        opencv-python==4.5.2.52 \
        pycocotools==2.0.6 \
        scikit-image==0.18.2 \
        scipy==1.9.0 \
        tensorboard==2.5.0 \
        protobuf==3.20.* \
        tqdm==4.60.0 \
        yacs==0.1.8 \
        einops==0.3.0 \
        Cython==3.0.4 \
        onnx==1.14.1 \
        onnxruntime==1.16.1 \
        onnxsim==0.4.33 \
        mmcv==1.6.1 \
        xtcocotools==1.14.3 \
        json_tricks==3.17.3 \
        timm==0.9.8 \
        munkres==1.1.4 \
    && pip install git+https://github.com/svenkreiss/poseval.git \
    && pip install torch torchvision torchaudio \
        --index-url https://download.pytorch.org/whl/cpu

ENV USERNAME=user
RUN echo "root:root" | chpasswd \
    && adduser --disabled-password --gecos "" "${USERNAME}" \
    && echo "${USERNAME}:${USERNAME}" | chpasswd \
    && echo "%${USERNAME}    ALL=(ALL)   NOPASSWD:    ALL" >> /etc/sudoers.d/${USERNAME} \
    && chmod 0440 /etc/sudoers.d/${USERNAME}
USER ${USERNAME}
ARG WKDIR=/workdir
WORKDIR ${WKDIR}
RUN sudo chown ${USERNAME}:${USERNAME} ${WKDIR}

RUN git clone https://github.com/uyoung-jeong/CrowdPose.git \
    && cd CrowdPose/crowdpose-api/PythonAPI \
    && python setup.py install --user

RUN git clone https://github.com/cocodataset/cocoapi.git \
    && cd cocoapi/PythonAPI \
    && make install \
    && python setup.py install --user
PINTOPINTO
pytorch2onnx.py
import os
import sys
import argparse

import mmcv
import numpy as np
import torch
from mmcv.runner import load_checkpoint
sys.path.append('/workspaces/SRPose')
from mmpose.models import build_posenet

try:
    import onnx
    import onnxruntime as rt
except ImportError as e:
    raise ImportError(f'Please install onnx and onnxruntime first. {e}')

# try:
#     from mmcv.onnx.symbolic import register_extra_symbolics
# except ModuleNotFoundError:
#     raise NotImplementedError('please update mmcv to version>=1.0.4')


def _convert_batchnorm(module):
    """Convert the syncBNs into normal BN3ds."""
    module_output = module
    if isinstance(module, torch.nn.SyncBatchNorm):
        module_output = torch.nn.BatchNorm3d(
            module.num_features, module.eps,
            module.momentum, module.affine,
            module.track_running_stats
        )
        if module.affine:
            module_output.weight.data = module.weight.data.clone().detach()
            module_output.bias.data = module.bias.data.clone().detach()
            # keep requires_grad unchanged
            module_output.weight.requires_grad = module.weight.requires_grad
            module_output.bias.requires_grad = module.bias.requires_grad
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
    for name, child in module.named_children():
        module_output.add_module(name, _convert_batchnorm(child))
    del module
    return module_output


def pytorch2onnx(
    model,
    input_shape,
    opset_version=11,
    show=False,
    output_file='tmp.onnx',
    verify=False
):
    """Convert pytorch model to onnx model.

    Args:
        model (:obj:`nn.Module`): The pytorch model to be exported.
        input_shape (tuple[int]): The input tensor shape of the model.
        opset_version (int): Opset version of onnx used. Default: 11.
        show (bool): Determines whether to print the onnx model architecture.
            Default: False.
        output_file (str): Output onnx model name. Default: 'tmp.onnx'.
        verify (bool): Determines whether to verify the onnx model.
            Default: False.
    """
    model.cpu().eval()

    one_img = torch.randn(input_shape)

    # register_extra_symbolics(opset_version)
    torch.onnx.export(
        model,
        one_img,
        output_file,
        export_params=True,
        keep_initializers_as_inputs=True,
        verbose=show,
        opset_version=opset_version)

    print(f'Successfully exported ONNX model: {output_file}')
    if verify:
        # check by onnx
        onnx_model = onnx.load(output_file)
        onnx.checker.check_model(onnx_model)

        # check the numerical value
        # get pytorch output
        pytorch_result = model(one_img).detach().numpy()

        # get onnx output
        input_all = [node.name for node in onnx_model.graph.input]
        input_initializer = [
            node.name for node in onnx_model.graph.initializer
        ]
        net_feed_input = list(set(input_all) - set(input_initializer))
        assert len(net_feed_input) == 1
        sess = rt.InferenceSession(output_file)
        onnx_result = sess.run(None, {net_feed_input[0]: one_img.detach().numpy()})[0]
        # only compare part of results
        assert np.allclose(
            pytorch_result, onnx_result,
            atol=1.e-5), 'The outputs are different between Pytorch and ONNX'
        print('The numerical values are same between Pytorch and ONNX')


def parse_args():
    parser = argparse.ArgumentParser(description='Convert MMPose models to ONNX')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument('--show', action='store_true', help='show onnx graph')
    parser.add_argument('--output-file', type=str, default='tmp.onnx')
    parser.add_argument('--opset-version', type=int, default=11)
    parser.add_argument(
        '--verify',
        action='store_true',
        help='verify the onnx model output against pytorch output')
    parser.add_argument(
        '--shape',
        type=int,
        nargs='+',
        default=[1, 3, 256, 192],
        help='input size')
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()

    assert args.opset_version == 11, 'MMPose only supports opset 11 now'

    cfg = mmcv.Config.fromfile(args.config)
    # build the model
    model = build_posenet(cfg.model)
    model = _convert_batchnorm(model)

    # onnx.export does not support kwargs
    if hasattr(model, 'forward_dummy'):
        model.forward = model.forward_dummy
    else:
        raise NotImplementedError(
            'Please implement the forward method for exporting.')

    checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')

    # conver model to onnx file
    pytorch2onnx(
        model,
        args.shape,
        opset_version=args.opset_version,
        show=args.show,
        output_file=f'{os.path.splitext(os.path.basename(args.checkpoint))[0]}_{args.shape[0]}x{args.shape[1]}x{args.shape[2]}x{args.shape[3]}.onnx',
        verify=args.verify)