Closed4

「Large Language Mario」をclaude/gemini対応する。

生ビール生ビール

からあげ先生作notebookをAnthropic / Gemini/MiniCPM-Vに仮対応してみる。
https://zenn.dev/karaage0703/articles/5a02a0822fba8a

Anthropic APIバージョン

predict()部分のコードを以下にしただけです。
※なお、claude-3.5-sonnetに作ってもらっただけなので間違いあるかもしれません。

!pip install Anthropic
import base64
import json
from anthropic import Anthropic
from PIL import Image

def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

def predict(state):
    # 今のマリオの状態をPNG画像として保存
    image = Image.fromarray(state)
    image.save('state.png')

    # 画像ファイルをエンコーディング
    image_path = "./state.png"
    base64_image = encode_image(image_path)

    api_key = userdata.get("ANTHROPIC_API_KEY")
    anthropic = Anthropic(api_key=api_key)

    # プロンプト
    prompt = """
        この画像はゲーム、スーパーマリオのプレイ画面です。
       画面に応じて、以下の7つのボタン操作ができます。ボタン操作は以下の7つからどれかを選んでください
       NOOPが操作しない。Aがジャンプ。Bがダッシュです。
      クリボーに当たるとゲームオーバーになるので、クリボーに当たりそうになった時は適切に回避して。
      また土管に当たって動きが見られない場合は少し左に動いた後に助走つけて右とジャンプを繰り出すといいです。

        0 = 'NOOP'
        1 = 'right'
        2 = 'right', 'A'
        3 = 'right', 'B'
        4 = 'right', 'A', 'B'
        5 = 'A'
        6 = 'left'

        以下の通りjson出力してください。日本語でお願いします。

       explanation: 画面の説明
        reason: ボタン操作の理由
        action: ボタン操作の種類

        回答は必ず有効なJSONフォーマットで出力してください。
    """

    response = anthropic.messages.create(
        model="claude-3-5-sonnet-20240620",
        max_tokens=300,
        temperature=1,
        messages=[
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": prompt
                    },
                    {
                        "type": "image",
                        "source": {
                            "type": "base64",
                            "media_type": "image/png",
                            "data": base64_image
                        }
                    }
                ]
            }
        ]
    )

    try:
        content_dict = json.loads(response.content[0].text)
        action = content_dict.get('action')
        explanation = content_dict.get('explanation')
        reason = content_dict.get('reason')

        if action is None:
            action = 0

        return action, explanation, reason
    except json.JSONDecodeError:
        print("Error: Unable to parse JSON response")
        return 0, "エラー: レスポンスの解析に失敗しました", "エラーが発生したため、デフォルトのアクションを返します"

生ビール生ビール

Gemini APIバージョン(Google AI studio)

※google ai studio課金版で試してます。

import base64
import json
import google.generativeai as genai
from google.colab import userdata
from PIL import Image
import time

def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

def predict(state):
    # 今のマリオの状態をPNG画像として保存
    image = Image.fromarray(state)
    image.save('state.png')

    # Google API Keyの設定
    GOOGLE_API_KEY = userdata.get('GOOGLE_API_KEY')
    genai.configure(api_key=GOOGLE_API_KEY)

    # 画像ファイルをアップロード
    image_path = "./state.png"
    uploaded_image = genai.upload_file(image_path)

    # アップロード完了を待つ
    while uploaded_image.state.name == "PROCESSING":
        print("Waiting for image processing.")
        time.sleep(1)
        uploaded_image = genai.get_file(uploaded_image.name)

    # モデルのインスタンスを作成
    model = genai.GenerativeModel("gemini-1.5-pro-latest")

    # プロンプト
    prompt = """
        この画像はゲーム、スーパーマリオのプレイ画面です。
       画面に応じて、以下の7つのボタン操作ができます。ボタン操作は以下の7つからどれかを選んでください
       NOOPが操作しない。Aがジャンプ。Bがダッシュです。
      クリボーに当たるとゲームオーバーになるので、クリボーに当たりそうになった時は適切に回避して。
      また土管に当たって動きが見られない場合は少し左に動いた後に助走つけて右とジャンプを繰り出すといいです。

        0 = 'NOOP'
        1 = 'right'
        2 = 'right', 'A'
        3 = 'right', 'B'
        4 = 'right', 'A', 'B'
        5 = 'A'
        6 = 'left'

        以下の通りjson出力してください。日本語でお願いします。

       explanation: 画面の説明
        reason: ボタン操作の理由
        action: ボタン操作の種類(数字で出力すること)

        回答は必ず有効なJSONフォーマットで出力してください。JSONのみを出力し、余分な説明は不要です。
    """

    # APIを実行
    content = [prompt, uploaded_image]
    response = model.generate_content(content)

    # デバッグ用に応答の全文を表示
    print("API Response:")
    print(response.text)

    try:
        # 応答テキストからJSONを抽出
        json_start = response.text.find('{')
        json_end = response.text.rfind('}') + 1
        if json_start != -1 and json_end != -1:
            json_str = response.text[json_start:json_end]
            content_dict = json.loads(json_str)
        else:
            raise ValueError("No JSON found in the response")

        action = content_dict.get('action')
        explanation = content_dict.get('explanation')
        reason = content_dict.get('reason')

        if action is None:
            action = 0
        else:
            action = int(action)  # actionを整数に変換

        return action, explanation, reason
    except (json.JSONDecodeError, ValueError) as e:
        print(f"Error: {str(e)}")
        return 0, "エラー: レスポンスの解析に失敗しました", "エラーが発生したため、デフォルトのアクションを返します"

# 使用例
# action, explanation, reason = predict(state)
# print(f"Action: {action}")
# print(f"Explanation: {explanation}")
# print(f"Reason: {reason}")
生ビール生ビール

MiniCPM-V 2.6版

※要HF_TOKEN

pip install -U flash_attn
!pip install transformers==4.40.0
!pip install sentencepiece==0.1.99
!pip install accelerate==0.30.1
!pip install bitsandbytes==0.43.1
!pip install -U timm
import torch
from transformers import AutoModel, AutoTokenizer
from PIL import Image
import cv2
import numpy as np

def predict(state):
    # モデルとトークナイザーの初期化(初回のみ実行)
    if not hasattr(predict, "model"):
        predict.model = AutoModel.from_pretrained(
            'openbmb/MiniCPM-V-2_6-int4',
            trust_remote_code=True,
        ).eval()
        predict.tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2_6-int4', trust_remote_code=True)

    # 画像の前処理
    image = Image.fromarray(cv2.cvtColor(state, cv2.COLOR_BGR2RGB))

    # プロンプトの設定
    prompt = """
    Analyze this Super Mario game screen and choose the best action.
    If you hit the Goomba, the game is over, so please dodge it well.
    Available actions:
    0 = 'NOOP'
    1 = 'RIGHT'
    2 = 'RIGHT A' (Jump right)
    3 = 'RIGHT B' (Run right)
    4 = 'RIGHT A B' (Jump and run right)
    5 = 'A' (Jump)
    6 = 'LEFT'

    Respond with:
    1. A brief description of the screen
    2. The recommended action (as a number 0-6)
    3. The reason for this action

    Format your response as follows:
    Description: [Your screen description]
    Action: [Number 0-6]
    Reason: [Your reason for the action]
    """

    # モデルに入力を渡して推論
    msgs = [{'role': 'user', 'content': [image, prompt]}]
    response = predict.model.chat(
        image=None,
        msgs=msgs,
        tokenizer=predict.tokenizer,
    )

    print("Raw response:")
    print(response)

    # レスポンスの解析
    lines = response.strip().split('\n')
    description = ""
    action = 0
    reason = ""

    for line in lines:
        if line.startswith("Description:"):
            description = line.split(":", 1)[1].strip()
        elif line.startswith("Action:"):
            action_str = line.split(":", 1)[1].strip()
            try:
                action = int(action_str)
            except ValueError:
                print(f"Invalid action value: {action_str}")
                action = 0
        elif line.startswith("Reason:"):
            reason = line.split(":", 1)[1].strip()

    print("\nParsed result:")
    print(f"Action: {action}")
    print(f"Description: {description}")
    print(f"Reason: {reason}")

    return action, description, reason

# 使用例
# state = cv2.imread('mario_game_screen.png')
# action, description, reason = predict(state)
# print(f"\nFinal output:")
# print(f"Action: {action}")
# print(f"Description: {description}")
# print(f"Reason: {reason}")
生ビール生ビール

Gemini 2.0-exp バージョン(Google AI studio)

とりあえず動いたレベルなので、改善あればぜひ改変してほしいです。

このスクラップは7日前にクローズされました