From 60ba2a11c1d86cd6f296a227178ccc0b32f9e03e Mon Sep 17 00:00:00 2001 From: bracesproul Date: Thu, 6 Mar 2025 11:42:39 -0800 Subject: [PATCH] fix: Tool calls for trip planner --- agent/find-tool-call.ts | 2 +- agent/stockbroker/nodes/tools.tsx | 3 +- agent/trip-planner/nodes/tools.tsx | 117 ++++++++---------- .../accommodations-list/index.tsx | 2 +- 4 files changed, 57 insertions(+), 67 deletions(-) diff --git a/agent/find-tool-call.ts b/agent/find-tool-call.ts index 0372551..7c9cc40 100644 --- a/agent/find-tool-call.ts +++ b/agent/find-tool-call.ts @@ -10,5 +10,5 @@ interface ToolCall { export function findToolCall(name: Name) { return ( x: ToolCall, - ): x is { name: Name; args: z.infer } => x.name === name; + ): x is { name: Name; args: z.infer; id?: string } => x.name === name; } diff --git a/agent/stockbroker/nodes/tools.tsx b/agent/stockbroker/nodes/tools.tsx index 03bc0e5..1b1f252 100644 --- a/agent/stockbroker/nodes/tools.tsx +++ b/agent/stockbroker/nodes/tools.tsx @@ -158,8 +158,7 @@ export async function callTools( buyStockToolCall.args.ticker, ); ui.write("buy-stock", { - toolCallId: - message.tool_calls?.find((tc) => tc.name === "buy-stock")?.id ?? "", + toolCallId: buyStockToolCall.id ?? "", snapshot, quantity: buyStockToolCall.args.quantity, }); diff --git a/agent/trip-planner/nodes/tools.tsx b/agent/trip-planner/nodes/tools.tsx index 3a60fda..8966bbc 100644 --- a/agent/trip-planner/nodes/tools.tsx +++ b/agent/trip-planner/nodes/tools.tsx @@ -5,46 +5,39 @@ import type ComponentMap from "../../uis/index"; import { z } from "zod"; import { LangGraphRunnableConfig } from "@langchain/langgraph"; import { getAccommodationsListProps } from "../utils/get-accommodations"; +import { findToolCall } from "../../find-tool-call"; -const schema = z.object({ - listAccommodations: z - .boolean() - .optional() - .describe( - "Whether or not the user has requested a list of accommodations for their trip.", - ), - bookAccommodation: z - .boolean() - .optional() - .describe( - "Whether or not the user has requested to book a reservation for an accommodation. If true, you MUST also set the 'accommodationName' field", - ), - accommodationName: z - .string() - .optional() - .describe( - "The name of the accommodation to book a reservation for. Only required if the 'bookAccommodation' field is true.", - ), +const listAccommodationsSchema = z.object({}).describe("A tool to list accommodations for the user") +const bookAccommodationSchema = z.object({ + accommodationName: z.string().describe("The name of the accommodation to book a reservation for"), +}).describe("A tool to book a reservation for an accommodation"); +const listRestaurantsSchema = z.object({}).describe("A tool to list restaurants for the user"); +const bookRestaurantSchema = z.object({ + restaurantName: z.string().describe("The name of the restaurant to book a reservation for"), +}).describe("A tool to book a reservation for a restaurant"); - listRestaurants: z - .boolean() - .optional() - .describe( - "Whether or not the user has requested a list of restaurants for their trip.", - ), - bookRestaurant: z - .boolean() - .optional() - .describe( - "Whether or not the user has requested to book a reservation for a restaurant. If true, you MUST also set the 'restaurantName' field", - ), - restaurantName: z - .string() - .optional() - .describe( - "The name of the restaurant to book a reservation for. Only required if the 'bookRestaurant' field is true.", - ), -}); +const ACCOMMODATIONS_TOOLS = [ + { + name: "list-accommodations", + description: "A tool to list accommodations for the user", + schema: listAccommodationsSchema, + }, + { + name: "book-accommodation", + description: "A tool to book a reservation for an accommodation", + schema: bookAccommodationSchema, + }, + { + name: "list-restaurants", + description: "A tool to list restaurants for the user", + schema: listRestaurantsSchema, + }, + { + name: "book-restaurant", + description: "A tool to book a reservation for a restaurant", + schema: bookRestaurantSchema, + }, +]; export async function callTools( state: TripPlannerState, @@ -56,18 +49,7 @@ export async function callTools( const ui = typedUi(config); - const llm = new ChatOpenAI({ model: "gpt-4o", temperature: 0 }).bindTools( - [ - { - name: "trip-planner", - description: "A series of actions to take for planning a trip", - schema, - }, - ], - { - tool_choice: "trip-planner", - }, - ); + const llm = new ChatOpenAI({ model: "gpt-4o", temperature: 0 }).bindTools(ACCOMMODATIONS_TOOLS); const response = await llm.invoke([ { @@ -78,35 +60,44 @@ export async function callTools( ...state.messages, ]); - const tripPlan = response.tool_calls?.[0]?.args as - | z.infer - | undefined; - const toolCallId = response.tool_calls?.[0]?.id; - if (!tripPlan || !toolCallId) { - throw new Error("No trip plan found"); + const listAccommodationsToolCall = response.tool_calls?.find( + findToolCall("list-accommodations"), + ); + const bookAccommodationToolCall = response.tool_calls?.find( + findToolCall("book-accommodation"), + ); + const listRestaurantsToolCall = response.tool_calls?.find( + findToolCall("list-restaurants"), + ); + const bookRestaurantToolCall = response.tool_calls?.find( + findToolCall("book-restaurant"), + ); + + if (!listAccommodationsToolCall && !bookAccommodationToolCall && !listRestaurantsToolCall && !bookRestaurantToolCall) { + throw new Error("No tool calls found"); } - if (tripPlan.listAccommodations) { + if (listAccommodationsToolCall) { ui.write("accommodations-list", { - toolCallId, + toolCallId: listAccommodationsToolCall.id ?? "", ...getAccommodationsListProps(state.tripDetails), }); } - if (tripPlan.bookAccommodation && tripPlan.accommodationName) { + if (bookAccommodationToolCall && bookAccommodationToolCall.args.accommodationName) { ui.write("book-accommodation", { tripDetails: state.tripDetails, - accommodationName: tripPlan.accommodationName, + accommodationName: bookAccommodationToolCall.args.accommodationName, }); } - if (tripPlan.listRestaurants) { + if (listRestaurantsToolCall) { ui.write("restaurants-list", { tripDetails: state.tripDetails }); } - if (tripPlan.bookRestaurant && tripPlan.restaurantName) { + if (bookRestaurantToolCall && bookRestaurantToolCall.args.restaurantName) { ui.write("book-restaurant", { tripDetails: state.tripDetails, - restaurantName: tripPlan.restaurantName, + restaurantName: bookRestaurantToolCall.args.restaurantName, }); } diff --git a/agent/uis/trip-planner/accommodations-list/index.tsx b/agent/uis/trip-planner/accommodations-list/index.tsx index 87e28f2..0a46ede 100644 --- a/agent/uis/trip-planner/accommodations-list/index.tsx +++ b/agent/uis/trip-planner/accommodations-list/index.tsx @@ -275,7 +275,7 @@ export default function AccommodationsList({ type: "tool", tool_call_id: toolCallId, id: `${DO_NOT_RENDER_ID_PREFIX}${uuidv4()}`, - name: "trip-planner", + name: "book-accommodation", content: JSON.stringify(orderDetails), }, {