Open3
TorchScriptで推論高速化への道
はじめに
TorchScriptを使って、ESPnetで作ったモデルを推論を高速化するタスクができました。しかし、現在(2022/08/06)では、モデルはTorchScriptように書かれておらず、大部分を書き換えないとトレースできない状態でした。
そこで、作業内で気づいたことをまとめていきます。
TorchScript初級編
一番簡単なモデルの準備
import torch
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.linear(x) + h)
return new_h, new_h
torch スクリプトにする場合は以下のようにトレースすることで、pytorchがデータのフローを見てスクリプト化する。(らしい)
torch.manual_seed(111)
my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell)
print(traced_cell(x, h))
出力が合っているか見る。
- Traceしない場合
MyCell(
(linear): Linear(in_features=4, out_features=4, bias=True)
)
(tensor([[ 0.1259, 0.6980, 0.1837, 0.7427],
[ 0.5375, 0.3345, 0.4816, 0.6669],
[-0.0359, -0.0503, 0.2005, 0.1008]], grad_fn=<TanhBackward0>), tensor([[ 0.1259, 0.6980, 0.1837, 0.7427],
[ 0.5375, 0.3345, 0.4816, 0.6669],
[-0.0359, -0.0503, 0.2005, 0.1008]], grad_fn=<TanhBackward0>))
- Traceした場合
MyCell(
original_name=MyCell
(linear): Linear(original_name=Linear)
)
(tensor([[ 0.1259, 0.6980, 0.1837, 0.7427],
[ 0.5375, 0.3345, 0.4816, 0.6669],
[-0.0359, -0.0503, 0.2005, 0.1008]], grad_fn=<TanhBackward0>), tensor([[ 0.1259, 0.6980, 0.1837, 0.7427],
[ 0.5375, 0.3345, 0.4816, 0.6669],
[-0.0359, -0.0503, 0.2005, 0.1008]], grad_fn=<TanhBackward0>))
同じですね。今回は処理時間計算はパスします。
Modelの中にTorch Moduleじゃないソースがある場合
ESPnetのTorch のソースらは非常に複雑に絡み合っており、例えばModuleの中でModuleじゃないオブジェクトを利用するところがあった。この場合もTorchScriptにしてくれるのか気になったため試してみた。
- まずはtorch Moduleじゃないオブジェクト
import torch
class Tools(object):
def __init__(self, ids: torch.Tensor) -> None:
self.ids = ids
def __call__(self) -> torch.Tensor:
return self.ids
- Modelの定義
import torch
from tools import Tools
torch.manual_seed(111)
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(4, 4)
self.tools = Tools(torch.rand(3, 4))
def forward(self, x, h):
new_h = torch.tanh(self.linear(x) + h)
new_h = new_h + self.tools()
return new_h, new_h
my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))
# traced_cell = my_cell
print(traced_cell)
print(traced_cell(x, h))
3.結果
MyCell(
original_name=MyCell
(linear): Linear(original_name=Linear)
)
(tensor([[0.9494, 1.1489, 1.4638, 0.1276],
[1.1904, 0.2365, 0.6603, 1.0575],
[0.4537, 0.6700, 0.1617, 1.1576]], grad_fn=<AddBackward0>), tensor([[0.9494, 1.1489, 1.4638, 0.1276],
[1.1904, 0.2365, 0.6603, 1.0575],
[0.4537, 0.6700, 0.1617, 1.1576]], grad_fn=<AddBackward0>))
無事スクリプト化できている。