🤪

PyTorch の argsort を sort に置き換えて ONNX にエクスポートする 

2021/08/15に公開

1. はじめに

素晴らしいPyTorchのモデルをONNXにエクスポートしようとすると下記のようなエラーに悩まされることがあります。torch.argsort を含むモデルです。ちなみに、PyTorch の公式エンジニアも、ONNXの公式エンジニアもまともに取り合ってくれていません。さらに、issueを解決したと考えられる回答の内容に全く中身がありません。困りますね。皆さん、青色の血が流れているのでしょうか。

Traceback (most recent call last):
  File "demo.py", line 49, in <module>
    torch.onnx.export(pose_ssstereo, (x, x), f"coex_{H}x{W}.onnx", opset_version=11)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/__init__.py", line 275, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 88, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 689, in _export
    _model_to_graph(model, args, verbose, input_names,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 463, in _model_to_graph
    graph = _optimize_graph(graph, operator_export_type,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 200, in _optimize_graph
    graph = torch._C._jit_pass_onnx(graph, operator_export_type)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/__init__.py", line 313, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 990, in _run_symbolic_function
    symbolic_fn = _find_symbolic_in_registry(domain, op_name, opset_version, operator_export_type)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 944, in _find_symbolic_in_registry
    return sym_registry.get_registered_op(op_name, domain, opset_version)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/symbolic_registry.py", line 116, in get_registered_op
    raise RuntimeError(msg)
RuntimeError: Exporting the operator argsort to ONNX opset version 11 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.

忌まわしきエラー。

RuntimeError: Exporting the operator argsort to ONNX opset version 11 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.

Google検索しても np.argsort に置き換えるという、使えない情報しか見当たりません。したがって、しかたなくモデルを改造します。

こちらが PyTorch と ONNX のissue例。
https://github.com/pytorch/pytorch/issues/33412
https://github.com/onnx/onnx/issues/3519
https://github.com/open-mmlab/mmdetection/issues/4738
https://github.com/open-mmlab/mmdetection/issues/4247#issuecomment-740468931

2. ワークアラウンド

しょーもないワークアラウンドですが、下記の通り改造することで ONNX へエクスポートできるようになります。今回は opset=11 でエクスポートして正常に変換できたパターンを記録として残します。

例えばこれを、

pool_ind_ = cost.argsort(2, True)[:, :, :k]

こう書き換えてからエクスポート torch.onnx.export(model, x, "xxxx.onnx", opset_version=11) すると成功します。

_, ind = cost.sort(2, True)
pool_ind_ = ind[:, :, :k]

k は上位何番目までを抽出するか、を指定する変数です。例えば k=10 の場合は上位10番目までのインデックスをスライスして抽出しています。これはあくまでサンプルとして3次元のテンソルに対するスライスの処理を記載しているだけですので、上位n番目までの抽出が不要な場合は [:, :, :k] の部分を無視してかまいません。 argsortsort の仕様は下記に記載されています。1つ目の引数はソートする次元、2つ目の引数は昇順or降順の指定です。Trueを指定した場合は降順に並び替えされます。sort の戻り値は2つ有り、1つ目の戻り値はソートされた値そのもの、2つ目はソートされたあとのインデックス値です。

torch.argsort(input, dim=-1, descending=False) → LongTensor

- Parameters
input (Tensor) – the input tensor.
dim (int, optional) – the dimension to sort along
descending (bool, optional) – controls the sorting order (ascending or descending)
torch.sort(input, dim=-1, descending=False, stable=False, *, out=None) -> (Tensor, LongTensor)

- Parameters
input (Tensor) – the input tensor.
dim (int, optional) – the dimension to sort along
descending (bool, optional) – controls the sorting order (ascending or descending)
stable (bool, optional) – makes the sorting routine stable, which guarantees that the order of equivalent elements is preserved.

PyTorchの公式ドキュメントに sort オペレーションのエクスポートが正式サポートされていることが記載されていることからも、ONNXへのエクスポートが成功するのは正しい挙動に見えます。

https://pytorch.org/docs/stable/onnx.html#id15

3. Appendix

下記のargsortを含むモデルでONNXエクスポートが正常に成功することを確認しました。
https://github.com/antabangun/coex
https://github.com/antabangun/coex/issues/2
PyTorchへのプルリクエスト
https://github.com/pytorch/pytorch/pull/63283

Discussion