fix: Tool calls for trip planner

This commit is contained in:
bracesproul
2025-03-06 11:42:39 -08:00
parent 098d954d06
commit 60ba2a11c1
4 changed files with 57 additions and 67 deletions

View File

@@ -10,5 +10,5 @@ interface ToolCall {
export function findToolCall<Name extends string>(name: Name) { export function findToolCall<Name extends string>(name: Name) {
return <Args extends ZodTypeAny>( return <Args extends ZodTypeAny>(
x: ToolCall, x: ToolCall,
): x is { name: Name; args: z.infer<Args> } => x.name === name; ): x is { name: Name; args: z.infer<Args>; id?: string } => x.name === name;
} }

View File

@@ -158,8 +158,7 @@ export async function callTools(
buyStockToolCall.args.ticker, buyStockToolCall.args.ticker,
); );
ui.write("buy-stock", { ui.write("buy-stock", {
toolCallId: toolCallId: buyStockToolCall.id ?? "",
message.tool_calls?.find((tc) => tc.name === "buy-stock")?.id ?? "",
snapshot, snapshot,
quantity: buyStockToolCall.args.quantity, quantity: buyStockToolCall.args.quantity,
}); });

View File

@@ -5,46 +5,39 @@ import type ComponentMap from "../../uis/index";
import { z } from "zod"; import { z } from "zod";
import { LangGraphRunnableConfig } from "@langchain/langgraph"; import { LangGraphRunnableConfig } from "@langchain/langgraph";
import { getAccommodationsListProps } from "../utils/get-accommodations"; import { getAccommodationsListProps } from "../utils/get-accommodations";
import { findToolCall } from "../../find-tool-call";
const schema = z.object({ const listAccommodationsSchema = z.object({}).describe("A tool to list accommodations for the user")
listAccommodations: z const bookAccommodationSchema = z.object({
.boolean() accommodationName: z.string().describe("The name of the accommodation to book a reservation for"),
.optional() }).describe("A tool to book a reservation for an accommodation");
.describe( const listRestaurantsSchema = z.object({}).describe("A tool to list restaurants for the user");
"Whether or not the user has requested a list of accommodations for their trip.", const bookRestaurantSchema = z.object({
), restaurantName: z.string().describe("The name of the restaurant to book a reservation for"),
bookAccommodation: z }).describe("A tool to book a reservation for a restaurant");
.boolean()
.optional()
.describe(
"Whether or not the user has requested to book a reservation for an accommodation. If true, you MUST also set the 'accommodationName' field",
),
accommodationName: z
.string()
.optional()
.describe(
"The name of the accommodation to book a reservation for. Only required if the 'bookAccommodation' field is true.",
),
listRestaurants: z const ACCOMMODATIONS_TOOLS = [
.boolean() {
.optional() name: "list-accommodations",
.describe( description: "A tool to list accommodations for the user",
"Whether or not the user has requested a list of restaurants for their trip.", schema: listAccommodationsSchema,
), },
bookRestaurant: z {
.boolean() name: "book-accommodation",
.optional() description: "A tool to book a reservation for an accommodation",
.describe( schema: bookAccommodationSchema,
"Whether or not the user has requested to book a reservation for a restaurant. If true, you MUST also set the 'restaurantName' field", },
), {
restaurantName: z name: "list-restaurants",
.string() description: "A tool to list restaurants for the user",
.optional() schema: listRestaurantsSchema,
.describe( },
"The name of the restaurant to book a reservation for. Only required if the 'bookRestaurant' field is true.", {
), name: "book-restaurant",
}); description: "A tool to book a reservation for a restaurant",
schema: bookRestaurantSchema,
},
];
export async function callTools( export async function callTools(
state: TripPlannerState, state: TripPlannerState,
@@ -56,18 +49,7 @@ export async function callTools(
const ui = typedUi<typeof ComponentMap>(config); const ui = typedUi<typeof ComponentMap>(config);
const llm = new ChatOpenAI({ model: "gpt-4o", temperature: 0 }).bindTools( const llm = new ChatOpenAI({ model: "gpt-4o", temperature: 0 }).bindTools(ACCOMMODATIONS_TOOLS);
[
{
name: "trip-planner",
description: "A series of actions to take for planning a trip",
schema,
},
],
{
tool_choice: "trip-planner",
},
);
const response = await llm.invoke([ const response = await llm.invoke([
{ {
@@ -78,35 +60,44 @@ export async function callTools(
...state.messages, ...state.messages,
]); ]);
const tripPlan = response.tool_calls?.[0]?.args as const listAccommodationsToolCall = response.tool_calls?.find(
| z.infer<typeof schema> findToolCall("list-accommodations")<typeof listAccommodationsSchema>,
| undefined; );
const toolCallId = response.tool_calls?.[0]?.id; const bookAccommodationToolCall = response.tool_calls?.find(
if (!tripPlan || !toolCallId) { findToolCall("book-accommodation")<typeof bookAccommodationSchema>,
throw new Error("No trip plan found"); );
const listRestaurantsToolCall = response.tool_calls?.find(
findToolCall("list-restaurants")<typeof listRestaurantsSchema>,
);
const bookRestaurantToolCall = response.tool_calls?.find(
findToolCall("book-restaurant")<typeof bookRestaurantSchema>,
);
if (!listAccommodationsToolCall && !bookAccommodationToolCall && !listRestaurantsToolCall && !bookRestaurantToolCall) {
throw new Error("No tool calls found");
} }
if (tripPlan.listAccommodations) { if (listAccommodationsToolCall) {
ui.write("accommodations-list", { ui.write("accommodations-list", {
toolCallId, toolCallId: listAccommodationsToolCall.id ?? "",
...getAccommodationsListProps(state.tripDetails), ...getAccommodationsListProps(state.tripDetails),
}); });
} }
if (tripPlan.bookAccommodation && tripPlan.accommodationName) { if (bookAccommodationToolCall && bookAccommodationToolCall.args.accommodationName) {
ui.write("book-accommodation", { ui.write("book-accommodation", {
tripDetails: state.tripDetails, tripDetails: state.tripDetails,
accommodationName: tripPlan.accommodationName, accommodationName: bookAccommodationToolCall.args.accommodationName,
}); });
} }
if (tripPlan.listRestaurants) { if (listRestaurantsToolCall) {
ui.write("restaurants-list", { tripDetails: state.tripDetails }); ui.write("restaurants-list", { tripDetails: state.tripDetails });
} }
if (tripPlan.bookRestaurant && tripPlan.restaurantName) { if (bookRestaurantToolCall && bookRestaurantToolCall.args.restaurantName) {
ui.write("book-restaurant", { ui.write("book-restaurant", {
tripDetails: state.tripDetails, tripDetails: state.tripDetails,
restaurantName: tripPlan.restaurantName, restaurantName: bookRestaurantToolCall.args.restaurantName,
}); });
} }

View File

@@ -275,7 +275,7 @@ export default function AccommodationsList({
type: "tool", type: "tool",
tool_call_id: toolCallId, tool_call_id: toolCallId,
id: `${DO_NOT_RENDER_ID_PREFIX}${uuidv4()}`, id: `${DO_NOT_RENDER_ID_PREFIX}${uuidv4()}`,
name: "trip-planner", name: "book-accommodation",
content: JSON.stringify(orderDetails), content: JSON.stringify(orderDetails),
}, },
{ {