ERNIE-LayoutのREADME
ERNIE-LayoutのREADMEを読んで試してみます。
README
-
モデル説明
事前学習により、視覚な文章を理解できるようになってきている。
しかし、既存の手法の多くは、レイアウトを意識した理解が最適とは言えない。
本論文では、ワークフローを通してレイアウトを学習し、テキスト、レイアウト、画像の特徴を組み合わせたより良い表現を学習する、新しい文書学習ソリューションERNIE-Layoutを提案します。
具体的には、まず直列化段階で入力シーケンスの並び替えを行い、次に相関的な事前学習タスクである読み順予測を提示し、文書の適切な読み順を学習する。
また、モデルのレイアウト認識を向上させるため、マルチモーダル変換器に空間認識型の離散化アテンションを、事前学習段階に置換領域予測タスクを統合する。
実験の結果、ERNIE-Layoutは様々な下流タスクにおいて優れた性能を達成し、主要な情報抽出、文書画像分類、文書質問応答データセットにおいて最高精度となることが示された。
本論文は、EMNLP 2022(Findings)に採択されました。文書理解の商用利用範囲を広げるため、ERNIE-Layoutの多言語モデルをPaddleNLPで公開します。 -
すぐに使ってみる
- HuggingFace web demo
HuggingFaceのウェブデモはこちら - Taskflow
- Input Format
デフォルトではPaddleOCRを使用しますが、word_boxesを介して独自のOCR結果を使用することも可能です。データ形式はList[str, List[float, float, float, float]][ {"doc": "./book.png", "prompt": ["What is the name of the author of 'The Adventure Zone: The Crystal Kingdom’?", "What type of book cover does The Adventure Zone: The Crystal Kingdom have?", "For Rage, who is the author listed as?"]}, {"doc": "./resume.png", "prompt": ["五百丁本次想要担任的是什么职位?", "五百丁是在哪里上的大学?", "大学学的是什么专业?"]} ]
[ {"doc": doc_path, "prompt": prompt, "word_boxes": word_boxes} ]
- Support single and batch input
- Image from http link
from pprint import pprint from paddlenlp import Taskflow docprompt = Taskflow("document_intelligence", lang="en") docprompt([{"doc": "https://bj.bcebos.com/paddlenlp/taskflow/document_intelligence/images/book.png", "prompt": ["What is the name of the author of 'The Adventure Zone: The Crystal Kingdom’?", "What type of book cover does The Adventure Zone: The Crystal Kingdom have?", "For Rage, who is the author listed as?"]}]) [{'prompt': "What is the name of the author of 'The Adventure Zone: The Crystal Kingdom’?", 'result': [{'value': 'Clint McElroy. Carey Pietsch, Griffn McElroy, Travis McElroy', 'prob': 0.99, 'start': 22, 'end': 39}]}, {'prompt': 'What type of book cover does The Adventure Zone: The Crystal Kingdom have?', 'result': [{'value': 'Paperback', 'prob': 1.0, 'start': 51, 'end': 51}]}, {'prompt': 'For Rage, who is the author listed as?', 'result': [{'value': 'Bob Woodward', 'prob': 1.0, 'start': 91, 'end': 93}]}]
- Image from local path
from pprint import pprint from paddlenlp import Taskflow docprompt = Taskflow("document_intelligence") pprint(docprompt([{"doc": "./resume.png", "prompt": ["五百丁本次想要担任的是什么职位?", "五百丁是在哪里上的大学?", "大学学的是什么专业?"]}])) [{'prompt': '五百丁本次想要担任的是什么职位?', 'result': [{'end': 7, 'prob': 1.0, 'start': 4, 'value': '客户经理'}]}, {'prompt': '五百丁是在哪里上的大学?', 'result': [{'end': 37, 'prob': 1.0, 'start': 31, 'value': '广州五百丁学院'}]}, {'prompt': '大学学的是什么专业?', 'result': [{'end': 44, 'prob': 0.82, 'start': 38, 'value': '金融学(本科)'}]}]
- Parameter Description
- batch_size: 各バッチの入力数、デフォルトは1。
- lang: PaddleOCRの言語、enは英語、デフォルトはch。
- topn: 最も確率の高い上位n個の結果を返す、デフォルトは1.
- Image from http link
- HuggingFace web demo
-
モデル性能
- Dataset
Dataset Task Language Note FUNSD Key Information Extraction English - XFUND-ZH Key Information Extraction Chinese - DocVQA-ZH Document Question Answering Chinese オリジナルデータセットです。4,187枚の学習画像、500枚の検証画像、500枚のテスト画像です。 RVL-CDIP (sampled) Document Image Classification English RVL-CDIPデータセットは、16クラス、40万枚のグレースケール画像からなり、1クラスあたり25,000枚の画像で構成されています。元のデータセットが大きく、学習に時間がかかるため、ダウンサンプリングを行った。このデータセットは、6,400枚の学習画像、800枚の検証画像、800枚のテスト画像で検証した。 - Results
Model FUNSD RVL-CDIP (sampled) XFUND-ZH DocVQA-ZH LayoutXLM-Base 86.72 90.88 86.24 66.01 ERNIE-LayoutX-Base 89.31 90.29 88.58 69.57 -
Evaluation Methods
- 上記のタスクは、いずれもグリッドサーチ法に基づくハイパーパラメータ探索を行うものである。FUNSDとXFUND-ZHの評価ステップ間隔は共に100、評価指標はF1-Scoreである。RVL-CDIPの評価ステップは2000、評価指標はAccuracyである。DocVQA-ZHの評価ステップ間隔は10000、評価指標はANLSです。
- Hyper Parameters search ranges
Hyper Parameters FUNSD RVL-CDIP (sampled) XFUND-ZH DocVQA-ZH learning_rate 5e-6, 1e-5, 2e-5, 5e-5 5e-6, 1e-5, 2e-5, 5e-5 5e-6, 1e-5, 2e-5, 5e-5 5e-6, 1e-5, 2e-5, 5e-5 batch_size 1, 2, 4 8, 16, 24 1, 2, 4 8, 16, 24 warmup_ratio - 0, 0.05, 0.1 - 0, 0.05, 0.1 FUNSDとXFUNDのlr_scheduler_typeの戦略は一定なので、warmup_ratioは除外しています。
- FUNSDとXFUND-ZHの微調整にはmax_stepsが適用され、それぞれ10000ステップと20000ステップ、num_train_epochsはDocVQA-ZHとRVL-CDIPでそれぞれ6と20に設定されています。
-
Best Hyper Parameter
Model FUNSD RVL-CDIP (sampled) XFUND-ZH DocVQA-ZH LayoutXLM-Base 1e-5, 2, _ 1e-5, 8, 0.1 1e-5, 2, _ 2e-5. 8, 0.1 ERNIE-LayoutX-Base 2e-5, 4, _ 1e-5, 8, 0. 1e-5, 4, _ 2e-5. 8, 0.05 -
ファインチューニング
- Installation
pip install -r requirements.txt
4.1 Key Information Extraction
- FUNSD Train
python -u run_ner.py \ --model_name_or_path ernie-layoutx-base-uncased \ --output_dir ./ernie-layoutx-base-uncased/models/funsd/ \ --dataset_name funsd \ --do_train \ --do_eval \ --max_steps 10000 \ --eval_steps 100 \ --save_steps 100 \ --save_total_limit 1 \ --load_best_model_at_end \ --pattern ner-bio \ --preprocessing_num_workers 4 \ --overwrite_cache false \ --use_segment_box \ --doc_stride 128 \ --target_size 1000 \ --per_device_train_batch_size 4 \ --per_device_eval_batch_size 4 \ --learning_rate 2e-5 \ --lr_scheduler_type constant \ --gradient_accumulation_steps 1 \ --seed 1000 \ --metric_for_best_model eval_f1 \ --greater_is_better true \ --overwrite_output_dir
- XFUND-ZH Train
python -u run_ner.py \ --model_name_or_path ernie-layoutx-base-uncased \ --output_dir ./ernie-layoutx-base-uncased/models/xfund_zh/ \ --dataset_name xfund_zh \ --do_train \ --do_eval \ --lang "ch" \ --max_steps 20000 \ --eval_steps 100 \ --save_steps 100 \ --save_total_limit 1 \ --load_best_model_at_end \ --pattern ner-bio \ --preprocessing_num_workers 4 \ --overwrite_cache false \ --use_segment_box \ --doc_stride 128 \ --target_size 1000 \ --per_device_train_batch_size 4 \ --per_device_eval_batch_size 4 \ --learning_rate 1e-5 \ --lr_scheduler_type constant \ --gradient_accumulation_steps 1 \ --seed 1000 \ --metric_for_best_model eval_f1 \ --greater_is_better true \ --overwrite_output_dir
4.2 Document Question Answering
- DocVQA-ZH Train
python3 -u run_mrc.py \ --model_name_or_path ernie-layoutx-base-uncased \ --output_dir ./ernie-layoutx-base-uncased/models/docvqa_zh/ \ --dataset_name docvqa_zh \ --do_train \ --do_eval \ --lang "ch" \ --num_train_epochs 6 \ --lr_scheduler_type linear \ --warmup_ratio 0.05 \ --weight_decay 0 \ --eval_steps 10000 \ --save_steps 10000 \ --save_total_limit 1 \ --load_best_model_at_end \ --pattern "mrc" \ --use_segment_box false \ --return_entity_level_metrics false \ --overwrite_cache false \ --doc_stride 128 \ --target_size 1000 \ --per_device_train_batch_size 8 \ --per_device_eval_batch_size 8 \ --learning_rate 2e-5 \ --preprocessing_num_workers 32 \ --save_total_limit 1 \ --train_nshard 16 \ --seed 1000 \ --metric_for_best_model anls \ --greater_is_better true \ --overwrite_output_dir
4.3 Document Image Classification
- RVL-CDIP Train
python3 -u run_cls.py \ --model_name_or_path ernie-layoutx-base-uncased \ --output_dir ./ernie-layoutx-base-uncased/models/rvl_cdip_sampled/ \ --dataset_name rvl_cdip_sampled \ --do_train \ --do_eval \ --num_train_epochs 20 \ --lr_scheduler_type linear \ --max_seq_length 512 \ --warmup_ratio 0.05 \ --weight_decay 0 \ --eval_steps 2000 \ --save_steps 2000 \ --save_total_limit 1 \ --load_best_model_at_end \ --pattern "cls" \ --use_segment_box \ --return_entity_level_metrics false \ --overwrite_cache false \ --doc_stride 128 \ --target_size 1000 \ --per_device_train_batch_size 8 \ --per_device_eval_batch_size 8 \ --learning_rate 1e-5 \ --preprocessing_num_workers 32 \ --train_nshard 16 \ --seed 1000 \ --metric_for_best_model acc \ --greater_is_better true \ --overwrite_output_dir
-
デプロイ
5.1 Inference Model Export
ファインチューニング後、Model Export Scriptで推論モデルをエクスポートすると、推論モデルは指定したoutput_pathに保存されます。- Export the model fine-tuned on FUNSD
python export_model.py --task_type ner --model_path ./ernie-layoutx-base-uncased/models/funsd/ --output_path ./ner_export
- Export the model fine-tuned on DocVQA-ZH
python export_model.py --task_type mrc --model_path ./ernie-layoutx-base-uncased/models/docvqa_zh/ --output_path ./mrc_export
- Export the model fine-tuned on RVL-CDIP(sampled)
python export_model.py --task_type cls --model_path ./ernie-layoutx-base-uncased/models/rvl_cdip_sampled/ --output_path ./cls_export
- Parameter Description
- model_path:dygraphモデルパラメータの保存ディレクトリ。デフォルトは、"./checkpoint/"
- output_path:静的グラフモデルパラメータの保存先ディレクトリ。デフォルトは、"./export"
- Directory
export/ ├── inference.pdiparams ├── inference.pdiparams.info └── inference.pdmodel
5.2 Python Deploy
ERNIE-Layout Python Deploy Guideを確認してください
References
- ERNIE-Layout: Layout-Knowledge Enhanced Multi-modal Pre-training for Document Understanding
- ICDAR 2019 Competition on Scene Text Visual Question Answering
- XFUND dataset
- FUNSD dataset
- RVL-CDIP dataset
- Competition of Insurance Document Visual Cognition Question Answering
試してみる
google colabで動かしてみます。
推論してみるだけなのでCPUで。
インストール
!python -m pip install paddlepaddle==2.4.2 -i https://mirror.baidu.com/pypi/simple
!pip install paddlenlp
!pip install paddleocr
推論
以下の画像から「氏名」「生年月日」「本籍」「住所」「交付日」を取得してみます。
photo by Tiny Banquet Committee
from pprint import pprint
from paddlenlp import Taskflow
docprompt = Taskflow("document_intelligence", lang="ch")
docprompt([{"doc": "https://storage.googleapis.com/zenn-user-upload/389608c2b63b-20230306.jpg", "prompt": ["氏名は?", "生年月日は?", "本籍は?", "住所は?", "交付日は?"]}])
[{'prompt': '氏名は?',
'result': [{'value': '又吉', 'prob': 0.99, 'start': 2, 'end': 3}]},
{'prompt': '生年月日は?',
'result': [{'value': '平成22 年2月28日', 'prob': 1.0, 'start': 4, 'end': 12}]},
{'prompt': '本籍は?',
'result': [{'value': '力少才市大字字×夕力285-4', 'prob': 0.88, 'start': 16, 'end': 28}]},
{'prompt': '住所は?',
'result': [{'value': '東京都台东区秋葉原学生服茶「番長亍一儿の下', 'prob': 1.0, 'start': 31, 'end': 51}]},
{'prompt': '交付日は?',
'result': [{'value': '平成24年01月15日', 'prob': 1.0, 'start': 54, 'end': 61}]}]
lang: PaddleOCRの言語、enは英語、デフォルトはch。
ということでしたので、「ch」で試しましたが、PaddleOCRは、「japan」に対応しているので、試してみます。
docprompt = Taskflow("document_intelligence", lang="japan")
docprompt([{"doc": "https://storage.googleapis.com/zenn-user-upload/389608c2b63b-20230306.jpg", "prompt": ["氏名は?", "生年月日は?", "本籍は?", "住所は?", "交付日は?"]}])
[{'prompt': '氏名は?',
'result': [{'value': '又ー吉', 'prob': 0.99, 'start': 2, 'end': 4}]},
{'prompt': '生年月日は?',
'result': [{'value': '平成22年2月28日', 'prob': 1.0, 'start': 5, 'end': 14}]},
{'prompt': '本籍は?',
'result': [{'value': 'クジラ県カツオ市大字サンマ字メダカ285-4', 'prob': 1.0, 'start': 18, 'end': 37}]},
{'prompt': '住所は?',
'result': [{'value': '東京都台東区秋葉原学生服喫茶番長テーブルの下', 'prob': 1.0, 'start': 40, 'end': 61}]},
{'prompt': '交付日は?',
'result': [{'value': '平成24年01月15日', 'prob': 0.99, 'start': 64, 'end': 74}]}]
いいですね!
氏名に謎の「ー」が惜しいですが。
Discussion
hello, have you tried to train this model? if so,could you share how the dataset was made?
Thank you for your comment.
I have only conducted inference on the pre-training model, not learning.