PyTorch の argsort を sort に置き換えて ONNX にエクスポートする
1. はじめに
素晴らしいPyTorchのモデルをONNXにエクスポートしようとすると下記のようなエラーに悩まされることがあります。torch.argsort
を含むモデルです。ちなみに、PyTorch の公式エンジニアも、ONNXの公式エンジニアもまともに取り合ってくれていません。さらに、issueを解決したと考えられる回答の内容に全く中身がありません。困りますね。皆さん、青色の血が流れているのでしょうか。
Traceback (most recent call last):
File "demo.py", line 49, in <module>
torch.onnx.export(pose_ssstereo, (x, x), f"coex_{H}x{W}.onnx", opset_version=11)
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/__init__.py", line 275, in export
return utils.export(model, args, f, export_params, verbose, training,
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 88, in export
_export(model, args, f, export_params, verbose, training, input_names, output_names,
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 689, in _export
_model_to_graph(model, args, verbose, input_names,
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 463, in _model_to_graph
graph = _optimize_graph(graph, operator_export_type,
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 200, in _optimize_graph
graph = torch._C._jit_pass_onnx(graph, operator_export_type)
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/__init__.py", line 313, in _run_symbolic_function
return utils._run_symbolic_function(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 990, in _run_symbolic_function
symbolic_fn = _find_symbolic_in_registry(domain, op_name, opset_version, operator_export_type)
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 944, in _find_symbolic_in_registry
return sym_registry.get_registered_op(op_name, domain, opset_version)
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/symbolic_registry.py", line 116, in get_registered_op
raise RuntimeError(msg)
RuntimeError: Exporting the operator argsort to ONNX opset version 11 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.
忌まわしきエラー。
RuntimeError: Exporting the operator argsort to ONNX opset version 11 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.
Google検索しても np.argsort
に置き換えるという、使えない情報しか見当たりません。したがって、しかたなくモデルを改造します。
こちらが PyTorch と ONNX のissue例。
2. ワークアラウンド
しょーもないワークアラウンドですが、下記の通り改造することで ONNX へエクスポートできるようになります。今回は opset=11
でエクスポートして正常に変換できたパターンを記録として残します。
例えばこれを、
pool_ind_ = cost.argsort(2, True)[:, :, :k]
こう書き換えてからエクスポート torch.onnx.export(model, x, "xxxx.onnx", opset_version=11)
すると成功します。
_, ind = cost.sort(2, True)
pool_ind_ = ind[:, :, :k]
k
は上位何番目までを抽出するか、を指定する変数です。例えば k=10
の場合は上位10番目までのインデックスをスライスして抽出しています。これはあくまでサンプルとして3次元のテンソルに対するスライスの処理を記載しているだけですので、上位n番目までの抽出が不要な場合は [:, :, :k]
の部分を無視してかまいません。 argsort
と sort
の仕様は下記に記載されています。1つ目の引数はソートする次元、2つ目の引数は昇順or降順の指定です。Trueを指定した場合は降順に並び替えされます。sort
の戻り値は2つ有り、1つ目の戻り値はソートされた値そのもの、2つ目はソートされたあとのインデックス値です。
torch.argsort(input, dim=-1, descending=False) → LongTensor
- Parameters
input (Tensor) – the input tensor.
dim (int, optional) – the dimension to sort along
descending (bool, optional) – controls the sorting order (ascending or descending)
torch.sort(input, dim=-1, descending=False, stable=False, *, out=None) -> (Tensor, LongTensor)
- Parameters
input (Tensor) – the input tensor.
dim (int, optional) – the dimension to sort along
descending (bool, optional) – controls the sorting order (ascending or descending)
stable (bool, optional) – makes the sorting routine stable, which guarantees that the order of equivalent elements is preserved.
PyTorchの公式ドキュメントに sort
オペレーションのエクスポートが正式サポートされていることが記載されていることからも、ONNXへのエクスポートが成功するのは正しい挙動に見えます。
3. Appendix
下記のargsortを含むモデルでONNXエクスポートが正常に成功することを確認しました。
PyTorchへのプルリクエスト
Discussion