Open3

TorchScriptで推論高速化への道

kitchykitchy

はじめに

TorchScriptを使って、ESPnetで作ったモデルを推論を高速化するタスクができました。しかし、現在(2022/08/06)では、モデルはTorchScriptように書かれておらず、大部分を書き換えないとトレースできない状態でした。

そこで、作業内で気づいたことをまとめていきます。

kitchykitchy

TorchScript初級編

こちらの記事を参考にさせていただきました。

一番簡単なモデルの準備

import torch

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h

torch スクリプトにする場合は以下のようにトレースすることで、pytorchがデータのフローを見てスクリプト化する。(らしい)

torch.manual_seed(111)

my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell)
print(traced_cell(x, h))

出力が合っているか見る。

  • Traceしない場合
MyCell(
  (linear): Linear(in_features=4, out_features=4, bias=True)
)
(tensor([[ 0.1259,  0.6980,  0.1837,  0.7427],
        [ 0.5375,  0.3345,  0.4816,  0.6669],
        [-0.0359, -0.0503,  0.2005,  0.1008]], grad_fn=<TanhBackward0>), tensor([[ 0.1259,  0.6980,  0.1837,  0.7427],
        [ 0.5375,  0.3345,  0.4816,  0.6669],
        [-0.0359, -0.0503,  0.2005,  0.1008]], grad_fn=<TanhBackward0>))
  • Traceした場合
MyCell(
  original_name=MyCell
  (linear): Linear(original_name=Linear)
)
(tensor([[ 0.1259,  0.6980,  0.1837,  0.7427],
        [ 0.5375,  0.3345,  0.4816,  0.6669],
        [-0.0359, -0.0503,  0.2005,  0.1008]], grad_fn=<TanhBackward0>), tensor([[ 0.1259,  0.6980,  0.1837,  0.7427],
        [ 0.5375,  0.3345,  0.4816,  0.6669],
        [-0.0359, -0.0503,  0.2005,  0.1008]], grad_fn=<TanhBackward0>))

同じですね。今回は処理時間計算はパスします。

kitchykitchy

Modelの中にTorch Moduleじゃないソースがある場合

ESPnetのTorch のソースらは非常に複雑に絡み合っており、例えばModuleの中でModuleじゃないオブジェクトを利用するところがあった。この場合もTorchScriptにしてくれるのか気になったため試してみた。

  1. まずはtorch Moduleじゃないオブジェクト
import torch
class Tools(object):
    def __init__(self, ids: torch.Tensor) -> None:
        self.ids = ids

    def __call__(self) -> torch.Tensor:
        return self.ids
  1. Modelの定義
import torch

from tools import Tools

torch.manual_seed(111)

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)
        self.tools = Tools(torch.rand(3, 4))

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        new_h = new_h + self.tools()
        return new_h, new_h

my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))
# traced_cell = my_cell
print(traced_cell)
print(traced_cell(x, h))

3.結果

MyCell(
  original_name=MyCell
  (linear): Linear(original_name=Linear)
)
(tensor([[0.9494, 1.1489, 1.4638, 0.1276],
        [1.1904, 0.2365, 0.6603, 1.0575],
        [0.4537, 0.6700, 0.1617, 1.1576]], grad_fn=<AddBackward0>), tensor([[0.9494, 1.1489, 1.4638, 0.1276],
        [1.1904, 0.2365, 0.6603, 1.0575],
        [0.4537, 0.6700, 0.1617, 1.1576]], grad_fn=<AddBackward0>))

無事スクリプト化できている。