Open29

LangChain の Conversation まわりのプロンプト生成実装理解

Coji MizoguchiCoji Mizoguchi

継承関係は

BaseModel => Chain => LLMChain => ConversationChain
BaseModel => Chain => AgentExecutor
BaseModel => Agent => ConversationalAgent

だった。

BaseModel は Pydantic のモデルで、TS でいう zod的なやつ。
ConversationChain と対比されるのは AgentExecutor 。わかりづらいい

Coji MizoguchiCoji Mizoguchi

Chain

class Chain(BaseModel, ABC):
プロパティ
  memory: Optional[Memory]
  callback_manager: BaseCallbackManager
  verbose: bool

@property
  _chain_type(self) -> str

@validator
  set_callback_manager(cls, callback_manager: Opional[BaseCallbackManager]) -> BaseCallbackManager
  set_verbose(cls, verbose: Optional[bool]) -> bool

@property
@abstractmethod
  input_keys() -> List[str]
  output_keys() => List[str]

@abstractmethod
  _call(self, inputs: Dict[str, str]) -> Dict[str, str]

メソッド
  _validate_inputs(self, inputs: Dict[str, str]) -> None
  _validate_outputs(self, outputs: Dict[str, str]) -> None
  _acall(self, inputs: Dict[str, str]) -> Dict[str, str]
  __call__(self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False) -> Dict[str, Any]
  acall(self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False) -> Dict[str, Any]
  prep_outputs(self, inputs: Dict[str, str], outputs: Dict[str, str], return_only_outputs: bool = False) -> Dict[str, str]
  prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]
  apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
  run(self, *args: str, **kwargs: str) -> str:
  arun(self, *args: str, **kwargs: str) -> str:
  dict(self, **kwargs: Any) -> Dict:
  save(self, file_path: Union[Path, str]) -> None:
Coji MizoguchiCoji Mizoguchi

LLMChain

class LLMChain(Chain, BaseModel)
  prompt: BasePromptTemplate
  llm: BaseLLM
  output_key: str = "text"

@property(オーバーライド)
  input_keys(self) -> List[str]
  output_keys(self) -> List[str]
  _chain_type(self) -> str: # llm_chain

メソッド(オーバーライド)
  apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]
  _call(self, inputs: Dict[str, Any]) -> Dict[str, str]
  _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]
  
メソッド
  generate(self, input_list: List[Dict[str, Any]]) -> LLMResult
  agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult
  prep_prompts(self, input_list: List[Dict[str, Any]]) -> Tuple[List[str], Optional[List[str]]]
  aprep_prompts(self, input_list: List[Dict[str, Any]]) -> Tuple[List[str], Optional[List[str]]]
  aapply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]
  create_outputs(self, response: LLMResult) -> List[Dict[str, str]]
  predict(self, **kwargs: Any) -> str
  apredict(self, **kwargs: Any) -> str
  predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]
  apply_and_parse(self, input_list: List[Dict[str, Any]]) -> Sequence[Union[str, List[str], Dict[str, str]]]
  _parse_result(self, result: List[Dict[str, str]]) -> Sequence[Union[str, List[str], Dict[str, str]]]
  aapply_and_parse(
        self, input_list: List[Dict[str, Any]]
    ) -> Sequence[Union[str, List[str], Dict[str, str]]]
  
@classmethod
  from_string(cls, llm: BaseLLM, template: str) -> Chain
Coji MizoguchiCoji Mizoguchi

ConversationChain

class ConversationChain(LLMChain, BaseModel)
  memory: Memory = Field(default_factory=ConversationBufferMemory)
  prompt: BasePromptTemplate = PROMPT
  input_key: str = "input"
  output_key: str = "response"

@property (オーバーライド)
  input_keys(self) -> List[str]

@root_validator
  validate_prompt_input_variables(cls, values: Dict) -> Dict
Coji MizoguchiCoji Mizoguchi

AgentExecutor

class AgentExecutor(Chain, BaseModel):
  agent: Agent
  tools: Sequence[BaseTool]
  return_intermediate_steps: bool = False
  max_iterations: Optional[int] = 15
  early_stopping_method: str = "force"

@property(オーバーライド)
  input_keys(self) -> List[str]
  output_keys(self) -> List[str]

@classmethod
  from_agent_and_tools(cls, agent: Agent, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, **kwargs: Any) -> AgentExecutor

@root_validator()
  validate_tools(cls, values: Dict) -> Dict

メソッド(オーバーライド)
  _call(self, inputs: Dict[str, str]) -> Dict[str, Any]
  _acall(self, inputs: Dict[str, str]) -> Dict[str, str]

メソッド
  save(self, file_path: Union[Path, str]) -> None
  save_agent(self, file_path: Union[Path, str]) -> None
  _should_continue(self, iterations: int) -> bool
  _return(self, output: AgentFinish, intermediate_steps: list) -> Dict[str, Any]
  _areturn(self, output: AgentFinish, intermediate_steps: list) -> Dict[str, Any]
  _take_next_step(self, name_to_tool_map: Dict[str, BaseTool], color_mapping: Dict[str, str], inputs: Dict[str, str], intermediate_steps: List[Tuple[AgentAction, str]]) -> Union[AgentFinish, Tuple[AgentAction, str]]
  _atake_next_step(self, name_to_tool_map: Dict[str, BaseTool], color_mapping: Dict[str, str], inputs: Dict[str, str], intermediate_steps: List[Tuple[AgentAction, str]]) -> Union[AgentFinish, Tuple[AgentAction, str]]

Coji MizoguchiCoji Mizoguchi

お次は Agent。Chain とは関係なくて、 AgentExecutor によって使われるもののようです。だけどメンバに LLMChain 持ってる。謎。

Agent

class Agent(BaseModel):
  llm_chain: LLMChain
  allowed_tools: Optional[List[str]] = None
  return_values: List[str] = ["output"]

@property
  _stop(self) -> List[str]
  _construct_scratchpad(self, intermediate_steps: List[Tuple[AgentAction, str]]) -> str
  finish_tool_name(self) -> str
  input_keys(self) -> List[str]

@property
@abstractmethod
  observation_prefix(self) -> str
  llm_prefix(self) -> str
  _agent_type(self) -> str:

@classmethod
@abstractmethod
  create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate

@classmethod
  _validate_tools(cls, tools: Sequence[BaseTool]) -> None
  from_llm_and_tools(cls, llm: BaseLLM, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, **kwargs: Any) -> Agent

@abstractmethod
  _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]

メソッド
  _fix_text(self, text: str) -> str
  _get_next_action(self, full_inputs: Dict[str, str]) -> AgentAction
  _aget_next_action(self, full_inputs: Dict[str, str]) -> AgentAction
  plan(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any) -> Union[AgentAction, AgentFinish]
  aplan(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any) -> Union[AgentAction, AgentFinish]
  get_full_inputs(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any) -> Dict[str, Any]
  prepare_for_new_call(self) -> None
  return_stopped_response(self, early_stopping_method: str, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any) -> AgentFinish
  dict(self, **kwargs: Any) -> Dict
  save(self, file_path: Union[Path, str]) -> None

@root_validator
  validate_prompt(cls, values: Dict) -> Dict
Coji MizoguchiCoji Mizoguchi

ConversationalAgent

class ConversationalAgent(Agent)
  ai_prefix: str = "AI"

@property
  _agent_type(self) -> str
  observation_prefix(self) -> str
  llm_prefix(self) -> str

@property(オーバーライド)
  finish_tool_name(self) -> str

@classmethod
  create_prompt(cls, tools: Sequence[BaseTool], prefix: str = PREFIX, suffix: str = SUFFIX, format_instructions: str = FORMAT_INSTRUCTIONS, ai_prefix: str = "AI", human_prefix: str = "Human", input_variables: Optional[List[str]] = None,) -> PromptTemplate
  from_llm_and_tools(cls, llm: BaseLLM, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, prefix: str = PREFIX, suffix: str = SUFFIX, format_instructions: str = FORMAT_INSTRUCTIONS, ai_prefix: str = "AI", human_prefix: str = "Human", input_variables: Optional[List[str]] = None, **kwargs: Any) -> Agent:

メソッド
  _extract_tool_and_input(self, llm_output: str) -> Optional[Tuple[str, str]]:
Coji MizoguchiCoji Mizoguchi

ConversationChain の predict でプロンプトがどのように作られるのかの流れを追う

Coji MizoguchiCoji Mizoguchi

ConversationChain.predict は定義されておらず、親の LLMChain.predict が呼ばれる。
さっそく意味がわからないwww

    def predict(self, **kwargs: Any) -> str:
        """Format prompt with kwargs and pass to LLM.
        Args:
            **kwargs: Keys to pass to prompt template.
        Returns:
            Completion from LLM.
        Example:
            .. code-block:: python
                completion = llm.predict(adjective="funny")
        """
        return self(kwargs)[self.output_key]
Coji MizoguchiCoji Mizoguchi

インスタンスメソッド内での self() はクラスの __call__ が呼ばれるってことでいいのだろうか?

Coji MizoguchiCoji Mizoguchi

repl で試したら雰囲気それでokっぽい (引数受け取り間違ってそう)

>>> class Hoge:
...   def __call__(self, message: str):
...     print("__call__", str)
...   def test(self):
...     self("test!")
...
>>> hoge = Hoge()
>>> hoge.test()
__call__ <class 'str'>
Coji MizoguchiCoji Mizoguchi

__call__ があるのは ConversationChain の親の LLMChain の親である Chain.__call__だった。

    def __call__(
        self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
    ) -> Dict[str, Any]:
        """Run the logic of this chain and add to output if desired.
        Args:
            inputs: Dictionary of inputs, or single input if chain expects
                only one param.
            return_only_outputs: boolean for whether to return only outputs in the
                response. If True, only new keys generated by this chain will be
                returned. If False, both input keys and new keys generated by this
                chain will be returned. Defaults to False.
        """
        inputs = self.prep_inputs(inputs)
        self.callback_manager.on_chain_start(
            {"name": self.__class__.__name__},
            inputs,
            verbose=self.verbose,
        )
        try:
            outputs = self._call(inputs)
        except (KeyboardInterrupt, Exception) as e:
            self.callback_manager.on_chain_error(e, verbose=self.verbose)
            raise e
        self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
        return self.prep_outputs(inputs, outputs, return_only_outputs)
Coji MizoguchiCoji Mizoguchi

prep_inputs で inputs の dict 作ってから self._call(inputs) を呼んで返り値の outputs を prep_outputs に渡してして返してる

prep_inputs, inputs, outputs, prep_outputs はすべて Chain で定義されており子孫では定義されてない。
_call だけは abstract method で、LLMChain と AgentExectutor それぞれで定義されてる。

Coji MizoguchiCoji Mizoguchi

Chain.prep_inputs

Memory(会話履歴) の memory_variable が inputs には必要なケースがあるのでその対処と、Memory オブジェクトがある場合は履歴の読み込みをここでやった上で Chain._validate_inputs を呼んでおそらくチェックをしてる。副作用あるのね。。

    def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
        """Validate and prep inputs."""
        if not isinstance(inputs, dict):
            _input_keys = set(self.input_keys)
            if self.memory is not None:
                # If there are multiple input keys, but some get set by memory so that
                # only one is not set, we can still figure out which key it is.
                _input_keys = _input_keys.difference(self.memory.memory_variables)
            if len(_input_keys) != 1:
                raise ValueError(
                    f"A single string input was passed in, but this chain expects "
                    f"multiple inputs ({_input_keys}). When a chain expects "
                    f"multiple inputs, please call it by passing in a dictionary, "
                    "eg `chain({'foo': 1, 'bar': 2})`"
                )
            inputs = {list(_input_keys)[0]: inputs}
        if self.memory is not None:
            external_context = self.memory.load_memory_variables(inputs)
            inputs = dict(inputs, **external_context)
        self._validate_inputs(inputs)
        return inputs
Coji MizoguchiCoji Mizoguchi

Chain.prep_outputs

はい。ここでも副作用でメモリ(会話履歴)の保存をしてる。

    def prep_outputs(
        self,
        inputs: Dict[str, str],
        outputs: Dict[str, str],
        return_only_outputs: bool = False,
    ) -> Dict[str, str]:
        """Validate and prep outputs."""
        self._validate_outputs(outputs)
        if self.memory is not None:
            self.memory.save_context(inputs, outputs)
        if return_only_outputs:
            return outputs
        else:
            return {**inputs, **outputs}
Coji MizoguchiCoji Mizoguchi

ConversationChain の親である LLMChain の _call
apply 呼んで返り値のdict一個目返してるだけ!

    def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
        return self.apply([inputs])[0]
Coji MizoguchiCoji Mizoguchi

apply は Chain でも定義されてるけど LLMChain にもあるのでそっちのほう
コメントが意味よくわからないけど、 self.generate して create_outputs してる、っぽい。

    def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
        """Utilize the LLM generate method for speed gains."""
        response = self.generate(input_list)
        return self.create_outputs(response)
Coji MizoguchiCoji Mizoguchi

LLMChain の generate
やっとプロンプトの生成してるっぽい所まで来た。ここを掘る必要がある。

    def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
        """Generate LLM result from inputs."""
        prompts, stop = self.prep_prompts(input_list)
        response = self.llm.generate(prompts, stop=stop)
        return response

同じく LLMChain の create_outputs
こっちは比較的簡単っぽい。python 特有のリスト内包表記を使ってるのでちょっと読みづらいけど、基本はLLMの出力結果から outputs の Dict[str,str] を作ってるだけだった。

Coji MizoguchiCoji Mizoguchi

LLMChain の prep_prompt

stop 謎いけど、それを無視するとプロンプトの input_variables を引数に prompt.format でプロンプトを生成していますね。複数のプロンプトを作っています。この辺はよくわからない。

    def prep_prompts(
        self, input_list: List[Dict[str, Any]]
    ) -> Tuple[List[str], Optional[List[str]]]:
        """Prepare prompts from inputs."""
        stop = None
        if "stop" in input_list[0]:
            stop = input_list[0]["stop"]
        prompts = []
        for inputs in input_list:
            selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
            prompt = self.prompt.format(**selected_inputs)
            _colored_text = get_colored_text(prompt, "green")
            _text = "Prompt after formatting:\n" + _colored_text
            self.callback_manager.on_text(_text, end="\n", verbose=self.verbose)
            if "stop" in inputs and inputs["stop"] != stop:
                raise ValueError(
                    "If `stop` is present in any inputs, should be present in all."
                )
            prompts.append(prompt)
        return prompts, stop

Coji MizoguchiCoji Mizoguchi

ConversationChain のデフォルトで使われるプロンプトの定義

_DEFAULT_TEMPLATE = """The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.
Current conversation:
{history}
Human: {input}
AI:"""
PROMPT = PromptTemplate(
    input_variables=["history", "input"], template=_DEFAULT_TEMPLATE
)
Coji MizoguchiCoji Mizoguchi

ConversationChain がプロンプトを作って LLM APIコールをするまでの流れ
サンプルとして npaka さんの記事にある ConversationBufferWindowMemory を使った場合の流れ。

呼び出し側

from langchain.llms import OpenAI
from langchain.chains import ConversationChain
from langchain.chains.conversation.memory import ConversationBufferWindowMemory

# ConversationChainの準備
conversation = ConversationChain(
    llm=OpenAI(),
    memory=ConversationBufferWindowMemory(k=2),
    verbose=True
)

# 推論の実行。
conversation.predict(input="こんにちは")
1. LLMChain.predict が { "input": "こんにちは" } で呼び出される
2. Chain.__call__ が  inputs={ "input": "こんにちは" } で呼び出される
2-1. Chain.pref_inputs が inputs = {"input": "こんにちは" } で呼び出される。返り値は { "input": "こんにちは", "history": "" } で戻る
3. LLMChain._call が inputs = {"input": "こんにちは", "history": "" } で呼び出される
4. LLMChain.apply が input_list: [{"input": "こんにちは", "history": ""}] で呼び出される
5. LLMChain.generate が input_list: [{"input": "こんにちは", "history": ""}] で呼び出される
5-1. LLMChain.prep_prompt が input_list: [{"input": "こんにちは", "history": ""}] で呼び出される。
5-1-1. PromptTemplate である self.promt の .format が  kwargs= {"input": "こんにちは", "history": ""} で呼び出される
5-1-1-1. DEFAULT_FORMATTER_MAPPINGに定義された string.formatter.format が (self.template, kwargs={"input": "こんにちは", "history": "" } で呼び出される
Coji MizoguchiCoji Mizoguchi

ちなみに最後できたプロンプトを使って OpenAI の completion を呼ぶ手前、langchain.llms.BaseOpenAI で get_sub_prompts関数が呼ばれてて、なにかバッチサイズに合わせるために分割して送ってそうなんだけど、よくわからない。

    def get_sub_prompts(
        self,
        params: Dict[str, Any],
        prompts: List[str],
        stop: Optional[List[str]] = None,
    ) -> List[List[str]]:
        """Get the sub prompts for llm call."""
        if stop is not None:
            if "stop" in params:
                raise ValueError("`stop` found in both the input and default params.")
            params["stop"] = stop
        if params["max_tokens"] == -1:
            if len(prompts) != 1:
                raise ValueError(
                    "max_tokens set to -1 not supported for multiple inputs."
                )
            params["max_tokens"] = self.max_tokens_for_prompt(prompts[0])
        sub_prompts = [
            prompts[i : i + self.batch_size]
            for i in range(0, len(prompts), self.batch_size)
        ]
        return sub_prompts