📖

2022/07/22に公開

# 概要

PyTorchで学習したモデルをTensorFlow Liteモデルに変換する方法。

• PyTroch → onnx → tensorflow → tflite

の順に変換していく。

# 実装

## PyTorchのモデル設定

### 設定値

「content」直下に学習済みの.pthファイルをアップロードしておく。

``````batch_size = 128
channel = 3
width = 32
height = 32

path_pth = "/content/model.pth"
path_onnx = "/content/model.onnx"
path_tf = "/content/tf_model"
path_tflite = "/content/model.tflite"
``````

### モデルの定義

``````class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
``````

## PyTorch => onnx

``````conver_model = Net()
conver_model.eval()

sample_input = torch.rand((batch_size, channel, width, height))

torch.onnx.export(
conver_model,
sample_input,
"model.onnx",
opset_version=12,
input_names=['input'],
output_names=['output']
)
``````

## onnx => tensorflow

``````!git clone https://github.com/onnx/onnx-tensorflow.git
%cd onnx-tensorflow
!pip install -e .
``````
``````import onnx

``````
``````from onnx_tf.backend import  prepare

tf_rep = prepare(onnx_model)
``````
``````tf_rep.export_graph(path_tf)
``````

## tensorflow => tflite

``````converter = tf.lite.TFLiteConverter.from_saved_model(path_tf)
tflite_model = converter.convert()

with open(path_tflite, 'wb') as f:
f.write(tflite_model)
``````

ログインするとコメントできます