✍️

MetaScript:手書き文字を生成する

に公開

初めに

手書き風の文字を生成するMetaScriptを検証しました。
中国語で開発されているリポジトリであり日本語には対応していません。
モデルファイルは百度网盘でしか公開されていないため、今回は検証として論文・リポジトリと同じデータで学習し推論まで行いました。

https://github.com/xxyQwQ/metascript?tab=readme-ov-file#

https://pan.baidu.com/s/1UGHPKFVSvRj2QY_PbSjJGQ?pwd=1024

環境

windows11
RTX5070ti
CUDA 12.8

データ

元論文と同様にHWDB1.1に前処理を行ったものを使用しました。
https://nlpr.ia.ac.cn/databases/handwriting/Download.html

学習コード

import os
import sys
import time

import hydra
import numpy as np
from PIL import Image
from omegaconf import OmegaConf

import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader

from utils.logger import Logger
from utils.dataset import CharacterDataset
from utils.function import plot_sample
from model.generator import SynthesisGenerator
from model.discriminator import MultiscaleDiscriminator


@hydra.main(version_base=None, config_path='./config', config_name='training')
def main(config):
    # load configuration
    dataset_path = str(config.parameter.dataset_path)
    checkpoint_path = str(config.parameter.checkpoint_path)
    device = torch.device('cuda') if config.parameter.device == 'gpu' else torch.device('cpu')
    batch_size = int(config.parameter.batch_size)
    num_workers = int(config.parameter.num_workers)
    reference_count = int(config.parameter.reference_count)
    num_iterations = int(config.parameter.num_iterations)
    report_interval = int(config.parameter.report_interval)
    save_interval = int(config.parameter.save_interval)

    # create logger
    sys.stdout = Logger(os.path.join(checkpoint_path, 'training.log'))
    config.parameter.checkpoint_path = checkpoint_path
    config.parameter.device = str(device)
    print(OmegaConf.to_yaml(config))

    # load dataset
    dataset = CharacterDataset(dataset_path, reference_count=reference_count)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True)
    print('image number: {}\n'.format(len(dataset)))

    # create model
    generator_model = SynthesisGenerator(reference_count=reference_count).to(device)
    generator_model.train()

    discriminator_model = MultiscaleDiscriminator(dataset.writer_count, dataset.character_count).to(device)
    discriminator_model.train()

    # create optimizer
    generator_optimizer = Adam(generator_model.parameters(), lr=config.parameter.generator.learning_rate, betas=(0.0, 0.999), weight_decay=1e-4)
    discriminator_optimizer = Adam(discriminator_model.parameters(), lr=config.parameter.discriminator.learning_rate, betas=(0.0, 0.999), weight_decay=1e-4)

    # start training
    current_iteration = 0
    current_time = time.time()

    while current_iteration < num_iterations:
        for reference_image, writer_label, template_image, character_label, script_image in dataloader:
            current_iteration += 1

            reference_image, writer_label, template_image, character_label, script_image = reference_image.to(device), writer_label.to(device), template_image.to(device), character_label.to(device), script_image.to(device)

            # generator
            generator_optimizer.zero_grad()

            result_image, template_structure, reference_style = generator_model(reference_image, template_image)

            loss_generator_adversarial = 0
            loss_generator_classification = 0
            for prediction_reality, prediction_writer, prediction_character in discriminator_model(result_image):
                loss_generator_adversarial += F.binary_cross_entropy(prediction_reality, torch.ones_like(prediction_reality))
                loss_generator_classification += F.cross_entropy(prediction_writer, writer_label) + F.cross_entropy(prediction_character, character_label)

            result_structure = generator_model.structure(result_image)
            loss_generator_structure = 0
            for i in range(len(result_structure)):
                loss_generator_structure += 0.5 * torch.mean(torch.square(template_structure[i] - result_structure[i]))

            result_style = generator_model.style(result_image.repeat_interleave(reference_count, dim=1))
            loss_generator_style = 0.5 * torch.mean(torch.square(reference_style - result_style))

            loss_generator_reconstruction = F.l1_loss(result_image, script_image)

            loss_generator = config.parameter.generator.loss_function.weight_adversarial * loss_generator_adversarial + config.parameter.generator.loss_function.weight_classification * loss_generator_classification + config.parameter.generator.loss_function.weight_structure * loss_generator_structure + config.parameter.generator.loss_function.weight_style * loss_generator_style + config.parameter.generator.loss_function.weight_reconstruction * loss_generator_reconstruction
            loss_generator.backward()
            generator_optimizer.step()

            # discriminator
            discriminator_optimizer.zero_grad()

            loss_discriminator_adversarial = 0
            loss_discriminator_classification = 0
            for prediction_reality, prediction_writer, prediction_character in discriminator_model(result_image.detach()):
                loss_discriminator_adversarial += F.binary_cross_entropy(prediction_reality, torch.zeros_like(prediction_reality))
                loss_discriminator_classification += F.cross_entropy(prediction_writer, writer_label) + F.cross_entropy(prediction_character, character_label)

            for prediction_reality, prediction_writer, prediction_character in discriminator_model(script_image):
                loss_discriminator_adversarial += F.binary_cross_entropy(prediction_reality, torch.ones_like(prediction_reality))
                loss_discriminator_classification += F.cross_entropy(prediction_writer, writer_label) + F.cross_entropy(prediction_character, character_label)

            loss_discriminator = config.parameter.discriminator.loss_function.weight_adversarial * loss_discriminator_adversarial + config.parameter.discriminator.loss_function.weight_classification * loss_discriminator_classification
            loss_discriminator.backward()
            discriminator_optimizer.step()

            # report
            if current_iteration % report_interval == 0:
                last_time = current_time
                current_time = time.time()
                iteration_time = (current_time - last_time) / report_interval

                print('iteration {} / {}:'.format(current_iteration, num_iterations))
                print('time: {:.6f} seconds per iteration'.format(iteration_time))
                print('generator loss: {:.6f}, adversarial loss: {:.6f}, classification loss: {:.6f}, structure loss: {:.6f}, style loss: {:.6f}, reconstruction loss: {:.6f}'.format(loss_generator.item(), loss_generator_adversarial.item(), loss_generator_classification.item(), loss_generator_structure.item(), loss_generator_style.item(), loss_generator_reconstruction.item()))
                print('discriminator loss: {:.6f}, adversarial loss: {:.6f}, classification loss: {:.6f}\n'.format(loss_discriminator.item(), loss_discriminator_adversarial.item(), loss_discriminator_classification.item()))

            # save
            if current_iteration % save_interval == 0:
                save_path = os.path.join(checkpoint_path, 'iteration_{}'.format(current_iteration))
                os.makedirs(save_path, exist_ok=True)

                image_path = os.path.join(save_path, 'sample.png')
                generator_path = os.path.join(save_path, 'generator.pth')
                discriminator_path = os.path.join(save_path, 'discriminator.pth')

                image = plot_sample(reference_image, template_image, script_image, result_image)[0]
                Image.fromarray((255 * image).astype(np.uint8)).save(image_path)
                torch.save(generator_model.state_dict(), generator_path)
                torch.save(discriminator_model.state_dict(), discriminator_path)

                print('save sample image in: {}'.format(image_path))
                print('save generator model in: {}'.format(generator_path))
                print('save discriminator model in: {}\n'.format(discriminator_path))

            if current_iteration >= num_iterations:
                break


if __name__ == '__main__':
    main()

推論コード

import os
import sys
import glob
import pickle

import hydra
import numpy as np
from PIL import Image
from tqdm import tqdm
from omegaconf import OmegaConf

import torch
from torchvision import transforms

from utils.logger import Logger
from utils.function import SquarePad, ColorReverse, RecoverNormalize, SciptTyper
from model.generator import SynthesisGenerator


@hydra.main(version_base=None, config_path='./config', config_name='inference')
def main(config):
    # load configuration
    model_path = str(config.parameter.model_path)
    reference_path = str(config.parameter.reference_path)
    checkpoint_path = str(config.parameter.checkpoint_path)
    device = torch.device('cuda') if config.parameter.device == 'gpu' else torch.device('cpu')
    reference_count = int(config.parameter.reference_count)
    target_text = str(config.parameter.target_text)

    # create logger
    sys.stdout = Logger(os.path.join(checkpoint_path, 'inference.log'))
    config.parameter.checkpoint_path = checkpoint_path
    config.parameter.device = str(device)
    print(OmegaConf.to_yaml(config))

    # create model
    generator_model = SynthesisGenerator(reference_count=reference_count).to(device)
    generator_model.eval()
    generator_model.load_state_dict(torch.load(model_path, map_location=device), strict=False)

    # create transform
    input_transform = transforms.Compose([
        transforms.Grayscale(),
        transforms.ToTensor(),
        ColorReverse(),
        SquarePad(),
        transforms.Resize((128, 128), antialias=True),
        transforms.Normalize((0.5,), (0.5,))
    ])
    output_transform = transforms.Compose([
        RecoverNormalize(),
        transforms.Resize((64, 64), antialias=True),
        ColorReverse(),
        transforms.ToPILImage()
    ])
    align_transform = transforms.Compose([
        transforms.Grayscale(),
        transforms.Resize((64, 64), antialias=True),
    ])

    # fetch reference
    reference_list = []
    file_list = glob.glob('{}/*.png'.format(reference_path))
    for file in tqdm(file_list, desc='fetching reference'):
        image = Image.open(file)
        reference_list.append(image)
    while len(reference_list) < reference_count:
        reference_list.extend(reference_list)
    reference_list = reference_list[:reference_count]
    reference_image = [np.array(align_transform(image)) for image in reference_list]
    reference_image = np.concatenate(reference_image, axis=1)
    Image.fromarray(reference_image).save(os.path.join(checkpoint_path, 'reference.png'))
    reference = [input_transform(image) for image in reference_list]
    reference = torch.cat(reference, dim=0).unsqueeze(0).to(device)
    print('fetch {} reference images\n'.format(reference_count))

    # load dictionary
    with open('./assets/dictionary/character.pkl', 'rb') as file:
        character_map = pickle.load(file)
    character_remap = {value: key for key, value in character_map.items()}
    with open('./assets/dictionary/punctuation.pkl', 'rb') as file:
        punctuation_map = pickle.load(file)
    punctuation_remap = {value: key for key, value in punctuation_map.items()}
    print('load dictionary from archive\n')

    # generate script
    script_typer = SciptTyper()
    for word in tqdm(target_text, desc='generating script'):
        if word in character_remap.keys():
            image = Image.open(os.path.join('./assets/character', '{}.png'.format(character_remap[word])))
            template = input_transform(image).unsqueeze(0).to(device)
            with torch.no_grad():
                result, _, _ = generator_model(reference, template)
            result = output_transform(result.squeeze(0).detach().cpu())
            script_typer.insert_word(result, word_type='character')
        elif word in punctuation_remap.keys():
            image = Image.open(os.path.join('./assets/punctuation', '{}.png'.format(punctuation_remap[word])))
            result = align_transform(image)
            script_typer.insert_word(result, word_type='punctuation')
        else:
            raise ValueError('word {} is not supported'.format(word))
    print('generate {} words from text\n'.format(len(target_text)))
    
    # save result
    result_image = script_typer.plot_result()
    result_image.save(os.path.join(checkpoint_path, 'result.png'))
    print('save inference result in: {}\n'.format(checkpoint_path))


if __name__ == '__main__':
    main()

# test

推論結果

入力

勇敢的希梅尔会这么做的。

出力

学習自体は成立し推論も行うことができましたが、出力された画像の品質は低くハイパーパラメータや学習データをもう少しこだわる必要がありそうです。また学習中に生成していたサンプル画像ではもう少しマシな出力が見れていたため、データセットに含まれている文字かどうかも影響していそうです。

Discussion