How to pass tool outputs to chat models
Prerequisites
This guide assumes familiarity with the following concepts:
If we're using the model-generated tool invocations to actually call tools and want to pass the tool results back to the model, we can do so using ToolMessage
s and ToolCall
s. First, let's define our tools and our model.
from langchain_core.tools import tool
@tool
def add(a: int, b: int) -> int:
"""Adds a and b."""
return a + b
@tool
def multiply(a: int, b: int) -> int:
"""Multiplies a and b."""
return a * b
tools = [add, multiply]
API Reference:tool
import os
from getpass import getpass
from langchain_openai import ChatOpenAI
os.environ["OPENAI_API_KEY"] = getpass()
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
llm_with_tools = llm.bind_tools(tools)
API Reference:ChatOpenAI
The nice thing about Tools is that if we invoke them with a ToolCall, we'll automatically get back a ToolMessage that can be fed back to the model:
langchain-core >= 0.2.19
This functionality was added in langchain-core == 0.2.19
. Please make sure your package is up to date.
from langchain_core.messages import HumanMessage, ToolMessage
query = "What is 3 * 12? Also, what is 11 + 49?"
messages = [HumanMessage(query)]
ai_msg = llm_with_tools.invoke(messages)
messages.append(ai_msg)
for tool_call in ai_msg.tool_calls:
selected_tool = {"add": add, "multiply": multiply}[tool_call["name"].lower()]
tool_msg = selected_tool.invoke(tool_call)
messages.append(tool_msg)
messages
API Reference:HumanMessage | ToolMessage
[HumanMessage(content='What is 3 * 12? Also, what is 11 + 49?'),
AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Smg3NHJNxrKfAmd4f9GkaYn3', 'function': {'arguments': '{"a": 3, "b": 12}', 'name': 'multiply'}, 'type': 'function'}, {'id': 'call_55K1C0DmH6U5qh810gW34xZ0', 'function': {'arguments': '{"a": 11, "b": 49}', 'name': 'add'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 49, 'prompt_tokens': 88, 'total_tokens': 137}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': None, 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-56657feb-96dd-456c-ab8e-1857eab2ade0-0', tool_calls=[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_Smg3NHJNxrKfAmd4f9GkaYn3', 'type': 'tool_call'}, {'name': 'add', 'args': {'a': 11, 'b': 49}, 'id': 'call_55K1C0DmH6U5qh810gW34xZ0', 'type': 'tool_call'}], usage_metadata={'input_tokens': 88, 'output_tokens': 49, 'total_tokens': 137}),
ToolMessage(content='36', name='multiply', tool_call_id='call_Smg3NHJNxrKfAmd4f9GkaYn3'),
ToolMessage(content='60', name='add', tool_call_id='call_55K1C0DmH6U5qh810gW34xZ0')]
llm_with_tools.invoke(messages)
AIMessage(content='3 * 12 is 36 and 11 + 49 is 60.', response_metadata={'token_usage': {'completion_tokens': 18, 'prompt_tokens': 153, 'total_tokens': 171}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-ba5032f0-f773-406d-a408-8314e66511d0-0', usage_metadata={'input_tokens': 153, 'output_tokens': 18, 'total_tokens': 171})
Note that we pass back the same id
in the ToolMessage
as the what we receive from the model in order to help the model match tool responses with tool calls.