🍩

End-to-End の OCR Free な文書理解モデル「Donut」で名刺画像を構造化データに変換してみる

2024/09/17に公開

概要

株式会社mutexの熊澤です。
2022年に、End-to-End の OCR Free な文書理解モデル「Donut」がリリースされ、弊社内でも実験的に使っていました。

Donutは、Transformerベースのモデルで、画像からテキストを抽出する文書理解タスクに特化しています。
今回は、Donutを使って、以下のような簡単な実験をしていきたいと思います。

  1. 名刺画像をランダムに作成
  2. Donutのベースモデルをファインチューニング
  3. 結果の確認

名刺画像をランダムに作成

名刺の形式

名刺の形式は以下の通りです。

  • 会社名
  • 氏名
  • メールアドレス
  • 電話番号
  • 住所

DonutはテキストをXML形式で出力するため、以下のようなクラスを作成します。

名刺のデータクラス
business_card.py
from dataclasses import dataclass


@dataclass
class BusinessCard:
    image_path: str
    company: str
    name: str
    email: str
    phone_number: str
    address: str

    @property
    def xml(self) -> str:
        return (
            "<s>"
            f"<s_company>{self.company}</s_company>"
            f"<s_name>{self.name}</s_name>"
            f"<s_email>{self.email}</s_email>"
            f"<s_phone_number>{self.phone_number}</s_phone_number>"
            f"<s_address>{self.address}</s_address>"
            "</s>"
        )

    @classmethod
    def get_xml_tags(cls) -> list[str]:
        return [
            "<s>",
            "<s_company>",
            "</s_company>",
            "<s_name>",
            "</s_name>",
            "<s_email>",
            "</s_email>",
            "<s_phone_number>",
            "</s_phone_number>",
            "<s_address>",
            "</s_address>",
            "</s>",
        ]

データセットの作成

データセット作成には、以下を利用します。

  • 背景画像
    • Lorem Picsum
    • URLで指定したサイズのランダムな背景画像を取得できる
  • 名刺情報
    • Faker
    • 日本語のダミーデータも生成できて便利

また、テキストの色は、白黒のうちなるべく背景色と対照的になるようにしています。

名刺画像生成
generate_business_card.py
import json
from dataclasses import asdict
from io import BytesIO
from pathlib import Path
from random import randint
from urllib.request import urlopen

import numpy as np
import requests
from faker import Faker
from PIL import Image, ImageDraw, ImageFont

from src.domain.business_card import BusinessCard

DATASET_PATH = Path("dataset/train")
LABEL_PATH = DATASET_PATH / "label.json"
IMAGE_DIR = DATASET_PATH / "image"
IMAGE_SIZE = (700, 500)
PICSUM_URL = f"https://picsum.photos/{IMAGE_SIZE[0]}/{IMAGE_SIZE[1]}"
FONT_PATH = "https://github.com/googlefonts/morisawa-biz-ud-mincho/raw/main/fonts/ttf/BIZUDPMincho-Regular.ttf"
DATASET_LENGTH = 1000

TIMEOUT = 1000


def color_invert(r: int, g: int, b: int) -> str:
    mono = (0.114 * r) + (0.587 * g) + (0.299 * b)
    if mono >= 127:
        return "#000000"

    return "#FFFFFF"


def fetch_image_from_url(url: str) -> Image.Image | None:
    try:
        image = Image.open(BytesIO(requests.get(url, timeout=TIMEOUT).content))
    except (TypeError, ValueError, ConnectionError, OSError, BufferError):
        print("Failed to get image: %s", url)
        return None
    if image.mode != "RGB":
        print("Convert image mode to RGB: %s", url)
        image = image.convert("RGB")
    return image


def dummy_business_card(i: int) -> BusinessCard | None:
    image = fetch_image_from_url(PICSUM_URL)
    if image is None:
        return None

    mean_color = np.mean(np.array(image), axis=(0, 1)).astype(int)

    text_color = color_invert(*mean_color)
    draw = ImageDraw.Draw(image)

    faker = Faker("ja_JP")

    text_x = randint(50, 100)

    # 左上の適当な位置とサイズを選び、会社名を書く
    company_point_x, company_point_y, company_size = (
        text_x,
        randint(50, 100),
        randint(20, 30),
    )
    company_font = ImageFont.truetype(
        urlopen(FONT_PATH),
        company_size,
    )
    company = faker.company()
    draw.text(
        (company_point_x, company_point_y),
        company,
        font=company_font,
        fill=text_color,
    )
    company_bounding_box = draw.textbbox(
        (company_point_x, company_point_y),
        company,
        font=company_font,
    )

    # 会社名の下に適当な位置とサイズで名前を書く
    name_point_x, name_point_y, name_size = (
        text_x,
        company_bounding_box[3] + randint(10, 20),
        randint(30, 50),
    )
    name_font = ImageFont.truetype(
        urlopen(FONT_PATH),
        name_size,
    )
    name = faker.name()
    draw.text(
        (name_point_x, name_point_y),
        name,
        font=name_font,
        fill=text_color,
    )

    detail_font_size = 20
    # 左下の適当な位置とサイズでメールアドレスを書く
    email_point_x, email_point_y = (
        text_x,
        randint(320, 360),
    )
    email_font = ImageFont.truetype(
        urlopen(FONT_PATH),
        detail_font_size,
    )
    email = faker.email()
    draw.text(
        (email_point_x, email_point_y),
        email,
        font=email_font,
        fill=text_color,
    )
    email_bounding_box = draw.textbbox(
        (email_point_x, email_point_y),
        email,
        font=email_font,
    )

    # メールアドレスの下に電話番号を書く
    phone_point_x, phone_point_y = (
        text_x,
        email_bounding_box[3] + 5,
    )
    phone_font = ImageFont.truetype(
        urlopen(FONT_PATH),
        detail_font_size,
    )
    phone = faker.phone_number()
    draw.text(
        (phone_point_x, phone_point_y),
        phone,
        font=phone_font,
        fill=text_color,
    )
    phone_bounding_box = draw.textbbox(
        (phone_point_x, phone_point_y),
        phone,
        font=phone_font,
    )

    # 電話番号の下に会社URLを書く
    url_point_x, url_point_y = (
        text_x,
        phone_bounding_box[3] + 5,
    )
    url_font = ImageFont.truetype(
        urlopen(FONT_PATH),
        detail_font_size,
    )
    url = faker.url()
    draw.text(
        (url_point_x, url_point_y),
        url,
        font=url_font,
        fill=text_color,
    )
    url_bounding_box = draw.textbbox(
        (url_point_x, url_point_y),
        url,
        font=url_font,
    )

    # 会社URLの下に住所を書く
    address_point_x, address_point_y = (
        text_x,
        url_bounding_box[3] + 5,
    )
    address_font = ImageFont.truetype(
        urlopen(FONT_PATH),
        detail_font_size,
    )
    address = faker.address()
    draw.text(
        (address_point_x, address_point_y),
        address,
        font=address_font,
        fill=text_color,
    )

    # 画像を保存
    image.save(f"{IMAGE_DIR}/{i}.png")

    return BusinessCard(
        image_path=f"{IMAGE_DIR}/{i}.png",
        company=company,
        name=name,
        email=email,
        phone_number=phone,
        address=address,
    )


if __name__ == "__main__":
    cards = []

    i = 0
    while i < DATASET_LENGTH:
        card = dummy_business_card(i)

        if card is None:
            continue

        print("Generated: %s", card)
        cards.append(asdict(card))

        i += 1

    print("Generated %s cards", len(cards))

    with LABEL_PATH.open("w") as f:
        f.write(json.dumps(cards, ensure_ascii=False, indent=4))

実際に生成した画像はこちらです。ただし、電話番号は実在してしまう可能性があるため、伏せています。

生成した名刺画像

Donutのベースモデルをファインチューニング

次に、Donutのベースモデルをファインチューニングします。

モデル定義
model.py
import re
from typing import cast, override

from nltk import edit_distance
from pytorch_lightning import LightningModule
from torch import Tensor, optim
from transformers import (
    DonutProcessor,
    LogitsProcessorList,
    PreTrainedModel,
    RepetitionPenaltyLogitsProcessor,
    VisionEncoderDecoderModel,
    XLMRobertaTokenizer,
)

from src.domain.business_card import BusinessCard
from src.domain.inference_processor import InferenceLogitsProcessor


class Model(LightningModule):
    def __init__(
        self,
        processor: DonutProcessor,
        model: VisionEncoderDecoderModel,
        lr: float | None = None,
        epochs: int | None = None,
    ) -> None:
        super().__init__()
        self.processor = processor
        self.tokenizer = cast(XLMRobertaTokenizer, processor.tokenizer)

        bos_token_id, eos_token_id = cast(
            list[int],
            self.tokenizer.convert_tokens_to_ids(
                ["<s>", "</s>"],
            ),
        )
        model.config.pad_token_id = self.tokenizer.pad_token_id
        model.config.decoder_start_token_id = bos_token_id
        model.config.eos_token_id = eos_token_id
        model.config.decoder.max_length = 1000
        newly_added_num = self.tokenizer.add_special_tokens(
            {
                "additional_special_tokens": [
                    tags
                    for tags in BusinessCard.get_xml_tags()
                    if tags not in self.tokenizer.all_special_tokens
                ],
            },
        )

        if newly_added_num > 0:
            cast(PreTrainedModel, model.decoder).resize_token_embeddings(len(self.tokenizer))

        self.model = model
        self._lr = lr
        self._epochs = epochs
        self.training_step_losses = []
        self.validation_step_losses = []
        self.validation_step_scores = []

    @property
    def lr(self) -> float:
        if self._lr is None:
            msg = "Learning rate is not set."
            raise ValueError(msg)
        return self._lr

    @property
    def epochs(self) -> int:
        if self._epochs is None:
            msg = "Epochs is not set."
            raise ValueError(msg)
        return self._epochs

    @override
    def configure_optimizers(self) -> optim.Optimizer:
        return optim.Adam(self.parameters(), lr=self.lr)

    @override
    def training_step(
        self,
        batch: tuple[Tensor, Tensor, list[str]],
        _batch_idx: int,
    ) -> Tensor:
        pixel_values, labels, _ = batch

        outputs = self.model(pixel_values, labels=labels[:, 1:])
        loss = cast(Tensor, outputs.loss)

        self.training_step_losses.append(loss.item())

        return loss

    @override
    def validation_step(
        self,
        batch: tuple[Tensor, Tensor, list[str]],
        _batch_idx: int,
    ) -> Tensor:
        pixel_values, labels, targets = batch
        outputs = self.model(pixel_values, labels=labels)
        loss = cast(Tensor, outputs.loss)

        self.validation_step_losses.append(loss.item())

        predictions = self.inference(pixel_values)

        scores = []

        for pred, answer in zip(predictions, targets, strict=True):
            score = edit_distance(pred, answer) / max(len(pred), len(answer))
            scores.append(score)

            self.print(f"Prediction: {pred}")
            self.print(f"    Answer: {answer}")
            self.print(f" Normed ED: {score}")

        average_score = sum(scores) / len(scores)
        self.validation_step_scores.append(average_score)

        self.log(
            "val_edit_distance",
            average_score,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )

        return loss

    def inference(self, pixel_values: Tensor) -> list[str]:
        outputs = self.model.generate(
            pixel_values,
            max_length=self.model.config.decoder.max_length,
            pad_token_id=self.model.config.pad_token_id,
            eos_token_id=self.model.config.eos_token_id,
            use_cache=True,
            num_beams=1,
            bad_words_ids=[[self.tokenizer.unk_token_id]],
            return_dict_in_generate=True,
            logits_processor=LogitsProcessorList(
                [InferenceLogitsProcessor(self.tokenizer), RepetitionPenaltyLogitsProcessor(1.06)],
            ),
        )

        pattern = re.compile(r"(<s_[a-zA-Z0-9_]+>)\s")

        predictions = []
        for seq in self.tokenizer.batch_decode(outputs.sequences):
            seq_ = seq.replace(
                self.tokenizer.pad_token,
                "",
            )
            seq_ = re.sub(pattern, r"\1", seq_)
            predictions.append(seq_)

        return predictions

ここで、推論時には、デコードの際のXMLタグの順番を強制するために、InferenceLogitsProcessorを使っています。
そして、バリデーション時の評価指標として、推論結果とGround Truthとの正規化した編集距離を使っています。

InferenceLogitsProcessor
inference_processor.py
from typing import cast

from torch import FloatTensor, LongTensor, Tensor
from transformers import (
    LogitsProcessor,
    XLMRobertaTokenizer,
)

from src.domain.business_card import BusinessCard


class InferenceLogitsProcessor(LogitsProcessor):
    def __init__(self, tokenizer: XLMRobertaTokenizer) -> None:
        self.tokenizer = tokenizer
        self.special_tokens = BusinessCard.get_xml_tags()
        self.special_token_ids = cast(
            list[int],
            tokenizer.convert_tokens_to_ids(self.special_tokens),
        )

    def _last_tag(self, ids: Tensor) -> str:
        last_special_token_id = next(
            (token_id for token_id in reversed(ids.tolist()) if token_id in self.special_token_ids),
        )
        return self.tokenizer.convert_ids_to_tokens(last_special_token_id)

    @staticmethod
    def _candidate_tags(last_tag: str) -> list[str]:
        return {
            "<s>": ["<s_company>"],
            "<s_company>": ["</s_company>"],
            "</s_company>": ["<s_name>"],
            "<s_name>": ["</s_name>"],
            "</s_name>": ["<s_email>"],
            "<s_email>": ["</s_email>"],
            "</s_email>": ["<s_phone_number>"],
            "<s_phone_number>": ["</s_phone_number>"],
            "</s_phone_number>": ["<s_address>"],
            "<s_address>": ["</s_address>"],
            "</s_address>": ["</s>"],
        }[last_tag]

    def __call__(self, input_ids: LongTensor, scores: FloatTensor) -> FloatTensor:
        for i_row in range(len(input_ids)):
            ids = input_ids[i_row]

            last_tag_label = self._last_tag(ids)

            candidates = self._candidate_tags(last_tag_label)

            forbidden = [
                token_id
                for token_id in self.special_token_ids
                if self.tokenizer.convert_ids_to_tokens(token_id) not in candidates
            ]

            scores[i_row, forbidden] = -float("inf")

        return scores

また、データセットはグレースケールにしてから学習させることにします。

データセット
dataset.py
import json
from pathlib import Path
from typing import cast

from PIL import Image
from torch import Tensor
from torch.utils.data import Dataset as TorchDataset
from torchvision.transforms.v2.functional import pil_to_tensor, to_grayscale, to_pil_image

from src.domain.business_card import BusinessCard
from src.domain.model import Model


class Dataset(TorchDataset[tuple[Tensor, Tensor, str]]):
    def __init__(
        self,
        data: list[tuple[Image.Image, BusinessCard]],
        model: Model,
        *,
        training: bool = True,
    ) -> None:
        self.data = data
        self.model = model
        self.training = training

    def __getitem__(self, index: int) -> tuple[Tensor, Tensor, str]:
        image, business_card = self.data[index]

        pixel_values = self._image_to_tensor(image, random_padding=self.training)
        labels = self._target_string_to_tensor(business_card.xml)

        return pixel_values, labels, business_card.xml

    def __len__(self) -> int:
        return len(self.data)

    @classmethod
    def load(
        cls,
        path: Path,
        model: Model,
        *,
        training: bool = True,
    ) -> "Dataset":
        with (path / "label.json").open() as f:
            labels_json = cast(list[dict], json.load(f))

        return cls(
            [
                (
                    Image.open(label_json["image_path"]),
                    BusinessCard(
                        image_path=label_json["image_path"],
                        company=label_json["company"],
                        name=label_json["name"],
                        email=label_json["email"],
                        phone_number=label_json["phone_number"],
                        address=label_json["address"],
                    ),
                )
                for label_json in labels_json
            ],
            model,
            training=training,
        )

    def _gray_scaling_image(self, image: Image.Image) -> Image.Image:
        return to_pil_image(to_grayscale(pil_to_tensor(image)))

    def _image_to_tensor(self, image: Image.Image, *, random_padding: bool) -> Tensor:
        preprocess_image = self._gray_scaling_image(image)
        pixel_values = cast(
            Tensor,
            self.model.processor(
                preprocess_image.convert("RGB"),
                random_padding=random_padding,
                return_tensors="pt",
            ).pixel_values,
        )

        return pixel_values.squeeze()

    def _target_string_to_tensor(self, target: str) -> Tensor:
        ignore_id = -100
        input_ids = cast(
            Tensor,
            self.model.tokenizer(
                target,
                add_special_tokens=False,
                max_length=self.model.model.config.decoder.max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
                return_special_tokens_mask=True,
            ).input_ids,
        ).squeeze(0)

        labels = input_ids.clone()
        labels[labels == self.model.tokenizer.pad_token_id] = ignore_id

        return labels

実際のトレーニングのコードです。
Donutはモデルサイズが大きく、弊社サーバーのGPUメモリでは、バッチサイズを大きくできないため、バッチサイズを1にしています。

CPU GPU OS RAM
Core i9 13900K GeForce RTX 4090 Ubuntu Server 22.04 64GB
トレーニング
train.py
from datetime import datetime
from pathlib import Path
from typing import cast

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
from transformers import (
    DonutProcessor,
    VisionEncoderDecoderConfig,
    VisionEncoderDecoderModel,
)

from src.domain.dataset import Dataset
from src.domain.model import Model

BASE_MODEL = Path("model_output/base_model")
LABEL_PATH = Path("dataset/label.json")
IMAGE_SIZE = (700, 500)
TRAIN_PATH = Path("dataset/train")
VALIDATION_PATH = Path("dataset/validation")


def train() -> None:
    batch_size = 1
    lr = 1e-6
    epoch_num = 3
    model_output_path = Path(
        f"model_output/donut_{batch_size}_{lr}_{round(datetime.now().timestamp())}",
    )

    config = VisionEncoderDecoderConfig.from_pretrained(BASE_MODEL)
    base_model = cast(
        VisionEncoderDecoderModel,
        VisionEncoderDecoderModel.from_pretrained(
            BASE_MODEL,
            config=config,
        ),
    )
    processor = cast(DonutProcessor, DonutProcessor.from_pretrained(BASE_MODEL))
    model = Model(processor, base_model, lr, epoch_num)

    training_dataset = Dataset.load(TRAIN_PATH, model, training=True)

    validation_dataset = Dataset.load(VALIDATION_PATH, model, training=False)

    train_dataloader = DataLoader(
        training_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=16,
    )

    val_dataloader = DataLoader(
        validation_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=16,
    )

    trainer = Trainer(
        accelerator="gpu",
        devices=1,
        max_epochs=epoch_num,
        check_val_every_n_epoch=1,
        gradient_clip_val=1.0,
        precision="16-mixed",
        num_sanity_val_steps=0,
        callbacks=[
            ModelCheckpoint(
                dirpath=model_output_path,
                filename="every_{epoch}_{v_num}",
                every_n_epochs=1,
                save_top_k=-1,
            ),
        ],
    )

    trainer.fit(model, train_dataloader, val_dataloader)

    model.model.save_pretrained(model_output_path)
    model.processor.save_pretrained(model_output_path)

結果の確認

学習結果

今回は、トレーニングデータ1000枚、バリデーションデータ100枚を用意し、3エポック学習させました。
3エポック目のValidation時の正規化した編集距離の平均は0.11と、かなり高い精度が出ました。
以下で、3エポック目のValidation時の結果を3つピックアップしてみます。
見やすさのため、差分形式で、Predictionを編集前、Ground Truthを編集後として表示しています。また、電話番号は000-0000-0000でマスクします。

結果の一部
例①(差分がなかったもの)
  <s>
    <s_company>山下鉱業株式会社</s_company>
    <s_name>鈴木 桃子</s_name>
    <s_email>gotokumiko@example.com</s_email>
    <s_phone_number>000-0000-0000</s_phone_number>
    <s_address>宮崎県横添市鶴見区下吉羽5丁目7番20号</s_address>
  </s>
例②(住所に差分があったもの)
  <s>
    <s_company>株式会社中島農林</s_company>
    <s_name>山本 篤司</s_name>
    <s_email>osamuogawa@example.com</s_email>
    <s_phone_number>000-0000-0000</s_phone_number>
-   <s_address>岡山県東大和市戸探町12丁目23番19号 ハイツ方京342</s_address>
+   <s_address>岡山県東大和市戸塚町12丁目23番1号 ハイツ方京342</s_address>
  </s>
例③(差分が散見されたもの。電話番号は2と7を間違えていた)
  <s>
    <s_company>有限会社石井農林</s_company>
    <s_name>山本 里佳</s_name>
-   <s_email>oshidashota@example.com</s_email>
+   <s_email>yoshidashota@example.com</s_email>
-   <s_phone_number>000-0000-0000</s_phone_number>
+   <s_phone_number>000-0000-0000</s_phone_number>
-   <s_address>大阪府木並区東神田11丁目10番11号 由光パーク138</s_address>
+   <s_address>大阪府杉並区東神田11丁目10番11号 日光パーク138</s_address>
  </s>

文字認識あるあるですが、形の似ている数字や文字、画数の多い複雑な漢字などは間違えていることが多いようです。
今回利用したベースモデルは日本語特化のモデルではなく、一部の漢字はそもそもボキャブラリーに含まれていないため、正確には読み取ることが不可能なパターンも存在します。

前半で載せた画像を使って、実際に推論した結果も見てみましょう。

<s>
  <s_company>合同会社佐藤運wod</s_company>
  <s_name>藤田 里佳</s_name>
  <s_email>tsubasakobayashi@example.com</s_email>
  <s_phone_number>000-0000-0000</s_phone_number>
  <s_address>東京都香取市入谷39丁目21番9号</s_address>
</s>

会社名がおかしくなっていますが、これは、今回使ったベースモデルのボキャブラリーに運輸などのサブワードが含まれていなかったことが原因でした。

余談ですが、検証の途中でDonutにという単体の文字もサブワードとして含まれていないことがわかりました。今回の検証では、Donutのベースモデルにを自前で追加したのちに検証を行なっています。

まとめ

検証の結果、1000件という少ないトレーニングデータで、簡単な画像ではあるもののここまで精度が出るのは、かなり有用性の高いモデルと言えるような気がしています。
また、従来の単純なテキストの読み取りだけをするOCRと異なり、構造化データとして出力できることも大きな魅力です。
社内の別の実験では、ダミーデータではなく、実際の請求書データを10000枚ほど用意することで、より複雑なデータ形式で文字も小さいタスクにも対応できることがわかりました。
今後も、新しい面白そうなモデルが出てきたら、また検証してみたいと思います。

最後に、ソースコード全体は以下のリポジトリにあります。

mutex Official Tech Blog

Discussion