📚

PyTorchのEmbeddingとは

に公開

Embeddingって何?

  • カテゴリ変数(=IDやラベル)を、意味のあるベクトルに変換する方法。
  • One-Hotよりも効率的&情報豊富。
  • 機械学習・深層学習で「意味を学習できる」 のが特徴。
  • 主に「単語のベクトル化」や「カテゴリ特徴量の表現」に使う。

PyTorchでの基本構文

import torch
import torch.nn as nn

# 例: 3個のカテゴリに対して、4次元の埋め込みベクトルを学習
# つまり一つのカテゴリに対し、shapeが[4]のベクトルを生成する。

items = ['リンゴ', 'バナナ', 'キウィ']
# ↓ ラベル化(文字列から重複のない0, 3の整数へ変換)
labeled_items = [0, 1, 2]
item_cunt = len(labeled_items) # 3

embedding = nn.Embedding(num_embeddings=item_cunt, embedding_dim=4)
or 
embedding = nn.Embedding(item_cunt, 4)
print(embedding)

Embedding(10, 4)

num_embeddings = カテゴリの数(インデックスの最大値 + 1)
embedding_dim = 埋め込みベクトルの次元数

x = torch.tensor([0, 2])        # カテゴリのインデックス
output = embedding(x)
print(output)

tensor([[-1.9968, 0.2367, 0.5444, 0.1343],
[ 3.0731, 1.2503, 1.3096, 0.7230]], grad_fn=<EmbeddingBackward0>)

x の各インデックスに対応する 埋め込みベクトル(4次元) が返されます。
3個の要素を渡しているため、3個出力されます。

特徴

  • 軽い One-Hotよりも次元数が小さい(効率的)
  • 学習可能 ベクトルの値は誤差逆伝播で更新される
  • 柔軟 多次元埋め込み、複数カテゴリに対応できる

インデックスエラー

Embeddingに範囲外のインデックスを渡した場合インデックスエラーになります。

num_embeddings=5 → インデックスは 0, 1, 2, 3, 4 だけです。

import torch.nn as nn
import torch

embedding = nn.Embedding(5, 3)  # 有効なインデックスは 0〜4

x = torch.tensor([0, 1, 2, 5])  # ← 5 は範囲外!
embedding(x)  # エラー発生!

shape

入力テンソルのshapeの末尾に、hidden_dimを加えたものが、出力テンソルのshapeになる。

例えば入力テンソルのshapeが(3, 4), hidden_dim = 10なら、出力テンソルのshapeは、(3, 4, 10)になる。

em = nn.Embedding(7, 10)
print(em)

input1 = torch.Tensor([5, 4]).long()
output1 = em(input1)
print(f"input1: {input1.shape}, output1: {output1.shape}")

input2 = torch.stack([
    input1,
    input1,
    input1
])

output2 = em(input2)
print(f"input1: {input2.shape}, output2: {output2.shape}")

Discussion