📖

PyTorchのモデルをtfliteに変換する方法

2022/07/22に公開

概要

PyTorchで学習したモデルをTensorFlow Liteモデルに変換する方法。

直接変換することはできないので

  • PyTroch → onnx → tensorflow → tflite

の順に変換していく。

環境

  • Google Colaboratory

実装

PyTorchのモデル設定

設定値

「content」直下に学習済みの.pthファイルをアップロードしておく。

batch_size = 128
channel = 3
width = 32
height = 32

path_pth = "/content/model.pth"
path_onnx = "/content/model.onnx"
path_tf = "/content/tf_model"
path_tflite = "/content/model.tflite"

モデルの定義

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

PyTorch => onnx

conver_model = Net()
conver_model.load_state_dict(torch.load(path_pth, map_location='cpu'))
conver_model.eval()

sample_input = torch.rand((batch_size, channel, width, height))

torch.onnx.export(
    conver_model,
    sample_input,
    "model.onnx",
    opset_version=12,
    input_names=['input'],
    output_names=['output']
)

onnx => tensorflow

!git clone https://github.com/onnx/onnx-tensorflow.git
%cd onnx-tensorflow
!pip install -e .
import onnx

onnx_model = onnx.load(path_onnx)
from onnx_tf.backend import  prepare

tf_rep = prepare(onnx_model)
tf_rep.export_graph(path_tf)

tensorflow => tflite

converter = tf.lite.TFLiteConverter.from_saved_model(path_tf)
tflite_model = converter.convert()

with open(path_tflite, 'wb') as f:
    f.write(tflite_model)

Discussion