📚
PyTorchのEmbeddingとは
Embeddingって何?
- カテゴリ変数(=IDやラベル)を、意味のあるベクトルに変換する方法。
- One-Hotよりも効率的&情報豊富。
- 機械学習・深層学習で「意味を学習できる」 のが特徴。
- 主に「単語のベクトル化」や「カテゴリ特徴量の表現」に使う。
PyTorchでの基本構文
import torch
import torch.nn as nn
# 例: 3個のカテゴリに対して、4次元の埋め込みベクトルを学習
# つまり一つのカテゴリに対し、shapeが[4]のベクトルを生成する。
items = ['リンゴ', 'バナナ', 'キウィ']
# ↓ ラベル化(文字列から重複のない0, 3の整数へ変換)
labeled_items = [0, 1, 2]
item_cunt = len(labeled_items) # 3
embedding = nn.Embedding(num_embeddings=item_cunt, embedding_dim=4)
or
embedding = nn.Embedding(item_cunt, 4)
print(embedding)
Embedding(10, 4)
num_embeddings = カテゴリの数(インデックスの最大値 + 1)
embedding_dim = 埋め込みベクトルの次元数
x = torch.tensor([0, 2]) # カテゴリのインデックス
output = embedding(x)
print(output)
tensor([[-1.9968, 0.2367, 0.5444, 0.1343],
[ 3.0731, 1.2503, 1.3096, 0.7230]], grad_fn=<EmbeddingBackward0>)
x の各インデックスに対応する 埋め込みベクトル(4次元) が返されます。
3個の要素を渡しているため、3個出力されます。
特徴
- 軽い One-Hotよりも次元数が小さい(効率的)
- 学習可能 ベクトルの値は誤差逆伝播で更新される
- 柔軟 多次元埋め込み、複数カテゴリに対応できる
インデックスエラー
Embeddingに範囲外のインデックスを渡した場合インデックスエラーになります。
num_embeddings=5 → インデックスは 0, 1, 2, 3, 4 だけです。
import torch.nn as nn
import torch
embedding = nn.Embedding(5, 3) # 有効なインデックスは 0〜4
x = torch.tensor([0, 1, 2, 5]) # ← 5 は範囲外!
embedding(x) # エラー発生!
shape
入力テンソルのshapeの末尾に、hidden_dimを加えたものが、出力テンソルのshapeになる。
例えば入力テンソルのshapeが(3, 4), hidden_dim = 10なら、出力テンソルのshapeは、(3, 4, 10)になる。
例
em = nn.Embedding(7, 10)
print(em)
input1 = torch.Tensor([5, 4]).long()
output1 = em(input1)
print(f"input1: {input1.shape}, output1: {output1.shape}")
input2 = torch.stack([
input1,
input1,
input1
])
output2 = em(input2)
print(f"input1: {input2.shape}, output2: {output2.shape}")
Discussion