fix: Tool calls for trip planner

This commit is contained in:
bracesproul
2025-03-06 11:42:39 -08:00
parent 098d954d06
commit 60ba2a11c1
4 changed files with 57 additions and 67 deletions

View File

@@ -5,46 +5,39 @@ import type ComponentMap from "../../uis/index";
import { z } from "zod";
import { LangGraphRunnableConfig } from "@langchain/langgraph";
import { getAccommodationsListProps } from "../utils/get-accommodations";
import { findToolCall } from "../../find-tool-call";
const schema = z.object({
listAccommodations: z
.boolean()
.optional()
.describe(
"Whether or not the user has requested a list of accommodations for their trip.",
),
bookAccommodation: z
.boolean()
.optional()
.describe(
"Whether or not the user has requested to book a reservation for an accommodation. If true, you MUST also set the 'accommodationName' field",
),
accommodationName: z
.string()
.optional()
.describe(
"The name of the accommodation to book a reservation for. Only required if the 'bookAccommodation' field is true.",
),
const listAccommodationsSchema = z.object({}).describe("A tool to list accommodations for the user")
const bookAccommodationSchema = z.object({
accommodationName: z.string().describe("The name of the accommodation to book a reservation for"),
}).describe("A tool to book a reservation for an accommodation");
const listRestaurantsSchema = z.object({}).describe("A tool to list restaurants for the user");
const bookRestaurantSchema = z.object({
restaurantName: z.string().describe("The name of the restaurant to book a reservation for"),
}).describe("A tool to book a reservation for a restaurant");
listRestaurants: z
.boolean()
.optional()
.describe(
"Whether or not the user has requested a list of restaurants for their trip.",
),
bookRestaurant: z
.boolean()
.optional()
.describe(
"Whether or not the user has requested to book a reservation for a restaurant. If true, you MUST also set the 'restaurantName' field",
),
restaurantName: z
.string()
.optional()
.describe(
"The name of the restaurant to book a reservation for. Only required if the 'bookRestaurant' field is true.",
),
});
const ACCOMMODATIONS_TOOLS = [
{
name: "list-accommodations",
description: "A tool to list accommodations for the user",
schema: listAccommodationsSchema,
},
{
name: "book-accommodation",
description: "A tool to book a reservation for an accommodation",
schema: bookAccommodationSchema,
},
{
name: "list-restaurants",
description: "A tool to list restaurants for the user",
schema: listRestaurantsSchema,
},
{
name: "book-restaurant",
description: "A tool to book a reservation for a restaurant",
schema: bookRestaurantSchema,
},
];
export async function callTools(
state: TripPlannerState,
@@ -56,18 +49,7 @@ export async function callTools(
const ui = typedUi<typeof ComponentMap>(config);
const llm = new ChatOpenAI({ model: "gpt-4o", temperature: 0 }).bindTools(
[
{
name: "trip-planner",
description: "A series of actions to take for planning a trip",
schema,
},
],
{
tool_choice: "trip-planner",
},
);
const llm = new ChatOpenAI({ model: "gpt-4o", temperature: 0 }).bindTools(ACCOMMODATIONS_TOOLS);
const response = await llm.invoke([
{
@@ -78,35 +60,44 @@ export async function callTools(
...state.messages,
]);
const tripPlan = response.tool_calls?.[0]?.args as
| z.infer<typeof schema>
| undefined;
const toolCallId = response.tool_calls?.[0]?.id;
if (!tripPlan || !toolCallId) {
throw new Error("No trip plan found");
const listAccommodationsToolCall = response.tool_calls?.find(
findToolCall("list-accommodations")<typeof listAccommodationsSchema>,
);
const bookAccommodationToolCall = response.tool_calls?.find(
findToolCall("book-accommodation")<typeof bookAccommodationSchema>,
);
const listRestaurantsToolCall = response.tool_calls?.find(
findToolCall("list-restaurants")<typeof listRestaurantsSchema>,
);
const bookRestaurantToolCall = response.tool_calls?.find(
findToolCall("book-restaurant")<typeof bookRestaurantSchema>,
);
if (!listAccommodationsToolCall && !bookAccommodationToolCall && !listRestaurantsToolCall && !bookRestaurantToolCall) {
throw new Error("No tool calls found");
}
if (tripPlan.listAccommodations) {
if (listAccommodationsToolCall) {
ui.write("accommodations-list", {
toolCallId,
toolCallId: listAccommodationsToolCall.id ?? "",
...getAccommodationsListProps(state.tripDetails),
});
}
if (tripPlan.bookAccommodation && tripPlan.accommodationName) {
if (bookAccommodationToolCall && bookAccommodationToolCall.args.accommodationName) {
ui.write("book-accommodation", {
tripDetails: state.tripDetails,
accommodationName: tripPlan.accommodationName,
accommodationName: bookAccommodationToolCall.args.accommodationName,
});
}
if (tripPlan.listRestaurants) {
if (listRestaurantsToolCall) {
ui.write("restaurants-list", { tripDetails: state.tripDetails });
}
if (tripPlan.bookRestaurant && tripPlan.restaurantName) {
if (bookRestaurantToolCall && bookRestaurantToolCall.args.restaurantName) {
ui.write("book-restaurant", {
tripDetails: state.tripDetails,
restaurantName: tripPlan.restaurantName,
restaurantName: bookRestaurantToolCall.args.restaurantName,
});
}