🌟

torch nn.Linear の仕様メモ

2025/02/03に公開

https://pytorch.org/docs/stable/generated/torch.nn.Linear.html

input は (*, in_features), weight は (out_features, in_features)
したがって weight は doc に従って transpose(転置)されて matmul される.

https://stackoverflow.com/questions/53465608/pytorch-shape-of-nn-linear-weights

weight を (in_features, out_features) にしておけば転置せずに matmul でシンプルになるとは思うが, 後方互換性のためと, パフォーマンスにそれほど影響はないと思われるため転置するままになっている. https://github.com/pytorch/pytorch/issues/2159#issuecomment-390068272

matmul で最適化された BLAS 実装を使う場合, 転置かどうかはフラグだけ持っておいて(中身のメモリデータは転置を取らない), メモリアクセスも転置を考慮して最適化している(はず)なので, 少なくとも推論においてはパフォーマンスはそれほど影響はないと思われる.

Discussion