[2023年4月版]Auto-GPTのコードを解き明かす。驚異のAIハック!~その1

2023/04/17に公開

Auto-GPT のコードをざっくりと読み解いてみる

先日、ローカルに構築したAuto-GPTですが、これはPythonで記述されています。
どのようなことをやっているのか、元のPythonのコードに対してGPT-4にコメントをつけてもらいながら読み解いていきたいと思います。

1. main.py

コマンドから叩くmain.pyはたった1行、

./main.py
from scripts.main import main

となっており、プログラムの本体は scripts ディレクトリ以下となります。

というわけでscritps/main.pyが本体の入り口となります。
script 以下のプログラムを読み解いていきたいとおもいます。

2. scripts/main.py はじまり

プログラムのメインルーチンがここに記述されています。
前半の def 関数名 の関数定義をすっ飛ばして後半の check_openai_api_key() から始まるメインの処理をみていきましょう。
GPT-4にコメントをつけてもらいつつ、解説してまいります。

# OpenAI APIキーを確認
check_openai_api_key()
# 引数を解析
parse_arguments()
# ログレベルを設定(デバッグモードの場合はDEBUG、それ以外の場合はINFO)
logger.set_level(logging.DEBUG if cfg.debug_mode else logging.INFO)
# AI名を初期化
ai_name = ""
# プロンプトを構築
prompt = construct_prompt()
# 変数を初期化
full_message_history = []
result = None
next_action_count = 0
# 定型文の定数を作成
user_input = "Determine which next command to use, and respond using the format specified above:"

# メモリを初期化し、空になっていることを確認
# これは、pineconeメモリのインデックス作成や参照に特に重要
memory = get_memory(cfg, init=True)
print('Using memory of type: ' + memory.__class__.__name__)

ここまででAIの初期設定ですね。
OpenAIのAPIキーを確認し、引数の解析、前回の記憶の取得などを行っております。

3. メインループ

While Trueの内部がメイン処理のループです。
以下はコードのコメントに番号を振って大きな処理のフローのポイントを表しています。

# 0. 対話ループ
while True:
    # 1. AIにメッセージを送信し、レスポンスを取得
    with Spinner("Thinking... "):
        assistant_reply = chat.chat_with_ai(
            prompt,
            user_input,
            full_message_history,
            memory,
            # config.py#L.43: self.fast_token_limit = int(os.getenv("FAST_TOKEN_LIMIT", 4000))
            cfg.fast_token_limit) # TODO: このハードコードはGPT3.5を使用します。引数にしてください

    # 2. アシスタントの思考を表示
    print_assistant_thoughts(assistant_reply)

    # 3. コマンド名と引数を取得
    try:
        command_name, arguments = cmd.get_command(attempt_to_fix_json_by_finding_outermost_brackets(assistant_reply))
        if cfg.speak_mode:
            speak.say_text(f"I want to execute {command_name}")
    except Exception as e:
        logger.error("Error: \n", str(e))
    # 4. ユーザーからの入力を受け付ける
    if not cfg.continuous_mode and next_action_count == 0:
        ### ユーザーからのコマンド実行承認を取得 ###
        # キー入力を取得:ユーザーにEnterキーを押して続行するか、Escキーを押して終了するかを促す
        user_input = ""
        logger.typewriter_log(
            "NEXT ACTION: ",
            Fore.CYAN,
            f"COMMAND = {Fore.CYAN}{command_name}{Style.RESET_ALL}  ARGUMENTS = {Fore.CYAN}{arguments}{Style.RESET_ALL}")
        print(
            f"Enter 'y' to authorise command, 'y -N' to run N continuous commands, 'n' to exit program, or enter feedback for {ai_name}...",
            flush=True)
        while True:
            console_input = utils.clean_input(Fore.MAGENTA + "Input:" + Style.RESET_ALL)
            if console_input.lower().rstrip() == "y":
                user_input = "GENERATE NEXT COMMAND JSON"
                break
            elif console_input.lower().startswith("y -"):
                try:
                    next_action_count = abs(int(console_input.split(" ")[1]))
                    user_input = "GENERATE NEXT COMMAND JSON"
                except ValueError:
                    print("Invalid input format. Please enter 'y -n' where n is the number ofcontinuous tasks.")
                    continue
                break
            elif console_input.lower() == "n":
                user_input = "EXIT"
                break
            else:
                user_input = console_input
                command_name = "human_feedback"
                break

        # 5. ユーザーの入力に応じてコマンドを実行
        if user_input == "GENERATE NEXT COMMAND JSON":
            logger.typewriter_log(
            "-=-=-=-=-=-=-= COMMAND AUTHORISED BY USER -=-=-=-=-=-=-=",
            Fore.MAGENTA,
            "")
        elif user_input == "EXIT":
            print("Exiting...", flush=True)
            break
        else:
            # コマンドを表示
            logger.typewriter_log(
                "NEXT ACTION: ",
                Fore.CYAN,
                f"COMMAND = {Fore.CYAN}{command_name}{Style.RESET_ALL}  ARGUMENTS = {Fore.CYAN}{arguments}{Style.RESET_ALL}")

        if command_name is not None and command_name.lower().startswith( "error" ):
            result = f"Command {command_name} threw the following error: " + arguments
        elif command_name == "human_feedback":
            result = f"Human feedback: {user_input}"
        else:
            result = f"Command {command_name} returned: {cmd.execute_command(command_name, arguments)}"
            if next_action_count > 0:
                next_action_count -= 1

        # 6. コマンドから得られた結果をメッセージ履歴に追加
        memory_to_add = f"Assistant Reply: {assistant_reply} " \
                        f"\nResult: {result} " \
                        f"\nHuman Feedback: {user_input} "

        memory.add(memory_to_add)
        if result is not None:
            full_message_history.append(chat.create_chat_message("system", result))
            logger.typewriter_log("SYSTEM: ", Fore.YELLOW, result)
        else:
            # 結果が得られなかった場合、「コマンドを実行できません」というメッセージをメッセージ履歴に追加
            full_message_history.append(
                chat.create_chat_message(
                    "system", "Unable to execute command"))
            logger.typewriter_log("SYSTEM: ", Fore.YELLOW, "Unable to execute command")

ということで、大きな流れは以下の通りです。

  1. AIにプロンプトをなげて思考処理をさせる。
  2. AIの結果からアシスタントの思考を表示
  3. AIの結果から次に実行する予定のコマンド名と引数を取得
  4. ユーザーからの入力を受け付ける
  5. ユーザーの入力に応じてコマンドを実行
  6. コマンドから得られた結果をメッセージ履歴に追加

以下、大まかな流れに沿ってコードを読み解いていきます。

各処理解説

それでは上段であげた大きな処理の流れに沿って各処理のソースを読み解いていきます。

3-1. AIにプロンプトをなげて思考処理をさせる。

with Spinner("Thinking... "): の箇所は Spinnerクラスを使って思考中の"クルクル"を表示し、With句内の処理が終わったら"クルクル"を消すだけです。
メインの処理は chat.chat_with_aiにあります。

        assistant_reply = chat.chat_with_ai(
            prompt,
            user_input,
            full_message_history,
            memory,
            # config.py#L.43: self.fast_token_limit = int(os.getenv("FAST_TOKEN_LIMIT", 4000))
            cfg.fast_token_limit) # TODO: このハードコードはGPT3.5を使用します。引数にしてください

プロンプトとユーザーの入力、メッセージのやり取り履歴、記憶、トークン数を引数に chat_with_ai を呼び出しています。
それでは chat_with_ai関数の中身を見てみましょう。

scripts/chat.py
def chat_with_ai(
        prompt,
        user_input,
        full_message_history,
        permanent_memory,
        token_limit):
    while True:
        try:
            """
            OpenAI APIと対話し、プロンプト、ユーザー入力、メッセージ履歴、および永続的なメモリを送信する。

            引数:
            prompt (str): AIにルールを説明するプロンプト。
            user_input (str): ユーザーからの入力。
            full_message_history (list): ユーザーとAIの間で送信されたすべてのメッセージのリスト。
            permanent_memory (Obj): 永続的なメモリを含むメモリオブジェクト。
            token_limit (int): API呼び出しで許可されるトークンの最大数。

            戻り値:
            str: AIの応答。
            """
            # FAST_LLM_MODEL=gpt-3.5-turboがここでは使用されている
            model = cfg.fast_llm_model # TODO: モデルをハードコードから引数に変更する

このmodelで指定されているcfg.fast_llm_model は config.py#L.41 で self.fast_llm_model = os.getenv("FAST_LLM_MODEL", "gpt-3.5-turbo") と定義されてます。
デフォルトでは gpt-3.5-turboです。
続けます。

            # デフォルトでは cfg.fast_token_limit = int(os.getenv("FAST_TOKEN_LIMIT", 4000)) が渡されてくる
            logger.debug(f"Token limit: {token_limit}")
            # 応答用に1000トークンを予約
            send_token_limit = token_limit - 1000
            # 関連する記憶を取得
            relevant_memory = permanent_memory.get_relevant(str(full_message_history[-9:]), 10)
            logger.debug(f'Memory Stats: {permanent_memory.get_stats()}')

# 関連する記憶を取得では、permanent_memoryオブジェクトから関連する記憶を取得してます。
get_relevant 関数は、メッセージ履歴の全体(full_message_history)の最後の9つのメッセージを文字列に変換し(str(full_message_history[-9:]))、その中で最も関連性の高い10個の記憶を取得します。
この関連する記憶は、後でAPI呼び出しに使用されるコンテキストの生成に利用されます。
またのちほど関連する記憶の取得処理ついて触れます。
では、続けます。

            next_message_to_add_index, current_tokens_used, insertion_index, current_context = generate_context(prompt, relevant_memory, full_message_history, model)
            while current_tokens_used > 2500:
                # トークン数が2500を下回るまでメモリを削除
                relevant_memory = relevant_memory[1:]
                # コンテキストを生成
                next_message_to_add_index, current_tokens_used, insertion_index, current_context = generate_context(
                    prompt, relevant_memory, full_message_history, model)

上記の箇所では、対話のコンテキストを生成し、トークン数が制限(2500トークン)を超えないようにしています。

まず、generate_context関数にて、プロンプト、関連するメモリ、完全なメッセージ履歴、およびモデルを元に、次に追加するメッセージのインデックス、現在使用されているトークン数、挿入インデックス、および現在のコンテキストを生成します。

次に、現在使用されているトークン数が2500を超えている場合、関連するメモリからメモリを削除して(relevant_memory = relevant_memory[1:])、2500トークン以下になるまで繰り返します。その後、再度generate_context関数を呼び出して、更新された関連するメモリを元に、新しいコンテキストを生成します。

この処理は、API呼び出しのコンテキストに含めるメッセージを適切に制限し、トークン数が許容範囲内に収まるようにすることを目的としています。

続けます。

            # ユーザー入力のトークン数を計算(後で追加される)
            current_tokens_used += token_counter.count_message_tokens([create_chat_message("user", user_input)], model)

            # 追加するメッセージがなくなるまで繰り返す
            while next_message_to_add_index >= 0:
                # 追加するメッセージを取得
                message_to_add = full_message_history[next_message_to_add_index]
                # 追加するメッセージのトークン数を計算
                tokens_to_add = token_counter.count_message_tokens([message_to_add], model)
                if current_tokens_used + tokens_to_add > send_token_limit:
                    break

                # 現在のコンテキストの先頭に最も新しいメッセージを追加(二つのシステムプロンプトの後)
                current_context.insert(insertion_index, full_message_history[next_message_to_add_index])

                # 現在使用されているトークン数を計算
                current_tokens_used += tokens_to_add

                # 完全なメッセージ履歴の次の最も新しいメッセージに移動
                next_message_to_add_index -= 1

上記の箇所は先ほどと同様に、ユーザー入力のトークン数を計算し、コンテキストに追加するメッセージを選択しています。
以下、細かく注釈を入れます。

current_tokens_used += token_counter.count_message_tokens([create_chat_message("user", user_input)], model)

ユーザー入力のトークン数を計算し、現在使用されているトークン数(current_tokens_used)に追加しています。

while next_message_to_add_index >= 0:

コンテキストに追加するメッセージを制限するループの開始。

message_to_add = full_message_history[next_message_to_add_index]
tokens_to_add = token_counter.count_message_tokens([message_to_add], model)
if current_tokens_used + tokens_to_add > send_token_limit:
    break

ここで、追加するメッセージのトークン数を計算し、現在使用されているトークン数との合計が送信用トークン制限(send_token_limit)を超える場合、ループから抜けます。

current_context.insert(insertion_index, full_message_history[next_message_to_add_index])

この行では、現在のコンテキストの先頭に最も新しいメッセージを追加しています。メッセージは、2つのシステムプロンプトの後に挿入されます。

current_tokens_used += tokens_to_add

ここで、追加したメッセージのトークン数を現在使用されているトークン数に加算しています。

next_message_to_add_index -= 1

この行では、完全なメッセージ履歴の次の最も新しいメッセージに移動しています。このプロセスは、コンテキストに追加できるメッセージがなくなるまで繰り返されます。

このようにこれまでの会話の履歴とユーザー入力を合わせて全体のトークン量に合わせてプロンプトに渡す履歴を調整しているのがここまでの処理です。
AIからの回答の精度を高めるためには履歴は必要ですが、すべてを渡すことはできません。
"関連する記憶"からトークン数の上限に収まるように履歴を調整し、続いてユーザー入力のメッセージを含めてトークン数に合わせてプロンプトに追加するメッセージ数を調整していってます。
ここも非常に重要なキモですね…

さて続けます。

            # ユーザー入力を追加(上記でトークン数を計算済み)
            current_context.extend([create_chat_message("user", user_input)])

            # 残りのトークンを計算
            tokens_remaining = token_limit - current_tokens_used

            # 現在のコンテキストをデバッグ表示の処理は割愛

            # AIにcompletionさせる
            # TODO: 他の場所で定義されたモデルを使用し、モデルに温度やその他の関心のある設定を含める
            assistant_reply = create_chat_completion(
                model=model,
                messages=current_context,
                max_tokens=tokens_remaining,
            )

            # 完全なメッセージ履歴を更新
            full_message_history.append(
                create_chat_message(
                    "user", user_input))
            full_message_history.append(
                create_chat_message(
                    "assistant", assistant_reply))

            return assistant_reply
        except openai.error.RateLimitError:
            #TODO: langchainに切り替えると、これは組み込まれています
            print("Error: ", "API Rate Limit Reached. Waiting 10 seconds...")
            time.sleep(10)

はい、ここまでです。

この関数での処理を紐解いていくと

  1. トークン数の調整しつつプロンプトの生成
  2. AIにcompletionさせる
  3. 会話履歴を記憶

の3つのパートに分かれますが、
処理の8~9割くらいがプロンプトの作成=トークン数の調整であります。
単に生成したコマンドやユーザーからの入力をAIに投げるだけであれば苦労しないのですが、投げるプロンプトの"トークン数"に上限があるのでそれを超えないようにプロンプトを切り詰めている、というのが悩みのタネであるということです。
上記の解説中にも述べましたが、ここがAIの、というかAPIでAIに処理を行わせる際のキーポイントですね…

さて、プロンプト構築のキモ中のキモ、関連する記憶について深堀してみます。

"関連する記憶" の正体に迫る。

関連する記憶を取得する処理は、permanent_memoryオブジェクトのget_relevant関数にあります。

permanent_memoryオブジェクトは.envMEMORY_BACKENDで指定する値でいくつかの実装が分かれており、local, redis, pinecone の3つがあります

検索ロジックが実装されている 'local' と 'Redis' の実装を載せます。
まず local.py の実装です。

script/memory/local.py
    def get_relevant(self, text: str, k: int) -> List[Any]:
        """
        与えられたテキストに対して最も関連性のあるデータをリストで返す関数
        Args:
            text: str
            k: int

        Returns: List[str]
        """
        # 与えられたテキストの埋め込みを取得
        embedding = get_ada_embedding(text)
        # すべてのデータの埋め込みと与えられたテキストの埋め込みの行列とベクトルの積を計算し、類似度スコアを得る
        scores = np.dot(self.data.embeddings, embedding)
        # スコア配列を降順にソートし、上位k個のインデックスを取得
        top_k_indices = np.argsort(scores)[-k:][::-1]
        # 上位k個のインデックスに対応するテキストをリストにまとめて返す
        return [self.data.texts[i] for i in top_k_indices]

引数のテキストはプロンプトからの入力です。入力値から get_ada_embedding でベクトルを求めて、data.embeddings の履歴データのベクトルすべてとの内積って降順ソートで上位k番目までのリストを返却してます。
ここでいろいろ突っ込む前に、redis 版の実装も載せます。

script/memory/redismem.py
def get_relevant(
    self,
    data: str,
    num_relevant: int = 5
) -> Optional[List[Any]]:
    """
    与えられたデータと関連性のあるメモリ内のデータをすべて返す。
    Args:
        data: 比較対象のデータ。
        num_relevant: 返す関連データの数。

    Returns: 最も関連性のあるデータのリスト。
    """
    # 与えられたデータの埋め込みを取得
    query_embedding = get_ada_embedding(data)
    # KNN(k近傍法)を使用したクエリを作成
    base_query = f"*=>[KNN {num_relevant} @embedding $vector AS vector_score]"
    # クエリオブジェクトを作成し、必要なフィールドを指定、ソート順を設定し、ダイアレクトを指定
    query = Query(base_query).return_fields(
        "data",
        "vector_score"
    ).sort_by("vector_score").dialect(2)
    # query_embedding を NumPy 配列に変換し、np.float32 型にキャストしてバイト列に変換
    query_vector = np.array(query_embedding).astype(np.float32).tobytes()

    try:
        # Redis サーバーで検索を実行し、結果を取得
        results = self.redis.ft(f"{self.cfg.memory_index}").search(
            query, query_params={"vector": query_vector}
        )
    except Exception as e:
        # 例外が発生した場合はエラーメッセージを出力し、None を返す
        print("Error calling Redis search: ", e)
        return None
    # 検索結果からデータフィールドを取り出し、リストにまとめて返す
    return [result.data for result in results.docs]

こちらも入力値から get_ada_embedding でベクトルを求めて、redis の履歴データのベクトルからKNNで検索かけて関連する記憶を取得してます。

これにはいろいろ仕掛けが必要でして…
まず準備として、

  1. ベクトル算出のロジック
  2. 検索される側である履歴データの保存

を、そもそもどうやっているかを紹介します。

ベクトルの算出ロジック

ベクトルで比較できるようにするには、あらゆるテキストを同じロジックでベクトルを計算しておく必要があります。
このロジックは memory にある基底クラスである base.py で定義されています。

memory/base.py
ef get_ada_embedding(text):
    # 改行をスペースに置換
    text = text.replace("\n", " ")
    # Azure を使用する場合の処理
    if cfg.use_azure:
        # Azure でデプロイされている "text-embedding-ada-002" モデルを使って、テキストの埋め込みを取得
        return openai.Embedding.create(
            input=[text], engine=cfg.get_azure_deployment_id_for_model("text-embedding-ada-002")
        )["data"][0]["embedding"]
    # Azure を使用しない場合の処理
    else:
        # "text-embedding-ada-002" モデルを使って、テキストの埋め込みを取得
        return openai.Embedding.create(
            input=[text], model="text-embedding-ada-002"
        )["data"][0]["embedding"]

この get_ada_embedding 関数は、与えられたテキストに対して埋め込み(ベクトル表現)を生成し、それを返す機能を提供します。Azureを使うかどうかでAPIの呼び出し方が変わりますが、ロジックは同じです。

ユーザーの入力や生成された応答など履歴を収めるデータのどちらに対して text-embedding-ada-002 モデルを使用して'埋め込み'を生成し、ベクトル同士の比較を行うことで類似検索を実現しています。

履歴の保存

続いて履歴を保存するときの処理です。
local.py での実装は以下のようになっています。

memory/local.py
def add(self, text: str):
    """
    テキストをテキストリストに追加し、埋め込みを埋め込み行列に追加する
    Args:
        text: str
    Returns: None
    """
    # エラーメッセージが含まれている場合は処理をスキップ
    if 'Command Error:' in text:
        return ""
    # テキストをリストに追加
    self.data.texts.append(text)
    # テキストの埋め込みを取得
    embedding = get_ada_embedding(text)
    # 埋め込みを NumPy 配列に変換し、型を np.float32 にキャスト
    vector = np.array(embedding).astype(np.float32)
    # 配列の次元を変更して1行の行列にする
    vector = vector[np.newaxis, :]
    # 埋め込み行列に新しい埋め込みベクトルを追加
    self.data.embeddings = np.concatenate(
        [
            self.data.embeddings,
            vector,
        ],
        axis=0,
    )
    # ファイルにデータを保存
    with open(self.filename, 'wb') as f:
        out = orjson.dumps(
            self.data,
            option=SAVE_OPTIONS
        )
        f.write(out)
    return text

get_ada_embedding関数でtext-embedding-ada-002モデルでのベクトルを評価してもらい、それを履歴データとして保存しています。

redisでの保尊ロジックは以下です。

まず、事前に"スキーマ"を定義します。

memory/redismem.py
SCHEMA = [
    TextField("data"),
    VectorField(
        "embedding",
        "HNSW",
        {
            "TYPE": "FLOAT32",
            "DIM": 1536,
            "DISTANCE_METRIC": "COSINE"
        }
    ),
]

ここでテキスト型のdataフィールドと、ベクトル型のembeddingフィールドを定義しています。
またベクトル型フィールドの検索オプションとして、"HNSW"の検索ロジック、詳しいデータ型は FLOAT32、次元は 1536、"類似度"を測るのは"コサイン"を定義しています。
それぞれのオプションなどについてはベクトル類似性検索のRedisのマニュアルをご覧ください。
このようにあらかじめスキーマを定義しておいて、履歴を保存しています。

memory/redismem.py
def add(self, data: str) -> str:
    """
    メモリにデータポイントを追加する。
    Args:
        data: 追加するデータ。
    Returns: データが追加されたことを示すメッセージ。
    """
    # エラーメッセージが含まれている場合は処理をスキップ
    if 'Command Error:' in data:
        return ""
    # データの埋め込みを取得し、NumPy 配列に変換して型を np.float32 にキャストし、バイト列に変換
    vector = get_ada_embedding(data)
    vector = np.array(vector).astype(np.float32).tobytes()
    # データと埋め込みをディクショナリに格納
    data_dict = {
        b"data": data,
        "embedding": vector
    }
    # Redis パイプラインを作成
    pipe = self.redis.pipeline()
    # メモリインデックスとベクトル番号をキーにして、データディクショナリを格納
    pipe.hset(f"{self.cfg.memory_index}:{self.vec_num}", mapping=data_dict)
    # 挿入されたデータに関する情報をメッセージに格納
    _text = f"Inserting data into memory at index: {self.vec_num}:\n"\
        f"data: {data}"
    # ベクトル番号をインクリメント
    self.vec_num += 1
    # ベクトル番号を Redis に保存
    pipe.set(f'{self.cfg.memory_index}-vec_num', self.vec_num)
    # パイプラインを実行
    pipe.execute()
    return _text

local.pyのadd メソッドと流れは一緒です。データの埋め込みを取得し、それを Redis に格納しています。
また、ベクトル番号をインクリメントし、それを Redis に保存しています。

このようにしてすべての履歴も保存時にtext-embedding-ada-002モデルでのベクトル評価が行われています。

ベクトルの埋め込みとKNN検索、やってるじゃん

Redis でもベクトル類似性検索ができるのは驚きました。
今回はKNN=k最近傍法で使って検索かけてます。いつからそんな機能が…
また、ベクトルの埋め込みと類似性検索についてはここがわかりやすかったです。

この AutoGPTはRedisはコサインで判定、localでは内積と、距離の判定ロジックが違うのも面白いですね。
Redisとlocalでどのように違うか、後で比べてみたいと思います。

今日はここまで。

まだまだスタートしたばかりですが、いきなり核心の処理を読み解いてしまったような…
ちょっと胃がもたれてきたので、いったん本日はここまでといたします。

明日も引き続き、読み解いていきたいと思います。
それではよろしくお願いいたします。

Discussion