feat: Basic trip planner agent

This commit is contained in:
bracesproul
2025-03-03 16:51:46 -08:00
parent 5256efb23f
commit da396ac83e
22 changed files with 471 additions and 13 deletions

View File

@@ -0,0 +1,70 @@
import { ChatOpenAI } from "@langchain/openai";
import { TripPlannerState } from "../types";
import { z } from "zod";
import { formatMessages } from "agent/utils/format-messages";
export async function classify(
state: TripPlannerState,
): Promise<Partial<TripPlannerState>> {
if (!state.tripDetails) {
// Can not classify if tripDetails are undefined
return {};
}
const schema = z.object({
isRelevant: z
.boolean()
.describe(
"Whether the trip details are still relevant to the user's request.",
),
});
const model = new ChatOpenAI({ model: "gpt-4o", temperature: 0 }).bindTools(
[
{
name: "classify",
description:
"A tool to classify whether or not the trip details are still relevant to the user's request.",
schema,
},
],
{
tool_choice: "classify",
},
);
const prompt = `You're an AI assistant for planning trips. The user has already specified the following details for their trip:
- location - ${state.tripDetails.location}
- startDate - ${state.tripDetails.startDate}
- endDate - ${state.tripDetails.endDate}
- numberOfGuests - ${state.tripDetails.numberOfGuests}
Your task is to carefully read over the user's conversation, and determine if their trip details are still relevant to their most recent request.
You should set is relevant to false if they are now asking about a new location, trip duration, or number of guests.
If they do NOT change their request details (or they never specified them), please set is relevant to true.
`;
const humanMessage = `Here is the entire conversation so far:\n${formatMessages(state.messages)}`;
const response = await model.invoke([
{ role: "system", content: prompt },
{ role: "human", content: humanMessage },
]);
const classificationDetails = response.tool_calls?.[0]?.args as
| z.infer<typeof schema>
| undefined;
if (!classificationDetails) {
throw new Error("Could not classify trip details");
}
if (!classificationDetails.isRelevant) {
return {
tripDetails: undefined,
};
}
// If it is relevant, return the state unchanged
return {};
}

View File

@@ -0,0 +1,122 @@
import { ChatOpenAI } from "@langchain/openai";
import { TripDetails, TripPlannerState } from "../types";
import { z } from "zod";
import { formatMessages } from "agent/utils/format-messages";
function calculateDates(
startDate: string | undefined,
endDate: string | undefined,
): { startDate: Date; endDate: Date } {
const now = new Date();
if (!startDate && !endDate) {
// Both undefined: 4 and 5 weeks in future
const start = new Date(now);
start.setDate(start.getDate() + 28); // 4 weeks
const end = new Date(now);
end.setDate(end.getDate() + 35); // 5 weeks
return { startDate: start, endDate: end };
}
if (startDate && !endDate) {
// Only start defined: end is 1 week after
const start = new Date(startDate);
const end = new Date(start);
end.setDate(end.getDate() + 7);
return { startDate: start, endDate: end };
}
if (!startDate && endDate) {
// Only end defined: start is 1 week before
const end = new Date(endDate);
const start = new Date(end);
start.setDate(start.getDate() - 7);
return { startDate: start, endDate: end };
}
// Both defined: use as is
return {
startDate: new Date(startDate!),
endDate: new Date(endDate!),
};
}
export async function extraction(
state: TripPlannerState,
): Promise<Partial<TripPlannerState>> {
const schema = z.object({
location: z
.string()
.describe(
"The location to plan the trip for. Can be a city, state, or country.",
),
startDate: z
.string()
.optional()
.describe("The start date of the trip. Should be in YYYY-MM-DD format"),
endDate: z
.string()
.optional()
.describe("The end date of the trip. Should be in YYYY-MM-DD format"),
numberOfGuests: z
.number()
.optional()
.describe("The number of guests for the trip"),
});
const model = new ChatOpenAI({ model: "gpt-4o", temperature: 0 }).bindTools([
{
name: "extract",
description: "A tool to extract information from a user's request.",
schema: schema,
},
]);
const prompt = `You're an AI assistant for planning trips. The user has requested information about a trip they want to go on.
Before you can help them, you need to extract the following information from their request:
- location - The location to plan the trip for. Can be a city, state, or country.
- startDate - The start date of the trip. Should be in YYYY-MM-DD format. Optional
- endDate - The end date of the trip. Should be in YYYY-MM-DD format. Optional
- numberOfGuests - The number of guests for the trip. Optional
You are provided with the ENTIRE conversation history between you, and the user. Use these messages to extract the necessary information.
Do NOT guess, or make up any information. If the user did NOT specify a location, please respond with a request for them to specify the location.
It should be a single sentence, along the lines of "Please specify the location for the trip you want to go on".
Extract only what is specified by the user. It is okay to leave fields blank if the user did not specify them.
`;
const humanMessage = `Here is the entire conversation so far:\n${formatMessages(state.messages)}`;
const response = await model.invoke([
{ role: "system", content: prompt },
{ role: "human", content: humanMessage },
]);
const extractedDetails = response.tool_calls?.[0]?.args as
| z.infer<typeof schema>
| undefined;
if (!extractedDetails) {
return {
messages: [response],
};
}
const { startDate, endDate } = calculateDates(
extractedDetails.startDate,
extractedDetails.endDate,
);
const extractionDetailsWithDefaults: TripDetails = {
startDate,
endDate,
numberOfGuests: extractedDetails.numberOfGuests ?? 2,
location: extractedDetails.location,
};
return {
tripDetails: extractionDetailsWithDefaults,
};
}

View File

@@ -0,0 +1,118 @@
import { TripPlannerState } from "../types";
import { ChatOpenAI } from "@langchain/openai";
import { typedUi } from "@langchain/langgraph-sdk/react-ui/server";
import type ComponentMap from "../../uis/index";
import { z } from "zod";
import { LangGraphRunnableConfig } from "@langchain/langgraph";
const schema = z.object({
listAccommodations: z
.boolean()
.optional()
.describe(
"Whether or not the user has requested a list of accommodations for their trip.",
),
bookAccommodation: z
.boolean()
.optional()
.describe(
"Whether or not the user has requested to book a reservation for an accommodation. If true, you MUST also set the 'accommodationName' field",
),
accommodationName: z
.string()
.optional()
.describe(
"The name of the accommodation to book a reservation for. Only required if the 'bookAccommodation' field is true.",
),
listRestaurants: z
.boolean()
.optional()
.describe(
"Whether or not the user has requested a list of restaurants for their trip.",
),
bookRestaurant: z
.boolean()
.optional()
.describe(
"Whether or not the user has requested to book a reservation for a restaurant. If true, you MUST also set the 'restaurantName' field",
),
restaurantName: z
.string()
.optional()
.describe(
"The name of the restaurant to book a reservation for. Only required if the 'bookRestaurant' field is true.",
),
});
export async function callTools(
state: TripPlannerState,
config: LangGraphRunnableConfig,
): Promise<Partial<TripPlannerState>> {
if (!state.tripDetails) {
throw new Error("No trip details found");
}
const ui = typedUi<typeof ComponentMap>(config);
const llm = new ChatOpenAI({ model: "gpt-4o", temperature: 0 }).bindTools(
[
{
name: "trip-planner",
description: "A series of actions to take for planning a trip",
schema,
},
],
{
tool_choice: "trip-planner",
},
);
const response = await llm.invoke([
{
role: "system",
content:
"You are an AI assistant who helps users book trips. Use the user's most recent message(s) to contextually generate a response.",
},
...state.messages,
]);
const tripPlan = response.tool_calls?.[0]?.args as
| z.infer<typeof schema>
| undefined;
if (!tripPlan) {
throw new Error("No trip plan found");
}
if (tripPlan.listAccommodations) {
// TODO: Replace with an accommodations list UI component
ui.write("accommodations-list", { tripDetails: state.tripDetails });
}
if (tripPlan.bookAccommodation && tripPlan.accommodationName) {
// TODO: Replace with a book accommodation UI component
ui.write("book-accommodation", {
tripDetails: state.tripDetails,
accommodationName: tripPlan.accommodationName,
});
}
if (tripPlan.listRestaurants) {
// TODO: Replace with a restaurants list UI component
ui.write("restaurants-list", { tripDetails: state.tripDetails });
}
if (tripPlan.bookRestaurant && tripPlan.restaurantName) {
// TODO: Replace with a book restaurant UI component
ui.write("book-restaurant", {
tripDetails: state.tripDetails,
restaurantName: tripPlan.restaurantName,
});
}
return {
messages: [response],
// TODO: Fix the ui return type.
ui: ui.collect as any[],
timestamp: Date.now(),
};
}