🐥

【Team JINIAC】小学校算数データセットの作成(事後学習用データセット)

2024/08/22に公開

この記事でわかること

この記事では、日本語LLMの数学能力を向上させるために、小学校算数の教科書をベースにした事後学習用データセットを作成する方法について解説します。具体的には、以下の内容がわかります。

  • 小学校就学前から小学6年生までの算数問題をドリル形式で学習するデータセットの作成方法
  • LV(学年)ごとのカリキュラムに基づいたデータセットの構成
  • LV(学年)ごとのデータセット例
  • 事前学習済モデル(JINIAC-5B)に対して、本データセットで事後学習させてみた結果の概要

はじめに

私たち(Team JINIAC)では、日本語LLMの数学能力を向上させるために、小学校算数の教科書をベースにした事後学習用データセットを作成しました。

具体的には、小学校就学前から小学6年生までの算数問題をドリル形式で学習していくというものです。

作業のアウトライン

次のカリキュラムに従って、LV0からLV6までを順に学習していくような事後学習用データセットを作成しました。

  • まず、次表にある内容に沿って、ランダムな数字を入力して計算結果を出力させるcsv形式のデータセットを作成します。今回はExcelによりcsvファイルを作成しました。
  • 次に、csvファイルを読み込んで欠損値をフィルタリングし、事後学習用データセットとなるjsonlファイルに変換します。
  • データ量は、LVごとに2万件を目安としました。しかしながら、もっと大量のデータセットを学習させた方が性能が向上した可能性があり、この点についてはより様々なケースを試す余地があると考えています。

LV(学年)ごとの算数データの内容

学年 LV instruction_1 instruction_2
幼児 0 数字の順番を覚えましょう
小学1年生 1 整数の足し算をしよう 整数の引き算をしよう
小学2年生 2 1桁×1桁の掛け算をしよう 2桁×2桁の掛け算をしよう
小学3年生 3 ものを分けるという概念を理解しよう(整数の割り算をしよう) 時刻と時間を覚えよう
小学4年生 4 小数の計算(足し算、引き算、掛け算、割り算)をしよう 割り算をしよう(分数の概念を理解しよう)
小学5年生 5 割合の概念を理解しよう 百分率の概念を理解しよう
小学6年生 6 円の面積を計算しよう 円周の長さを計算しよう

LV(学年)ごとのデータセット例

LV.0

LV instruction input output
0 数字の数え方を覚えましょう。 0の次の数字はなんでしょう。 1
0 数字の数え方を覚えましょう。 1の次の数字はなんでしょう。 2
0 数字の数え方を覚えましょう。 2の次の数字はなんでしょう。 3
0 数字の数え方を覚えましょう。 3の次の数字はなんでしょう。 4
0 数字の数え方を覚えましょう。 4の次の数字はなんでしょう。 5
0 数字の数え方を覚えましょう。 5の次の数字はなんでしょう。 6
0 数字の数え方を覚えましょう。 6の次の数字はなんでしょう。 7
0 数字の数え方を覚えましょう。 7の次の数字はなんでしょう。 8
0 数字の数え方を覚えましょう。 8の次の数字はなんでしょう。 9
0 数字の数え方を覚えましょう。 9の次の数字はなんでしょう。 10

LV.1

LV instruction input output
1 整数の足し算をしよう。 2+0= 2
1 整数の足し算をしよう。 2+2= 4
1 整数の足し算をしよう。 3+3= 6
1 整数の足し算をしよう。 3+3= 6
1 整数の足し算をしよう。 3+5= 8
1 整数の足し算をしよう。 4+0= 4
1 整数の足し算をしよう。 0+1= 1
1 整数の足し算をしよう。 2+5= 7
1 整数の足し算をしよう。 3+4= 7
1 整数の足し算をしよう。 2+1= 3

LV.2

LV instruction input output
2 1桁×1桁の掛け算をしよう 3×2= 6
2 1桁×1桁の掛け算をしよう 0×2= 0
2 1桁×1桁の掛け算をしよう 3×7= 21
2 1桁×1桁の掛け算をしよう 7×4= 28
2 1桁×1桁の掛け算をしよう 1×4= 4
2 1桁×1桁の掛け算をしよう 1×6= 6
2 1桁×1桁の掛け算をしよう 7×7= 49
2 1桁×1桁の掛け算をしよう 8×4= 32
2 1桁×1桁の掛け算をしよう 6×3= 18
2 1桁×1桁の掛け算をしよう 6×1= 6

LV.3

LV instruction input output
3 ものを分けるという概念を理解しよう(整数の割り算をしよう) お菓子が27個あります。これを2人で分けると、ひとりあたり何個ずつになりますか。また、お菓子は何個余りますか。 1人あたり13個づつで、1個のお菓子が余ります。
3 ものを分けるという概念を理解しよう(整数の割り算をしよう) お菓子が39個あります。これを4人で分けると、ひとりあたり何個ずつになりますか。また、お菓子は何個余りますか。 1人あたり9個づつで、3個のお菓子が余ります。
3 ものを分けるという概念を理解しよう(整数の割り算をしよう) お菓子が4個あります。これを2人で分けると、ひとりあたり何個ずつになりますか。また、お菓子は何個余りますか。 1人あたり2個づつで、0個のお菓子が余ります。
3 ものを分けるという概念を理解しよう(整数の割り算をしよう) お菓子が24個あります。これを10人で分けると、ひとりあたり何個ずつになりますか。また、お菓子は何個余りますか。 1人あたり2個づつで、4個のお菓子が余ります。
3 ものを分けるという概念を理解しよう(整数の割り算をしよう) お菓子が68個あります。これを4人で分けると、ひとりあたり何個ずつになりますか。また、お菓子は何個余りますか。 1人あたり17個づつで、0個のお菓子が余ります。

LV.4

LV instruction input output
4 1桁×1桁の割り算をしよう 4÷9= 0.4444444444
4 1桁×1桁の割り算をしよう 1÷8= 0.125
4 1桁×1桁の割り算をしよう 9÷2= 4.5
4 1桁×1桁の割り算をしよう 2÷4= 0.5
4 1桁×1桁の割り算をしよう 1÷5= 0.2
4 1桁×1桁の割り算をしよう 3÷6= 0.5
4 1桁×1桁の割り算をしよう 6÷3= 2
4 1桁×1桁の割り算をしよう 1÷6= 0.1666666667
4 1桁×1桁の割り算をしよう 4÷2= 2
4 1桁×1桁の割り算をしよう 8÷2= 4

LV.5

LV instruction input output
5 割合の概念を理解しよう 7kgの2倍は、何kgでしょうか。 14kgです。
5 割合の概念を理解しよう 1kgの6倍は、何kgでしょうか。 6kgです。
5 割合の概念を理解しよう 7kgの2倍は、何kgでしょうか。 14kgです。
5 割合の概念を理解しよう 2kgの2倍は、何kgでしょうか。 4kgです。
5 割合の概念を理解しよう 3kgの2倍は、何kgでしょうか。 6kgです。
5 割合の概念を理解しよう 2kgの5倍は、何kgでしょうか。 10kgです。
5 割合の概念を理解しよう 5kgの0倍は、何kgでしょうか。 0kgです。
5 割合の概念を理解しよう 4kgの2倍は、何kgでしょうか。 8kgです。
5 割合の概念を理解しよう 1kgの5倍は、何kgでしょうか。 5kgです。
5 割合の概念を理解しよう 6kgの1倍は、何kgでしょうか。 6kgです。

LV.6

LV instruction input output
6 円の面積は、半径×半径×円周率で求められます。円周率を3.14として、円の面積を計算しよう。 半径が2mの円の面積は何㎡でしょうか。 12.56㎡です。
6 円の面積は、半径×半径×円周率で求められます。円周率を3.14として、円の面積を計算しよう。 半径が2mの円の面積は何㎡でしょうか。 12.56㎡です。
6 円の面積は、半径×半径×円周率で求められます。円周率を3.14として、円の面積を計算しよう。 半径が5mの円の面積は何㎡でしょうか。 78.5㎡です。
6 円の面積は、半径×半径×円周率で求められます。円周率を3.14として、円の面積を計算しよう。 半径が7mの円の面積は何㎡でしょうか。 153.86㎡です。
6 円の面積は、半径×半径×円周率で求められます。円周率を3.14として、円の面積を計算しよう。 半径が9mの円の面積は何㎡でしょうか。 254.34㎡です。
6 円の面積は、半径×半径×円周率で求められます。円周率を3.14として、円の面積を計算しよう。 半径が2mの円の面積は何㎡でしょうか。 12.56㎡です。
6 円の面積は、半径×半径×円周率で求められます。円周率を3.14として、円の面積を計算しよう。 半径が4mの円の面積は何㎡でしょうか。 50.24㎡です。
6 円の面積は、半径×半径×円周率で求められます。円周率を3.14として、円の面積を計算しよう。 半径が1mの円の面積は何㎡でしょうか。 3.14㎡です。
6 円の面積は、半径×半径×円周率で求められます。円周率を3.14として、円の面積を計算しよう。 半径が5mの円の面積は何㎡でしょうか。 78.5㎡です。
6 円の面積は、半径×半径×円周率で求められます。円周率を3.14として、円の面積を計算しよう。 半径が4mの円の面積は何㎡でしょうか。 50.24㎡です。

作成したデータの欠損値フィルタリングとjsonl形式への変換

次に、csvファイルを読み込んで欠損値をフィルタリングし、事後学習用データセットとなるjsonlファイルに変換する工程について説明します。今回は、Google Colaboratoryを用いました。

Google Driveのマウント

まず、Google Driveをマウントして作業ディレクトリを指定します。

from google.colab import drive
drive.mount('/content/drive')

# 保存用ディレクトリの指定
submit_dir = "/content/drive/MyDrive/math/"

CSVファイルの読み込みと欠損値のフィルタリング

次に、pandasライブラリを使用してCSVファイルを読み込み、欠損値をフィルタリングします。

import pandas as pd
import numpy as np

# CSVファイルを読み込む
file_path = '/content/drive/MyDrive/math/math/LV4事後学習用math.csv'
df = pd.read_csv(file_path)
df = df.iloc[:, :4]

# 欠損値の数を列ごとに確認
print(df.isnull().sum())

# DataFrame内で'#DIV/0!'が含まれている場所を見つける
div_zero_locs = np.where(df == '#DIV/0!')

# 結果を表示
for row, col in zip(*div_zero_locs):
    print(f'Row: {row}, Column: {col}')

'#DIV/0!'が含まれている行の削除

次に、'#DIV/0!'が含まれている行を削除します。

# '#DIV/0!'が含まれている行を削除
df = df.replace('#DIV/0!', pd.NA)
df = df.dropna()

# 元のCSVファイルを上書き
df.to_csv(file_path, index=False)

CSVファイルをjsonl形式に変換

最後に、CSVファイルをjsonl形式に変換します。

import os
import pandas as pd

# 指定されたパス
path = '/content/drive/MyDrive/math/math'

# パス内のすべてのcsvファイルを取得
files = [f for f in os.listdir(path) if f.endswith('.csv')]

# 各ファイルを一つずつ処理
for file in files:
    # ファイルを読み込み、データフレームに変換
    df = pd.read_csv(os.path.join(path, file))

    # 最初の4列を取得
    df = df.iloc[:, :4]

    # ファイルの1列目の数値を取得
    num = df.iloc[0, 0]

    # jsonlファイル名を作成
    jsonl_file = f'math_LV{num}.jsonl'

    # データフレームをjsonlファイルとして保存
    df.to_json(os.path.join(path, jsonl_file), orient='records', lines=True, force_ascii=False)

# dfはあなたのDataFrameです
is_null = df.isnull().any().any()

if is_null:
    print("DataFrameにはnullのセルが含まれています。")
else:
    print("DataFrameにはnullのセルは含まれていません。")

# データフレームをCSVファイルとして保存
df.to_csv('/content/drive/MyDrive/math/math.csv', index=False)

以上が、小学校算数をベースにした事後学習用データセットの作成方法と各LVのデータセット例です。

事前学習済モデル(JINIAC-5B)に対して、本データセットを学習させた結果の概要

このデータセットを事後学習させたところ、数の数え方や簡単な足し算、引き算については正解を出せるケースが増加しました。

PROMPT = """\
以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。

### 指示:
数字の数え方を覚えましょう。
153663の次の数字はなんでしょう。

### 応答:

"""

inputs = tokenizer.encode(PROMPT, add_special_tokens=False, return_tensors="pt").to(model.device)
with torch.no_grad():
    tokens = model.generate(
        input_ids=inputs,
        max_new_tokens=64,
        do_sample=True,
        #num_beams=2,
        #num_beam_groups=3,
        #no_repeat_ngram_size=2,
        temperature=0.8,
        #top_k=50,
        #top_p=0.9,
        #diversity_penalty=0.0,
        #repetition_penalty=1.05,
        #bad_words_ids=,
        #force_words_ids=,
        #constraints=,
        pad_token_id=tokenizer.pad_token_id,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id
        #early_stopping=True
    )
output = tokenizer.decode(tokens[0], skip_special_tokens=False)
print(output)

以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。

指示:

数字の数え方を覚えましょう。
153663の次の数字はなんでしょう。

応答:

153664</s>

一方で、掛け算、割り算等については、正解を出すことはあまりできませんでした。

PROMPT = """\
以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。

### 指示:
ものを分けるという概念を理解しよう(整数の割り算をしよう)
お菓子が30個あります。これを4人で分けると、ひとりあたり何個ずつになりますか。また、お菓子は何個余りますか。
### 応答:

"""

inputs = tokenizer.encode(PROMPT, add_special_tokens=False, return_tensors="pt").to(model.device)

output = tokenizer.decode(tokens[0], skip_special_tokens=False)
print(output)

以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。

指示:

ものを分けるという概念を理解しよう(整数の割り算をしよう)
お菓子が30個あります。これを4人で分けると、ひとりあたり何個ずつになりますか。また、お菓子は何個余りますか。

応答:

1人あたり7個づつで、3個のお菓子が余ります。</s>

今回は、時間の都合上、様々なパターンでの学習を行うことができませんでした。例えば、より大量のデータセットを学習させた場合や、カリキュラム学習の工程上の工夫などの改善を図ることで、さらに高い性能を目指せる可能性があります。

まとめ

今回の記事では、小学校算数の教科書をベースにした事後学習用データセットの作成方法について解説しました。具体的には、各学年ごとのカリキュラムに基づいたデータセットの構成や、欠損値のフィルタリングとjsonl形式への変換方法について説明しました。

このようなデータセットを活用することで、日本語LLMの数学能力を向上させることが期待できます。特に、数の数え方や簡単な足し算、引き算については正解を出せるケースが増加しました。

💡 この成果は、NEDO(国立研究開発法人新エネルギー・産業技術総合開発機構)の助成事業「ポスト5G情報通信システム基盤強化研究開発事業」(JPNP20017)の結果得られたものです。

東大松尾・岩澤研究室 | LLM開発 プロジェクト[GENIAC]

Discussion