🟰

Wikipediaの全記事を学習させて、日本語で遊んでみる

2024/11/09に公開

元ネタ

https://www.youtube.com/watch?v=sK3HqLwag_w&list=WL&index=38

見てておもろそうだなーと思ったので、FastAPIの復習も兼ねて作ってみました。
動画に全て内容が載っていますが一応記載すると、Wikipediaの全記事内の文章をword2vecに学習させて単語を数値化することで、単語同士の足し算や引き算を行うことができます。

できたもの

https://github.com/Suntory-Y-Water/word-vectors-api
APIにして公開したら面白そうだな~と思いましたが、Renderだと学習データの容量が重くてNGな点とできる限り無料で運用したかったのでやめました🤔
学習させたmodelデータもありますのでぜひ遊んでみてください。

遊んでみる

セットアップ

パッケージのインストール

poetry install

開発サーバーの起動

uvicorn api.main:app --reload --port 8000 --host 0.0.0.0

起動を確認後、http://localhost:8000/docsにアクセスします。

実行例

水瀬いのりから近い単語を得るためにhttp://localhost:8000/api/vectors?positive=水瀬いのりで実行します。
すると声優に関連のワードを得ることができますね。

{
  "positive_words": [
    {
      "word": "水瀬いのり"
    }
  ],
  "negative_words": [],
  "word_list": [
    {
      "word": "水樹",
      "similarity": 0.745496988296509
    },
    {
      "word": "裕香",
      "similarity": 0.732845783233643
    },
    {
      "word": "奈々",
      "similarity": 0.722927331924439
    },
    {
      "word": "愛美",
      "similarity": 0.713139891624451
    },
    {
      "word": "咲",
      "similarity": 0.69900643825531
    }
  ]
}

動画でも紹介されている「円 + 韓国 - 日本」で計算すると、やはりウォンが出力できます。

// http://localhost:8000/api/vectors?positive=円,韓国&negative=日本
{
  "positive_words": [
    {
      "word": "円"
    },
    {
      "word": "韓国"
    }
  ],
  "negative_words": [
    {
      "word": "日本"
    }
  ],
  "word_list": [
    {
      "word": "ウォン",
      "similarity": 0.6519988775253296
    },
    {
      "word": "億",
      "similarity": 0.5680368542671204
    },
    {
      "word": "ドル",
      "similarity": 0.5642450451850891
    },
    {
      "word": "万",
      "similarity": 0.5566189885139465
    },
    {
      "word": "額",
      "similarity": 0.5470199584960938
    }
  ]
}

APIは足し算したい単語であればpositive、引き算したい単語であればnegativeにカンマ区切りで設定します。
クエリパラメータに入力された値を一度分かち書きしてから、モデルに入力を行います。

router/vectors.py
@router.get(
    "/vectors",
    response_model=vectors_schema.WordVectorResponse,
    responses={
        400: {"model": BadRequestError, "description": "入力値が不正な場合のエラー"},
        404: {"model": NotFoundError, "description": "類似する単語が見つからなかった場合のエラー"},
        500: {"description": "サーバー内部エラー"},
    },
    response_description="単語ベクトルの計算結果",
    summary="単語ベクトルの計算",
    description="単語同士の足し算・引き算を行い、類似する単語を返却します",
)
async def calculate_vectors(
    positive: str | None = Query(
        default=None, description="足し算したい単語(カンマ区切り。例:Python,プログラミング)"
    ),
    negative: str | None = Query(default=None, description="引き算したい単語(カンマ区切り。例:パソコン)"),
    topn: int = Query(default=5, ge=1, le=20),
) -> vectors_schema.WordVectorResponse:
    try:
        if positive is None and negative is None:
            raise TypedHTTPException(
                error_response=BadRequestError(message="positiveまたはnegativeワードのいずれかは必須です。")
            )

        _positive_terms = positive.split(",") if positive else []
        _negative_terms = negative.split(",") if negative else []

        # 各単語を分かち書きする
        wakati_positive_terms = mecab.tokenize_list(_positive_terms)
        wakati_negative_terms = mecab.tokenize_list(_negative_terms)

        similar_words = vector.most_similar(positive=wakati_positive_terms, negative=wakati_negative_terms, topn=topn)

        if not similar_words:
            raise TypedHTTPException(
                error_response=NotFoundError(message="類似する単語が見つかりませんでした。別の単語で試してください。")
            )
        # 省略
services/tokenizer.py
import MeCab
from typing import List


class Tokenizer:
    def __init__(self):
        self.mecab = MeCab.Tagger("-Owakati")

    def tokenize(self, text: str) -> List[str]:
        # MeCabで形態素解析
        node = self.mecab.parseToNode(text)
        tokens = []

        while node:
            # 表層形を取得(実際の単語)
            surface = node.surface
            # 品詞情報を取得
            features = node.feature.split(",")

            # 空白と記号は除外
            if surface and features[0] not in ["記号", "BOS/EOS"]:
                tokens.append(surface)

            node = node.next

        return tokens

    def tokenize_list(self, texts: List[str]) -> List[str]:
        all_tokens = []
        for text in texts:
            tokens = self.tokenize(text)
            all_tokens.extend(tokens)
        return all_tokens

positiveまたはnegativeが設定されていない場合は400エラーを返却します。
入力値によっては計算結果を得ることができないこともあったため、そのときは404エラーを返却します。

おわりに

なにかおもしろい案が思いついたらアプリケーションにしていきたいですね。
動画で紹介されていたときと比べて年月が経過している影響か分かりませんが、同じ入力をしても同様の結果を得ることができなかったです。
記事の内容や学習させたデータの数によって変わるものなんでしょうかね?

他、参考にさせていたもの

https://zenn.dev/haru330/articles/503c217c3cda1e

https://qiita.com/hoppiece_/items/72753b7ac08f0bd4993f

Discussion