🐼

JGLUEのSTSとNLIにおける文ペア数を調べてみた

2023/01/14に公開

JGLUEのSTSとNLIにおける文ペア数を調べてみた

概要

自然言語処理モデルの評価用データセットJGLUEのSTSとNLIのラベルの各組み合わせで文ペア数を調べた.

目次

  1. 概要
  2. JGLUE
  3. STS
  4. NLI
  5. 環境
  6. コード
    1. NLIとSTSをマージ
    2. ヒートマップで各組み合わせの文ペア数を出力
    3. 各組み合わせの文ペア数の結果
  7. おまけ

JGLUE

2022年に言語処理学会で『JGLUE: 日本語言語理解ベンチマーク』という論文が紹介された.
JGLUEのSTSとNLIの文ペアはYahoo Japan Captions Datasetを用いている.

次に公開されているデータセットを示す.

https://github.com/yahoojapan/JGLUE

STS

STS(Semantic Textual Similarity)とは,文ペアの意味的の類似度を求めるタスクである.
JGLUEでは,正解の類似度0(意味が完全に異なる)〜5(意味が等価)の間の値として付与されている.

NLI

NLI(Natural Language Inference,自然言語推論)とは,前提文と仮説文の文ペアがあり,
前提文が仮説文に対して持つ推論関係を認識するタスクである.
RTE(Recognize Textual Entailment)とも呼ばれる.
JGLUEでは,「含意(entailment)」「矛盾(contradiction)」「中立(neutral)」の3つの推論関係で定義され,付与されている.

環境

Colab上で実行した.

コード

NLIとSTSにおける各組み合わせで,文ペア数をヒートマップで出力した.
ライブラリにはpandas, numpy, matplotlib, seabornを使用した.

NLIとSTSをマージ

今回は学習用(train)と評価用(valid)を合わせた.
NLIとSTSのラベル名が同じなので,別名で再設定した.
また,NLIとSTSの両方がアノテートされた文ペアのみを用いるため,yjcaptions_idでマージした.

import pandas as pd

JGLUE_DIR = '/.../JGLUE/datasets/'
NLI_DIR = JGLUE_DIR + 'jnli-v1.1/'
STS_DIR = JGLUE_DIR + 'jsts-v1.1/'
JGLUE_FILENAME = { 'train' : 'train-v1.1.json' , 'valid' : 'valid-v1.1.json' }
JGLUE_COLS = [ 'yjcaptions_id' , 'sentence1' , 'sentence2' , 'label' ]

# NLI
NLI_COLUMNS = [ 'id', 'sentenceA' , 'sentenceB' , 'label_nli' ]

nli_train_df = pd.read_json( NLI_DIR + JGLUE_FILENAME['train'] , lines=True , encoding='utf-8' )
nli_train_df = nli_train_df.loc[ : , JGLUE_COLS ]
nli_train_df.columns = NLI_COLUMNS

nli_valid_df = pd.read_json( NLI_DIR + JGLUE_FILENAME['valid'] , lines=True , encoding='utf-8' )
nli_valid_df = nli_valid_df.loc[ : , JGLUE_COLS ]
nli_valid_df.columns = COLUMNS

# STS
STS_COLUMNS = [ 'id', 'sentenceA' , 'sentenceB' , 'label_sts' ]

sts_train_df = pd.read_json( STS_DIR + JGLUE_FILENAME['train'] , lines=True , encoding='utf-8' )
sts_train_df = sts_train_df.loc[ : , JGLUE_COLS ]
sts_train_df.columns = STS_COLUMNS

sts_valid_df = pd.read_json( STS_DIR + JGLUE_FILENAME['valid'] , lines=True , encoding='utf-8' )
sts_valid_df = sts_valid_df.loc[ : , JGLUE_COLS ]
sts_valid_df.columns = STS_COLUMNS

# merge

nli_df = pd.concat( [ nli_train_df, nli_valid_df ] , axis=0 )
sts_df = pd.concat( [ sts_train_df, sts_train_df ] , axis=0 )
df = pd.merge( nli_df, sts_df[['id', 'label_sts']] , on='id' )

ヒートマップで各組み合わせの文ペア数を出力

value_counts()では数が0の組み合わせが出てこないため,同じサイズ・同じ設定で0で埋めたDataFrameを作成して,補完した.
STSの値の最小間隔が0.2なので,0〜5で0.2ずつ増えていく配列を生成.

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

N_LABEL = 3

label_df = df[[ 'label_rte' , 'label_sts' ]]

freq = label_df.value_counts(sort=True, ascending=True, dropna=False)
freq_df = pd.DataFrame(freq)
freq_df.columns = ['freq']
freq_df = freq_df.sort_index()

sts_range = np.round(np.arange(0, 5.2, 0.2) , 2)
sts_list = sts_range.tolist()
sts_index_range = sts_list * N_LABEL

rte_index_range = ['entailment']*len(sts_list)
rte_index_range.extend( ['neutral']*len(sts_list))
rte_index_range.extend( ['contradiction']*len(sts_list))

empty_df = pd.DataFrame(index=[rte_index_range, sts_index_range], columns=['freq'])
empty_df.fillna(0, inplace=True)
empty_df.update(freq_df)

freq_df = empty_df

# ここサボって強引にした
freq_2d = [
    freq_df.loc['entailment'].T.values.tolist()[0],
    freq_df.loc['neutral'].T.values.tolist()[0],
    freq_df.loc['contradiction'].T.values.tolist()[0]
]

df_ = pd.DataFrame(data=freq_2d, index=['entailment', 'neutral', 'contradiction'], columns=np.round(sts_range,2))
df_ = df_.T

plt.figure(figsize=(10, 20))
sns.heatmap(df_, cmap='Blues', vmax=500, annot=True)

各組み合わせの文ペア数の結果

heatmapで出力した結果

"NLIとSTSにおける各組み合わせの文ペア数"

おまけ

次のコードでNLIとSTSのそれぞれの文ペア数を求めた.

def print_freq(label_df):
    s = label_df
    freq = s.value_counts()
    print(freq)

label_df = df[[ 'label_rte' , 'label_sts' ]]

print_freq(label_df[0])
print_freq(label_df[1])
print_freq(label_df[2])

出力結果を整理したもの.

NLIの文ペア数

NLIラベル 文ペア数
entailment 2558
neutral 9768
contradiction 3734

STSの文ペア数

STSラベル 文ペア数
5.0 170
4.8 148
4.6 274
4.4 284
4.2 482
4.0 764
3.8 1072
3.6 1086
3.4 1180
3.2 1368
3.0 1192
2.8 1242
2.6 750
2.4 786
2.2 944
2.0 768
1.8 758
1.6 524
1.4 486
1.2 568
1.0 388
0.8 332
0.6 250
0.4 112
0.2 82
0.0 50

Discussion