📊
PyTorchのテンソルをシリアライズする3つの方法とベンチマーク分析
はじめに
PyTorchで開発をしていると、テンソルデータをバイト列に変換(シリアライズ)して保存や転送を行う必要に迫られることがあります。本記事では、一般的な3つのシリアライズ方法について、その特徴とパフォーマンスを解説します。
シリアライズの3つの方法
numpy().tobytes()
1. def tensor_to_buffer_numpy(tensor: torch.Tensor) -> bytes:
if tensor.device.type != "cpu":
tensor = tensor.cpu()
return tensor.numpy().tobytes()
def buffer_to_tensor_numpy(buffer: bytes, shape: tuple[int, ...], dtype: np.dtype) -> torch.Tensor:
array = np.frombuffer(buffer, dtype=dtype).copy()
return torch.from_numpy(array.reshape(shape))
- NumPy配列を経由してバイト列に変換する方法
- メモリコピーが発生するが、NumPyとの相互運用性が高い
- シンプルで理解しやすい実装
data_ptr()
2. def tensor_to_buffer_ptr(tensor: torch.Tensor) -> bytes:
if tensor.device.type != "cpu":
tensor = tensor.cpu()
nbytes = tensor.nelement() * tensor.element_size()
ptr = tensor.data_ptr()
return ctypes.string_at(ptr, nbytes)
def buffer_to_tensor_ptr(buffer: bytes, shape: tuple[int, ...], dtype: torch.dtype) -> torch.Tensor:
return torch.frombuffer(bytearray(buffer), dtype=dtype).reshape(shape)
- テンソルの生のメモリポインタを直接アクセスする方法
- 最小限のメモリコピーで高速な変換が可能
- 低レベルな操作のため、注意深い実装が必要
torch.save
3. def tensor_to_buffer_save(tensor: torch.Tensor) -> bytes:
buffer = BytesIO()
torch.save(tensor, buffer)
return buffer.getvalue()
def buffer_to_tensor_save(buffer: bytes) -> torch.Tensor:
return torch.load(BytesIO(buffer))
- PyTorch標準のシリアライズ機能を使用
- メタデータ(dtype、device等)も含めて保存可能
- Pickle形式での保存となるため、セキュリティに注意が必要
ベンチマーク結果と考察
テンソルサイズ別シリアライズ性能(MB/s)
テンソルサイズ | numpy().tobytes() | data_ptr() | torch.save |
---|---|---|---|
100x1 | 724.3 | 1,305.5 | 18.8 |
100x100 | 42,424.0 | 59,640.4 | 1,315.9 |
5000x5000 | 20,446.0 | 20,163.4 | 3,865.6 |
小規模テンソル(100x1)での処理では、data_ptr()メソッドが約1.3GB/sと最も高速で、numpy().tobytes()の約724MB/sを大きく上回っています。
一方、torch.saveは約19MB/sと著しく低速です。
しかし、テンソルサイズが大きくなると(5000x5000)、numpy().tobytes()とdata_ptr()はともに約20GB/sとほぼ同等の性能を示し、torch.saveも約4GB/sまで性能が向上します。
パフォーマンス比較
分析から得られた主な知見は以下の通りです:
-
テンソルサイズによる性能特性
- 小規模テンソル(100要素以下)では
data_ptr()
が最も高速で、torch.save
は著しく遅い - 大規模テンソル(1000x1000以上)では
numpy().tobytes()
とdata_ptr()
の性能差が縮小 - メモリ使用量は
numpy().tobytes()
とdata_ptr()
が同等で、torch.save
は小規模テンソルでオーバーヘッドが大きい
- 小規模テンソル(100要素以下)では
-
実用的な選択基準
- 単純なデータ転送の場合は
data_ptr()
が最適 - NumPyとの相互運用が必要な場合は
numpy().tobytes()
- メタデータの保存が重要な場合は
torch.save
- 単純なデータ転送の場合は
-
考慮すべき注意点
- すべての方法でCPUテンソルへの変換が必要
-
data_ptr()
は低レベル操作のため、メモリ管理に注意が必要 -
torch.save
はPickleを使用するため、信頼できないソースからのデータ読み込みには注意
まとめ
本記事では、PyTorchテンソルのシリアライズについて3つの方法を比較しました。実際の使用時には、以下の選択基準を推奨します:
- 高速な単純転送が必要な場合:
data_ptr()
- NumPyとの相互運用性が重要な場合:
numpy().tobytes()
- メタデータの保存が必要な場合:
torch.save
参考文献
ソースコードとraw result
完全なソースコードとベンチマーク結果はこちらです。
Discussion