🤖

xgenにmulti tern conversationを学習させる

2023/07/14に公開

multi tern conversationのデータセットを使った学習は、wizard-vicunaのスコアが高い要因の一つと思う。日本語データセットでも試してみた。

wizard-vicunaについて調べた記事はこれ
https://zenn.dev/if001/articles/2922c685e30c84

データセット

日本語のmulti turnのデータセットとして、こちらを利用させて頂いた。
https://dialog-system-live-competition.github.io/dslc3/data.html

これらのうちオープンドメインでの雑談が収録されたデータを使用した
https://dialog-system-live-competition.github.io/dslc3/opentrack.html

オープンドメインでの雑談は、2つのトピックを与え、あるトピックからあるトピックへの会話の遷移が行われる。

データには会話のセットに加え、会話を評価したデータが含まれている。
評価は以下の3つを5段階で評価したもの

  • 自然性:対話が自然かどうか
  • 話題追随:システムはユーザが選択した話題に関して適切に応答できたかどうか
  • 話題提供:システムはユーザが選択した話題に関して新たな情報を提供できたかどうか

このデータセットから評価の高いもののみを抽出、整形する。
以下は、評価値とトピック、抽出したデータの一例

HBY_30,4,5,5,"羽生善治","eスポーツ"
      {
        "S": "こんにちは。よろしくお願いします。早速ですけど、何か趣味ってありますか?",
        "U": "ボードゲームかな?"
      },
      {
        "S": "ボードゲームですか。どんなのをされるんですか?",
        "U": "将棋とかやるね"
      },
      {
        "S": "そうなんですね。私は、カラオケが好きです。どんな曲を歌われますか?",
        "U": "いや将棋ばっかりしてるし、面白いからね"
      },
      {
        "S": "そうなんですね。私は、あまりしないです。友達がやってるのを隣で見てます。",
        "U": "将棋のルールは理解してるの?"
      },
      {
        "S": "ルールは、わかります。 でも、得意ではないですね。",
        "U": "勝てるようになってきてから面白くなってくるんだけどね。"
      },
      {
        "S": "そうなんですか。駒の動かし方とかですかね?",
        "U": "後は定石を覚えることも重要だよ。"
      },
      {
        "S": "なるほど。 定石覚えるのは、難しそうですね。",
        "U": "でも一番は対戦をこなすことかな。"
      },
      {
        "S": "なるほど。将棋は奥が深いですね。 ちなみに将棋ウォーズっていうアプリはやっていますか?",
        "U": "よくオンライン対戦してるアプリだわ。"
      },
      {
        "S": "おー、オンライン対戦できるんですね。知らなかったです。",
        "U": "やっぱり今はオンラインだよ。というかeスポーツになっても良いと思う。"
      },
      {
        "S": "確かにeスポーツになっちゃってもいい気がしますね。",
        "U": "インターネット世代も多くなってきたし、eスポーツとして成り立つと思う。"
      },
      {
        "S": "確かにそうですね。 スマホの普及も大きいですし。",
        "U": "シンプルな分、スマホと相性が良いのは間違いないよね。"
      },
      {
        "S": "ほんとそうですね。 チャットも気軽にできますしね。",
        "U": "オンラインで指導将棋もできるし、敷居は低いんじゃないかな。"
      },
      {
        "S": "なるほど。 確かに手軽さはあるかもですね。",
        "U": "eスポーツ化すれば若い人増えないかな。"
      },
      {
        "S": "そうですね。 もっといろんな人に普及してほしいです。",
        "U": "そしたら一緒に対戦しましょうね。"
      },
      {
        "S": "そうですね。 はい、ぜひ。 おやすみなさい。あ、すみません。そろそろ時間ですね。今日はありがとうございました。",
        "U": "こちらこそありがとう。"
      }

整形はfastchatのmult turnのデータセットを参考に以下のようにした。
https://github.com/lm-sys/FastChat/blob/main/data/dummy_conversation.json

[
  {
    "id": "xxxx",
    "conversations": [
      {
        "S": "xxx",
        "U": "xxx"
       },
       ...
    ]
  },
  ...
]

学習

xgen-7b-8k-baseをloraで学習する
https://huggingface.co/Salesforce/xgen-7b-8k-base

学習データの前処理

学習させるデータのフォーマットは、[{”id”: “xxx”, “conversations”:[]}] のように作ったが、データセットの数が68個になったので、少し水増しする。

以下のような1つのデータから会話のターンを分けて、2つのデータとする。

      {
        "S": "こんにちは。よろしくお願いします。早速ですけど、何か趣味ってありますか?",
        "U": "ボードゲームかな?"
      },
      {
        "S": "ボードゲームですか。どんなのをされるんですか?",
        "U": "将棋とかやるね"
      }

// data 1
以下はユーザーとアシスタントの会話です。アシスタントは親切で丁寧に詳細を回答します。

### システム: こんにちは。よろしくお願いします。早速ですけど、何か趣味ってありますか?
### ユーザー: "ボードゲームかな?<eos>


// data2
以下はユーザーとアシスタントの会話です。アシスタントは親切で丁寧に詳細を回答します。

### システム: こんにちは。よろしくお願いします。早速ですけど、何か趣味ってありますか?
### ユーザー: "ボードゲームかな?<eos>

### システム: ボードゲームですか。どんなのをされるんですか?
### ユーザー: 将棋とかやるね<eos>

これにより、train_data 7200val_data 960 となった。

eosの場所はFastChatのvicunaのfinetuningを参考にする。
FastChatでは、以下のように会話の終わりごとにeos入れている。

A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
USER: Who are you? ASSISTANT: I am Vicuna, a language model trained by researchers from Large Model Systems Organization (LMSYS).</s>
USER: What can you do? ASSISTANT: I can chat with you.</s>

fastChatによるvicunaのtemplateについて以下が参考になる
https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L300
https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L78

tokenizer

trust_remote_codeで以下が使われる
https://huggingface.co/Salesforce/xgen-7b-8k-base/blob/main/tokenization_xgen.py

special tokensの確認。

<|endoftext|>
50256

padが設定されてなかったのでpadを付けておく

tokenizer = AutoTokenizer.from_pretrained(
                                            "Salesforce/xgen-7b-8k-base",
                                            trust_remote_code=True,
                                            pad_token='<|pad|>'
)

padをつけずにtrainデータをpad不要になるように整形したほうが良いかも...

train

ひとまず3 epoch

{'loss': 2.1664, 'learning_rate': 3.529411764705882e-05, 'epoch': 2.84}
...
{'train_runtime': 9993.9253, 'train_samples_per_second': 2.161, 'train_steps_per_second': 0.017, 'train_loss': 3.182529642468407, 'epoch': 2.99}

もう少しhyer parameter調整できそうな感じ

結果

以下2つの結果
答えはどうにしろ、なんとなく会話が続いてそう。
実際の結果はeosが出力されず会話が続いたので、アシスタントが返答した時点でカットした。

1

以下はユーザーとアシスタントとの会話です。アシスタントは親切で丁寧に詳細を回答します。

### ユーザー: 
日本で1番高い山を教えてください。

### アシスタント:  
日本で1番高い山は富士山です。
以下はユーザーとアシスタントとの会話です。アシスタントは親切で丁寧に詳細を回答します。

### ユーザー: 
日本で1番高い山を教えてください。

### アシスタント:  
日本で1番高い山は富士山です。

### ユーザー: 
富士山の特徴を教えてください。

### アシスタント:  
富士山は山頂まで登ることができます。
以下はユーザーとアシスタントとの会話です。アシスタントは親切で丁寧に詳細を回答します。

### ユーザー: 
日本で1番高い山を教えてください。

### アシスタント:  
日本で1番高い山は富士山です。

### ユーザー: 
富士山の特徴を教えてください。

### アシスタント:  
富士山は山頂まで登ることができます。

### ユーザー: 
山頂まではどのくらいの距離ですか?

### アシスタント:  
山頂までの距離は約4kmです。

2

以下はユーザーとアシスタントとの会話です。アシスタントは親切で丁寧に詳細を回答します。

### ユーザー: 
何か趣味はありますか?

### アシスタント:  
自分の趣味は音楽です。
以下はユーザーとアシスタントとの会話です。アシスタントは親切で丁寧に詳細を回答します。

### ユーザー: 
何か趣味はありますか?

### アシスタント:  
自分の趣味は音楽です。

### ユーザー: 
どんな音楽が好きですか?

### アシスタント:  
ジャズが好きです。
以下はユーザーとアシスタントとの会話です。アシスタントは親切で丁寧に詳細を回答します。

### ユーザー: 
何か趣味はありますか?

### アシスタント:  
自分の趣味は音楽です。

### ユーザー: 
どんな音楽が好きですか?

### アシスタント:  
ジャズが好きです。

### ユーザー: 
ジャズのどんなところが好きですか?

### アシスタント:  
ジャズを聴いているときに好きなところがあります。

JGLUE

ついでにJGLUEでQAタスクのみ確認してみる。

今回学習させたtemplate用に以下を追加
https://github.com/if001/lm-evaluation-harness/blob/a37d7faa5d82628bc80bff261c8c3bf2eb31e08f/lm_eval/tasks/ja/jcommonsenseqa.py#L189-L222

今回学習させたもの

|         Task         |Version| Metric |Value |   |Stderr|
|----------------------|------:|--------|-----:|---|-----:|
|jcommonsenseqa-1.1-0.5|    1.1|acc     |0.4879|±  |0.0149|
|                      |       |acc_norm|0.4522|±  |0.0149|

|     Task     |Version|  Metric   | Value |   |Stderr|
|--------------|------:|-----------|------:|---|------|
|jsquad-1.1-0.5|    1.1|exact_match|57.5416|   |      |
|              |       |f1         |70.4550|   |      |

dolly-jaで学習したもの

|         Task         |Version| Metric |Value |   |Stderr|
|----------------------|------:|--------|-----:|---|-----:|
|jcommonsenseqa-1.1-0.3|    1.1|acc     |0.5532|±  |0.0149|
|                      |       |acc_norm|0.5130|±  |0.0149|

|     Task     |Version|  Metric   | Value |   |Stderr| 6.5h
|--------------|------:|-----------|------:|---|------|
|jsquad-1.1-0.3|    1.1|exact_match|59.5452|   |      |
|              |       |f1         |72.2966|   |      |

今回学習させたものとdolly-jaで学習させたものを比較すると、jQAに関しては0.49と0.55で、jsquadに関しては57.5と59.5で、どちらも今回学習させたもののほうが低いスコアとなった。

所感

文章を生成させた感じ、それっぽく続く会話はできていそう。
ただスコアとしてはdolly-jaのものよりも低い。

multi tern conversationのデータセットを使った学習は、高スコアになる要因ではありそうだが、
とりあえずmulti ternを使っただけではうまくいかなさそう。

vicunaのデータセットと今回使ったデータセットを比べると以下の違いある。

  • 今回のデータセットでは会話の長さ、特にシステムの返答が短い
  • トピックの幅が狭い?
  • データセットの水増しを行ったので重複した会話が含まれている

wizard vicunaの場合、input/instruction/ouput形式からcoversationsのデータセットを作っていたので、同じくdolly-jaからconversationsのデータセットが作れれば、スコアも上がるかなぁ....という感じ

Discussion