📖
PyTorchのモデルをtfliteに変換する方法
概要
PyTorchで学習したモデルをTensorFlow Liteモデルに変換する方法。
直接変換することはできないので
- PyTroch → onnx → tensorflow → tflite
の順に変換していく。
環境
- Google Colaboratory
実装
PyTorchのモデル設定
設定値
「content」直下に学習済みの.pthファイルをアップロードしておく。
batch_size = 128
channel = 3
width = 32
height = 32
path_pth = "/content/model.pth"
path_onnx = "/content/model.onnx"
path_tf = "/content/tf_model"
path_tflite = "/content/model.tflite"
モデルの定義
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
PyTorch => onnx
conver_model = Net()
conver_model.load_state_dict(torch.load(path_pth, map_location='cpu'))
conver_model.eval()
sample_input = torch.rand((batch_size, channel, width, height))
torch.onnx.export(
conver_model,
sample_input,
"model.onnx",
opset_version=12,
input_names=['input'],
output_names=['output']
)
onnx => tensorflow
!git clone https://github.com/onnx/onnx-tensorflow.git
%cd onnx-tensorflow
!pip install -e .
import onnx
onnx_model = onnx.load(path_onnx)
from onnx_tf.backend import prepare
tf_rep = prepare(onnx_model)
tf_rep.export_graph(path_tf)
tensorflow => tflite
converter = tf.lite.TFLiteConverter.from_saved_model(path_tf)
tflite_model = converter.convert()
with open(path_tflite, 'wb') as f:
f.write(tflite_model)
Discussion