📚

精度の高い自然言語処理モデルでshapを試そうとしたが最終的には断念した件

2021/08/02に公開

前回のshap記事で書いたコードを改善しようとして最終的には失敗したのですが、後世に役立てばと思いまとめます。

なお、shap の概要については下記記事をご参照ください。

https://www.datarobot.com/jp/blog/explain-machine-learning-models-using-shap/

前回の概要

前回のshap記事のコードはWord2Vecの日本語学習済みモデルをEmbeddingしCNNを使い、エポック数は2回、optimizer は Adam でした。
バリデーションデータのf1スコアは約72%でした。

試行結果

試行内容 結果
試行1 optimizer を AdaBelief に変更し、エポック数を30にした f1スコアが約86%に向上
試行2 CNN3層からConVLSTM2D 1層に変更。エポック数は5回に抑えた 学習時間が大幅に増大。f1スコアは約81%
試行3 BERT用のshapツール、TranSHAPを発見。英語のデータセットを使用したサンプルコードを試す bert-base-cased を使ったBERTで shap value を算定できた
試行4 結果は失敗。日本語のデータセットを使ったBERTでTranSHAPを試す。 shap value を算定しようとするもエラーで断念。

その他気付き

Epoch 数を増やしすぎると shap value 算定時に データの中に floatかNaNがあるというエラーが出るので、エポック数はほどほどにする必要がありました。

過学習が影響しているんでしょうか・・・。

試行1

日本語データセット、Word2Vecの日本語学習済みモデルでEmbedding、KerasのConv2D は前回記事のコードと同様とし、optimizer を AdaBeliefに変更、エポック数を30 に増やしました。

コード

shap の結果は以下のとおりです。

特徴量ごとに SHAP Value を可視化

1つのサンプルを force_plot で可視化

複数のサンプルの Expected Value を出力

Predicted vector is [4.98997181e-08 2.09270320e-06 9.99997139e-01 1.30058515e-08
 2.71443401e-08 4.02676619e-07 2.56113086e-09 2.04109396e-09
 2.17494119e-07] = Class 2 = kaden-channel
Input features/words:
['タブレット端末' '何' '買う' '人' '多い' '悩む' '人' '揃い' 'タブレット' '持つ' 'クリスマス' '照明'
 'タブレット' '遠隔' '操作' '点灯' '話題' '2' '0' '0' 'キロ' '離れる' '照明' '点灯' '話題' '時'
 '使用' 'タブレット' '集まる' '使用' 'ソニー' 'tablet' 's' 'ipad' '驚き' '世界中' '広がる' '揃い'
 'タブレット' '冬' 'ヒット' 'sony tablet' 's' 'no' 'ipad' 'with' 'sony' 'tablet'
 '関連' '記事' '毒舌' 'ブーム' '続く' 'テレビ' '一番' '多い' '出る' '有吉' '2012年' '話題' 'モチーフ'
 'ライト' 'led' 'w' '充電' '式' '売れ筋' 'チェック' 'ネット' '高位' '家電' '売れる' '量販店' '販売'
 '額' '話題' 'スマホ' '買う' 'facebook' 'android' 'ユーザー' '使う' 'アプリ' 'facebook'
 '話題' '家庭' '充電' '着脱' '式' 'バッテリー' '搭載' '新発売' '売れ筋' 'チェック']
True class is 2 = kaden-channel
Explainer expected value is [-1.69660641 -2.40835054 -2.32374329 -2.87834988 -2.0009611  -1.98163116
 -2.19596157 -1.90211317 -1.73528568], i.e. class 0 is the most common.

複数のサンプルを同時に force_plot で可視化

試行2

Conv2D 3層を ConvLSTM2D 1層に変更しました。

ConvLSTM2D 1層 だけでも学習時間が大幅に増加したので、エポック数も5回に抑えました。

コード

特徴量ごとに SHAP Value を可視化

1つのサンプルを force_plot で可視化

複数のサンプルの Expected Value を出力

Predicted vector is [2.1706934e-03 2.5475287e-04 1.4284113e-04 1.1484002e-03 5.3635560e-04
 9.9162525e-01 1.3585549e-04 4.5576988e-05 3.9403252e-03] = Class 5 = peachy
Input features/words:
['毛穴' 'カバー' '明るい' 'ピンク色' '肌' '抑える' '肌' '仕上げる' '冬' '気' '紫外線' 'カット' '嬉しい'
 '今回' 'peachy' 'パウダー' 'ファンデーション' '33' '系' 'ベース' 'クリーム' 'uv' 'セット' '3' '名'
 '様' 'プレゼント' '機会' '乾燥' '潤い' '美肌' '手' '入れる' 'ベース' 'メイク' 'セット' '3' '名' '様'
 'プレゼント' '賞品' '応募' '数' 'ベース' 'メイク' 'セット' 'パウダー' 'ファンデーション' '33' '系' 'ベース'
 'クリーム' 'uv' '3' '名' '様' 'プレゼント' '応募期間' '2011年' '11月1日' '火' '2011年' '月'
 '応募方法' '下記' '応募' 'ボタン' '応募フォーム' '必要' '事項' '記入' '上' '応募' '当選' '発表' '当選'
 '厳正' '抽選' '上' '決定' '当選' '発送' '発表' '賞品' '発送' '頃' '予定' '都合' '賞品' '発送' '遅れる'
 '場合' '了承' 'プレゼント' '終了' '関連' '情報' '好き' 'キャンペーン' 'サイト']
True class is 5 = peachy
Explainer expected value is [-2.2171125  -2.14492297 -1.99543403 -3.01041016 -1.76218854 -2.30711388
 -1.72017186 -1.6678746  -2.40595878], i.e. class 7 is the most common.

複数のサンプルを同時に force_plot で可視化

試行3

TranSHAPというBERT向けの shap を使用。

TranSHAPのサンプルコードを実行しました。

データセットは英語です。

コード

LIME

shap value を可視化

試行4

日本語のデータセットを使い BERT で TranSHAP を試すも LIME、shap いずれもエラーが出て失敗。

コード

なお、コードは下記のサイトのものを改修して使用しています。

https://qiita.com/takubb/items/fd972f0ac3dba909c293

TranSHAPとBERTでトークンが一致していないのか?LIMEの場合は下記のコードでエラーが発生しました。

なお、shap もほぼ同様のエラーが発生。

shap.initjs()
to_use = texts[-2:]
for i, example in enumerate(to_use):
    logging.info(f"Example {i+1}/{len(to_use)} start")
    temp = predictor.split_string(example)
    exp = explainer.explain_instance(text_instance=example, classifier_fn=predictor.predict, num_features=len(temp))
    logging.info(f"Example {i + 1}/{len(to_use)} done")
    words = exp.as_list()
    #sum_ = 0.6
    #exp.local_exp = {x: [(xx, yy / (sum(hh for _, hh in exp.local_exp[x])/sum_)) for xx, yy in exp.local_exp[x]] for x in exp.local_exp}
    exp.show_in_notebook(text=True, labels=(exp.available_labels()[0],))
ValueError                                Traceback (most recent call last)
<ipython-input-161-868865d5a269> in <module>()
      4     logging.info(f"Example {i+1}/{len(to_use)} start")
      5     temp = predictor.split_string(example)
----> 6     exp = explainer.explain_instance(text_instance=example, classifier_fn=predictor.predict, num_features=len(temp))
      7     logging.info(f"Example {i + 1}/{len(to_use)} done")
      8     words = exp.as_list()

2 frames
/content/TransSHAP/explainers/LIME_for_text.py in predict(self, data)
     28             #    x = [xx for xxx in x for xx in xxx]
     29             for w in x:
---> 30                 id = ref_temp.index(w)
     31                 new[id] = w
     32                 ref_temp[id] = ""

ValueError: 'みんながノリがよくなった改革とは' is not in list

日本語のデータセットで TranSHAP を使うためにはもう一工夫が必要のようです。

以上になります、最後までお読みいただきありがとうございました。

Discussion