LangChain の Conversation まわりのプロンプト生成実装理解
LangChain の ConversationChain と ConversationAgent の違いがよくわからなかったので調べます。特にプロンプト周り。
見ているソースは 2023/2/24 時点の master ブランチでの最新コミットです。
継承関係は
BaseModel => Chain => LLMChain => ConversationChain
BaseModel => Chain => AgentExecutor
BaseModel => Agent => ConversationalAgent
だった。
BaseModel は Pydantic のモデルで、TS でいう zod的なやつ。
ConversationChain と対比されるのは AgentExecutor 。わかりづらいい
まず ConversationChain と AgentExecutor の共通の親である Chain から
class Chain(BaseModel, ABC):
プロパティ
memory: Optional[Memory]
callback_manager: BaseCallbackManager
verbose: bool
@property
_chain_type(self) -> str
@validator
set_callback_manager(cls, callback_manager: Opional[BaseCallbackManager]) -> BaseCallbackManager
set_verbose(cls, verbose: Optional[bool]) -> bool
@property
@abstractmethod
input_keys() -> List[str]
output_keys() => List[str]
@abstractmethod
_call(self, inputs: Dict[str, str]) -> Dict[str, str]
メソッド
_validate_inputs(self, inputs: Dict[str, str]) -> None
_validate_outputs(self, outputs: Dict[str, str]) -> None
_acall(self, inputs: Dict[str, str]) -> Dict[str, str]
__call__(self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False) -> Dict[str, Any]
acall(self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False) -> Dict[str, Any]
prep_outputs(self, inputs: Dict[str, str], outputs: Dict[str, str], return_only_outputs: bool = False) -> Dict[str, str]
prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]
apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
run(self, *args: str, **kwargs: str) -> str:
arun(self, *args: str, **kwargs: str) -> str:
dict(self, **kwargs: Any) -> Dict:
save(self, file_path: Union[Path, str]) -> None:
class LLMChain(Chain, BaseModel)
prompt: BasePromptTemplate
llm: BaseLLM
output_key: str = "text"
@property(オーバーライド)
input_keys(self) -> List[str]
output_keys(self) -> List[str]
_chain_type(self) -> str: # llm_chain
メソッド(オーバーライド)
apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]
_call(self, inputs: Dict[str, Any]) -> Dict[str, str]
_acall(self, inputs: Dict[str, Any]) -> Dict[str, str]
メソッド
generate(self, input_list: List[Dict[str, Any]]) -> LLMResult
agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult
prep_prompts(self, input_list: List[Dict[str, Any]]) -> Tuple[List[str], Optional[List[str]]]
aprep_prompts(self, input_list: List[Dict[str, Any]]) -> Tuple[List[str], Optional[List[str]]]
aapply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]
create_outputs(self, response: LLMResult) -> List[Dict[str, str]]
predict(self, **kwargs: Any) -> str
apredict(self, **kwargs: Any) -> str
predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]
apply_and_parse(self, input_list: List[Dict[str, Any]]) -> Sequence[Union[str, List[str], Dict[str, str]]]
_parse_result(self, result: List[Dict[str, str]]) -> Sequence[Union[str, List[str], Dict[str, str]]]
aapply_and_parse(
self, input_list: List[Dict[str, Any]]
) -> Sequence[Union[str, List[str], Dict[str, str]]]
@classmethod
from_string(cls, llm: BaseLLM, template: str) -> Chain
class ConversationChain(LLMChain, BaseModel)
memory: Memory = Field(default_factory=ConversationBufferMemory)
prompt: BasePromptTemplate = PROMPT
input_key: str = "input"
output_key: str = "response"
@property (オーバーライド)
input_keys(self) -> List[str]
@root_validator
validate_prompt_input_variables(cls, values: Dict) -> Dict
class AgentExecutor(Chain, BaseModel):
agent: Agent
tools: Sequence[BaseTool]
return_intermediate_steps: bool = False
max_iterations: Optional[int] = 15
early_stopping_method: str = "force"
@property(オーバーライド)
input_keys(self) -> List[str]
output_keys(self) -> List[str]
@classmethod
from_agent_and_tools(cls, agent: Agent, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, **kwargs: Any) -> AgentExecutor
@root_validator()
validate_tools(cls, values: Dict) -> Dict
メソッド(オーバーライド)
_call(self, inputs: Dict[str, str]) -> Dict[str, Any]
_acall(self, inputs: Dict[str, str]) -> Dict[str, str]
メソッド
save(self, file_path: Union[Path, str]) -> None
save_agent(self, file_path: Union[Path, str]) -> None
_should_continue(self, iterations: int) -> bool
_return(self, output: AgentFinish, intermediate_steps: list) -> Dict[str, Any]
_areturn(self, output: AgentFinish, intermediate_steps: list) -> Dict[str, Any]
_take_next_step(self, name_to_tool_map: Dict[str, BaseTool], color_mapping: Dict[str, str], inputs: Dict[str, str], intermediate_steps: List[Tuple[AgentAction, str]]) -> Union[AgentFinish, Tuple[AgentAction, str]]
_atake_next_step(self, name_to_tool_map: Dict[str, BaseTool], color_mapping: Dict[str, str], inputs: Dict[str, str], intermediate_steps: List[Tuple[AgentAction, str]]) -> Union[AgentFinish, Tuple[AgentAction, str]]
お次は Agent。Chain とは関係なくて、 AgentExecutor によって使われるもののようです。だけどメンバに LLMChain 持ってる。謎。
class Agent(BaseModel):
llm_chain: LLMChain
allowed_tools: Optional[List[str]] = None
return_values: List[str] = ["output"]
@property
_stop(self) -> List[str]
_construct_scratchpad(self, intermediate_steps: List[Tuple[AgentAction, str]]) -> str
finish_tool_name(self) -> str
input_keys(self) -> List[str]
@property
@abstractmethod
observation_prefix(self) -> str
llm_prefix(self) -> str
_agent_type(self) -> str:
@classmethod
@abstractmethod
create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate
@classmethod
_validate_tools(cls, tools: Sequence[BaseTool]) -> None
from_llm_and_tools(cls, llm: BaseLLM, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, **kwargs: Any) -> Agent
@abstractmethod
_extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]
メソッド
_fix_text(self, text: str) -> str
_get_next_action(self, full_inputs: Dict[str, str]) -> AgentAction
_aget_next_action(self, full_inputs: Dict[str, str]) -> AgentAction
plan(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any) -> Union[AgentAction, AgentFinish]
aplan(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any) -> Union[AgentAction, AgentFinish]
get_full_inputs(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any) -> Dict[str, Any]
prepare_for_new_call(self) -> None
return_stopped_response(self, early_stopping_method: str, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any) -> AgentFinish
dict(self, **kwargs: Any) -> Dict
save(self, file_path: Union[Path, str]) -> None
@root_validator
validate_prompt(cls, values: Dict) -> Dict
class ConversationalAgent(Agent)
ai_prefix: str = "AI"
@property
_agent_type(self) -> str
observation_prefix(self) -> str
llm_prefix(self) -> str
@property(オーバーライド)
finish_tool_name(self) -> str
@classmethod
create_prompt(cls, tools: Sequence[BaseTool], prefix: str = PREFIX, suffix: str = SUFFIX, format_instructions: str = FORMAT_INSTRUCTIONS, ai_prefix: str = "AI", human_prefix: str = "Human", input_variables: Optional[List[str]] = None,) -> PromptTemplate
from_llm_and_tools(cls, llm: BaseLLM, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, prefix: str = PREFIX, suffix: str = SUFFIX, format_instructions: str = FORMAT_INSTRUCTIONS, ai_prefix: str = "AI", human_prefix: str = "Human", input_variables: Optional[List[str]] = None, **kwargs: Any) -> Agent:
メソッド
_extract_tool_and_input(self, llm_output: str) -> Optional[Tuple[str, str]]:
ConversationChain の predict でプロンプトがどのように作られるのかの流れを追う
ConversationChain.predict は定義されておらず、親の LLMChain.predict が呼ばれる。
さっそく意味がわからないwww
def predict(self, **kwargs: Any) -> str:
"""Format prompt with kwargs and pass to LLM.
Args:
**kwargs: Keys to pass to prompt template.
Returns:
Completion from LLM.
Example:
.. code-block:: python
completion = llm.predict(adjective="funny")
"""
return self(kwargs)[self.output_key]
インスタンスメソッド内での self() はクラスの __call__ が呼ばれるってことでいいのだろうか?
repl で試したら雰囲気それでokっぽい (引数受け取り間違ってそう)
>>> class Hoge:
... def __call__(self, message: str):
... print("__call__", str)
... def test(self):
... self("test!")
...
>>> hoge = Hoge()
>>> hoge.test()
__call__ <class 'str'>
__call__ があるのは ConversationChain の親の LLMChain の親である Chain.__call__だった。
def __call__(
self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
Args:
inputs: Dictionary of inputs, or single input if chain expects
only one param.
return_only_outputs: boolean for whether to return only outputs in the
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
"""
inputs = self.prep_inputs(inputs)
self.callback_manager.on_chain_start(
{"name": self.__class__.__name__},
inputs,
verbose=self.verbose,
)
try:
outputs = self._call(inputs)
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_chain_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
return self.prep_outputs(inputs, outputs, return_only_outputs)
prep_inputs で inputs の dict 作ってから self._call(inputs) を呼んで返り値の outputs を prep_outputs に渡してして返してる
prep_inputs, inputs, outputs, prep_outputs はすべて Chain で定義されており子孫では定義されてない。
_call だけは abstract method で、LLMChain と AgentExectutor それぞれで定義されてる。
Memory(会話履歴) の memory_variable が inputs には必要なケースがあるのでその対処と、Memory オブジェクトがある場合は履歴の読み込みをここでやった上で Chain._validate_inputs を呼んでおそらくチェックをしてる。副作用あるのね。。
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
"""Validate and prep inputs."""
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(self.memory.memory_variables)
if len(_input_keys) != 1:
raise ValueError(
f"A single string input was passed in, but this chain expects "
f"multiple inputs ({_input_keys}). When a chain expects "
f"multiple inputs, please call it by passing in a dictionary, "
"eg `chain({'foo': 1, 'bar': 2})`"
)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs)
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
return inputs
はい。ここでも副作用でメモリ(会話履歴)の保存をしてる。
def prep_outputs(
self,
inputs: Dict[str, str],
outputs: Dict[str, str],
return_only_outputs: bool = False,
) -> Dict[str, str]:
"""Validate and prep outputs."""
self._validate_outputs(outputs)
if self.memory is not None:
self.memory.save_context(inputs, outputs)
if return_only_outputs:
return outputs
else:
return {**inputs, **outputs}
ConversationChain の親である LLMChain の _call
apply 呼んで返り値のdict一個目返してるだけ!
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return self.apply([inputs])[0]
apply は Chain でも定義されてるけど LLMChain にもあるのでそっちのほう
コメントが意味よくわからないけど、 self.generate して create_outputs してる、っぽい。
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
response = self.generate(input_list)
return self.create_outputs(response)
LLMChain の generate
やっとプロンプトの生成してるっぽい所まで来た。ここを掘る必要がある。
def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list)
response = self.llm.generate(prompts, stop=stop)
return response
同じく LLMChain の create_outputs
こっちは比較的簡単っぽい。python 特有のリスト内包表記を使ってるのでちょっと読みづらいけど、基本はLLMの出力結果から outputs の Dict[str,str] を作ってるだけだった。
stop 謎いけど、それを無視するとプロンプトの input_variables を引数に prompt.format でプロンプトを生成していますね。複数のプロンプトを作っています。この辺はよくわからない。
def prep_prompts(
self, input_list: List[Dict[str, Any]]
) -> Tuple[List[str], Optional[List[str]]]:
"""Prepare prompts from inputs."""
stop = None
if "stop" in input_list[0]:
stop = input_list[0]["stop"]
prompts = []
for inputs in input_list:
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
prompt = self.prompt.format(**selected_inputs)
_colored_text = get_colored_text(prompt, "green")
_text = "Prompt after formatting:\n" + _colored_text
self.callback_manager.on_text(_text, end="\n", verbose=self.verbose)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
)
prompts.append(prompt)
return prompts, stop
input_list ってそもそも何が入ってるんだっけ?つらい
LLMChain の _call で固定1個のリストにされているけど、元は inputs だった
それはもとをたどると ConversationChain.predict に渡される引数の dict
実際は以下のような感じ。
conversation.predict(input="こんにちは!")
ConversationChain のデフォルトで使われるプロンプトの定義
_DEFAULT_TEMPLATE = """The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.
Current conversation:
{history}
Human: {input}
AI:"""
PROMPT = PromptTemplate(
input_variables=["history", "input"], template=_DEFAULT_TEMPLATE
)
うーん、迷子になっちゃった。
ちょっと整理しよう
ConversationChain がプロンプトを作って LLM APIコールをするまでの流れ
サンプルとして npaka さんの記事にある ConversationBufferWindowMemory を使った場合の流れ。
呼び出し側
from langchain.llms import OpenAI
from langchain.chains import ConversationChain
from langchain.chains.conversation.memory import ConversationBufferWindowMemory
# ConversationChainの準備
conversation = ConversationChain(
llm=OpenAI(),
memory=ConversationBufferWindowMemory(k=2),
verbose=True
)
# 推論の実行。
conversation.predict(input="こんにちは")
1. LLMChain.predict が { "input": "こんにちは" } で呼び出される
2. Chain.__call__ が inputs={ "input": "こんにちは" } で呼び出される
2-1. Chain.pref_inputs が inputs = {"input": "こんにちは" } で呼び出される。返り値は { "input": "こんにちは", "history": "" } で戻る
3. LLMChain._call が inputs = {"input": "こんにちは", "history": "" } で呼び出される
4. LLMChain.apply が input_list: [{"input": "こんにちは", "history": ""}] で呼び出される
5. LLMChain.generate が input_list: [{"input": "こんにちは", "history": ""}] で呼び出される
5-1. LLMChain.prep_prompt が input_list: [{"input": "こんにちは", "history": ""}] で呼び出される。
5-1-1. PromptTemplate である self.promt の .format が kwargs= {"input": "こんにちは", "history": ""} で呼び出される
5-1-1-1. DEFAULT_FORMATTER_MAPPINGに定義された string.formatter.format が (self.template, kwargs={"input": "こんにちは", "history": "" } で呼び出される
今日のMPがゼロになったので撤退しました。手強い!
ちなみに最後できたプロンプトを使って OpenAI の completion を呼ぶ手前、langchain.llms.BaseOpenAI で get_sub_prompts関数が呼ばれてて、なにかバッチサイズに合わせるために分割して送ってそうなんだけど、よくわからない。
def get_sub_prompts(
self,
params: Dict[str, Any],
prompts: List[str],
stop: Optional[List[str]] = None,
) -> List[List[str]]:
"""Get the sub prompts for llm call."""
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
if params["max_tokens"] == -1:
if len(prompts) != 1:
raise ValueError(
"max_tokens set to -1 not supported for multiple inputs."
)
params["max_tokens"] = self.max_tokens_for_prompt(prompts[0])
sub_prompts = [
prompts[i : i + self.batch_size]
for i in range(0, len(prompts), self.batch_size)
]
return sub_prompts