From da396ac83ea1b688c473ea3af5a183f8cd84f06e Mon Sep 17 00:00:00 2001 From: bracesproul Date: Mon, 3 Mar 2025 16:51:46 -0800 Subject: [PATCH] feat: Basic trip planner agent --- agent/agent.tsx | 26 ++-- agent/stockbroker/types.ts | 1 - agent/trip-planner/index.tsx | 51 ++++++++ agent/trip-planner/nodes/classify.ts | 70 ++++++++++ agent/trip-planner/nodes/extraction.tsx | 122 ++++++++++++++++++ agent/trip-planner/nodes/tools.tsx | 118 +++++++++++++++++ agent/trip-planner/types.ts | 18 +++ agent/types.ts | 2 +- agent/uis/index.tsx | 12 +- .../portfolio-view/index.css | 0 .../portfolio-view/index.tsx | 0 .../{ => stockbroker}/stock-price/index.css | 0 .../{ => stockbroker}/stock-price/index.tsx | 0 .../accommodations-list/index.css | 1 + .../accommodations-list/index.tsx | 9 ++ .../trip-planner/book-accommodation/index.css | 1 + .../trip-planner/book-accommodation/index.tsx | 15 +++ .../trip-planner/book-restaurant/index.css | 1 + .../trip-planner/book-restaurant/index.tsx | 15 +++ .../trip-planner/restaurants-list/index.css | 1 + .../trip-planner/restaurants-list/index.tsx | 9 ++ agent/utils/format-messages.ts | 12 ++ 22 files changed, 471 insertions(+), 13 deletions(-) create mode 100644 agent/trip-planner/index.tsx create mode 100644 agent/trip-planner/nodes/classify.ts create mode 100644 agent/trip-planner/nodes/extraction.tsx create mode 100644 agent/trip-planner/nodes/tools.tsx create mode 100644 agent/trip-planner/types.ts rename agent/uis/{ => stockbroker}/portfolio-view/index.css (100%) rename agent/uis/{ => stockbroker}/portfolio-view/index.tsx (100%) rename agent/uis/{ => stockbroker}/stock-price/index.css (100%) rename agent/uis/{ => stockbroker}/stock-price/index.tsx (100%) create mode 100644 agent/uis/trip-planner/accommodations-list/index.css create mode 100644 agent/uis/trip-planner/accommodations-list/index.tsx create mode 100644 agent/uis/trip-planner/book-accommodation/index.css create mode 100644 agent/uis/trip-planner/book-accommodation/index.tsx create mode 100644 agent/uis/trip-planner/book-restaurant/index.css create mode 100644 agent/uis/trip-planner/book-restaurant/index.tsx create mode 100644 agent/uis/trip-planner/restaurants-list/index.css create mode 100644 agent/uis/trip-planner/restaurants-list/index.tsx create mode 100644 agent/utils/format-messages.ts diff --git a/agent/agent.tsx b/agent/agent.tsx index bb7b667..f6d5c3b 100644 --- a/agent/agent.tsx +++ b/agent/agent.tsx @@ -4,18 +4,22 @@ import { z } from "zod"; import { GenerativeUIAnnotation, GenerativeUIState } from "./types"; import { stockbrokerGraph } from "./stockbroker"; import { ChatOpenAI } from "@langchain/openai"; +import { tripPlannerGraph } from "./trip-planner"; + +const allToolDescriptions = `- stockbroker: can fetch the price of a ticker, purchase/sell a ticker, or get the user's portfolio +- tripPlanner: helps the user plan their trip. it can suggest restaurants, and places to stay in any given location.`; async function router( state: GenerativeUIState, ): Promise> { const routerDescription = `The route to take based on the user's input. - stockbroker: can fetch the price of a ticker, purchase/sell a ticker, or get the user's portfolio -- weather: can fetch the current weather conditions for a location +- tripPlanner: helps the user plan their trip. it can suggest restaurants, and places to stay in any given location. - generalInput: handles all other cases where the above tools don't apply `; const routerSchema = z.object({ route: z - .enum(["stockbroker", "weather", "generalInput"]) + .enum(["stockbroker", "tripPlanner", "generalInput"]) .describe(routerDescription), }); const routerTool = { @@ -61,13 +65,19 @@ You should analyze the user's input, and choose the appropriate tool to use.`; function handleRoute( state: GenerativeUIState, -): "stockbroker" | "weather" | "generalInput" { +): "stockbroker" | "tripPlanner" | "generalInput" { return state.next; } async function handleGeneralInput(state: GenerativeUIState) { const llm = new ChatOpenAI({ model: "gpt-4o-mini", temperature: 0 }); - const response = await llm.invoke(state.messages); + const response = await llm.invoke([ + { + role: "system", + content: `You are an AI assistant.\nIf the user asks what you can do, describe these tools. Otherwise, just answer as normal.\n\n${allToolDescriptions}`, + }, + ...state.messages, + ]); return { messages: [response], @@ -77,19 +87,17 @@ async function handleGeneralInput(state: GenerativeUIState) { const builder = new StateGraph(GenerativeUIAnnotation) .addNode("router", router) .addNode("stockbroker", stockbrokerGraph) - .addNode("weather", () => { - throw new Error("Weather not implemented"); - }) + .addNode("tripPlanner", tripPlannerGraph) .addNode("generalInput", handleGeneralInput) .addConditionalEdges("router", handleRoute, [ "stockbroker", - "weather", + "tripPlanner", "generalInput", ]) .addEdge(START, "router") .addEdge("stockbroker", END) - .addEdge("weather", END) + .addEdge("tripPlanner", END) .addEdge("generalInput", END); export const graph = builder.compile(); diff --git a/agent/stockbroker/types.ts b/agent/stockbroker/types.ts index bd32d2e..ebe3fdb 100644 --- a/agent/stockbroker/types.ts +++ b/agent/stockbroker/types.ts @@ -5,7 +5,6 @@ export const StockbrokerAnnotation = Annotation.Root({ messages: GenerativeUIAnnotation.spec.messages, ui: GenerativeUIAnnotation.spec.ui, timestamp: GenerativeUIAnnotation.spec.timestamp, - next: Annotation<"stockbroker" | "weather">(), }); export type StockbrokerState = typeof StockbrokerAnnotation.State; diff --git a/agent/trip-planner/index.tsx b/agent/trip-planner/index.tsx new file mode 100644 index 0000000..60f8f25 --- /dev/null +++ b/agent/trip-planner/index.tsx @@ -0,0 +1,51 @@ +import { StateGraph, START, END } from "@langchain/langgraph"; +import { TripPlannerAnnotation, TripPlannerState } from "./types"; +import { extraction } from "./nodes/extraction"; +import { callTools } from "./nodes/tools"; +import { classify } from "./nodes/classify"; + +function routeStart(state: TripPlannerState): "classify" | "extraction" { + if (!state.tripDetails) { + return "extraction"; + } + + return "classify"; +} + +function routeAfterClassifying( + state: TripPlannerState, +): "callTools" | "extraction" { + // if `tripDetails` is undefined, this means they are not relevant to the conversation + if (!state.tripDetails) { + return "extraction"; + } + + // otherwise, they are relevant, and we should route to callTools + return "callTools"; +} + +function routeAfterExtraction( + state: TripPlannerState, +): "callTools" | typeof END { + // if `tripDetails` is undefined, this means they're missing some fields. + if (!state.tripDetails) { + return END; + } + + return "callTools"; +} + +const builder = new StateGraph(TripPlannerAnnotation) + .addNode("classify", classify) + .addNode("extraction", extraction) + .addNode("callTools", callTools) + .addConditionalEdges(START, routeStart, ["classify", "extraction"]) + .addConditionalEdges("classify", routeAfterClassifying, [ + "callTools", + "extraction", + ]) + .addConditionalEdges("extraction", routeAfterExtraction, ["callTools", END]) + .addEdge("callTools", END); + +export const tripPlannerGraph = builder.compile(); +tripPlannerGraph.name = "Trip Planner"; diff --git a/agent/trip-planner/nodes/classify.ts b/agent/trip-planner/nodes/classify.ts new file mode 100644 index 0000000..945c056 --- /dev/null +++ b/agent/trip-planner/nodes/classify.ts @@ -0,0 +1,70 @@ +import { ChatOpenAI } from "@langchain/openai"; +import { TripPlannerState } from "../types"; +import { z } from "zod"; +import { formatMessages } from "agent/utils/format-messages"; + +export async function classify( + state: TripPlannerState, +): Promise> { + if (!state.tripDetails) { + // Can not classify if tripDetails are undefined + return {}; + } + + const schema = z.object({ + isRelevant: z + .boolean() + .describe( + "Whether the trip details are still relevant to the user's request.", + ), + }); + + const model = new ChatOpenAI({ model: "gpt-4o", temperature: 0 }).bindTools( + [ + { + name: "classify", + description: + "A tool to classify whether or not the trip details are still relevant to the user's request.", + schema, + }, + ], + { + tool_choice: "classify", + }, + ); + + const prompt = `You're an AI assistant for planning trips. The user has already specified the following details for their trip: +- location - ${state.tripDetails.location} +- startDate - ${state.tripDetails.startDate} +- endDate - ${state.tripDetails.endDate} +- numberOfGuests - ${state.tripDetails.numberOfGuests} + +Your task is to carefully read over the user's conversation, and determine if their trip details are still relevant to their most recent request. +You should set is relevant to false if they are now asking about a new location, trip duration, or number of guests. +If they do NOT change their request details (or they never specified them), please set is relevant to true. +`; + + const humanMessage = `Here is the entire conversation so far:\n${formatMessages(state.messages)}`; + + const response = await model.invoke([ + { role: "system", content: prompt }, + { role: "human", content: humanMessage }, + ]); + + const classificationDetails = response.tool_calls?.[0]?.args as + | z.infer + | undefined; + + if (!classificationDetails) { + throw new Error("Could not classify trip details"); + } + + if (!classificationDetails.isRelevant) { + return { + tripDetails: undefined, + }; + } + + // If it is relevant, return the state unchanged + return {}; +} diff --git a/agent/trip-planner/nodes/extraction.tsx b/agent/trip-planner/nodes/extraction.tsx new file mode 100644 index 0000000..5019027 --- /dev/null +++ b/agent/trip-planner/nodes/extraction.tsx @@ -0,0 +1,122 @@ +import { ChatOpenAI } from "@langchain/openai"; +import { TripDetails, TripPlannerState } from "../types"; +import { z } from "zod"; +import { formatMessages } from "agent/utils/format-messages"; + +function calculateDates( + startDate: string | undefined, + endDate: string | undefined, +): { startDate: Date; endDate: Date } { + const now = new Date(); + + if (!startDate && !endDate) { + // Both undefined: 4 and 5 weeks in future + const start = new Date(now); + start.setDate(start.getDate() + 28); // 4 weeks + const end = new Date(now); + end.setDate(end.getDate() + 35); // 5 weeks + return { startDate: start, endDate: end }; + } + + if (startDate && !endDate) { + // Only start defined: end is 1 week after + const start = new Date(startDate); + const end = new Date(start); + end.setDate(end.getDate() + 7); + return { startDate: start, endDate: end }; + } + + if (!startDate && endDate) { + // Only end defined: start is 1 week before + const end = new Date(endDate); + const start = new Date(end); + start.setDate(start.getDate() - 7); + return { startDate: start, endDate: end }; + } + + // Both defined: use as is + return { + startDate: new Date(startDate!), + endDate: new Date(endDate!), + }; +} + +export async function extraction( + state: TripPlannerState, +): Promise> { + const schema = z.object({ + location: z + .string() + .describe( + "The location to plan the trip for. Can be a city, state, or country.", + ), + startDate: z + .string() + .optional() + .describe("The start date of the trip. Should be in YYYY-MM-DD format"), + endDate: z + .string() + .optional() + .describe("The end date of the trip. Should be in YYYY-MM-DD format"), + numberOfGuests: z + .number() + .optional() + .describe("The number of guests for the trip"), + }); + + const model = new ChatOpenAI({ model: "gpt-4o", temperature: 0 }).bindTools([ + { + name: "extract", + description: "A tool to extract information from a user's request.", + schema: schema, + }, + ]); + + const prompt = `You're an AI assistant for planning trips. The user has requested information about a trip they want to go on. +Before you can help them, you need to extract the following information from their request: +- location - The location to plan the trip for. Can be a city, state, or country. +- startDate - The start date of the trip. Should be in YYYY-MM-DD format. Optional +- endDate - The end date of the trip. Should be in YYYY-MM-DD format. Optional +- numberOfGuests - The number of guests for the trip. Optional + +You are provided with the ENTIRE conversation history between you, and the user. Use these messages to extract the necessary information. + +Do NOT guess, or make up any information. If the user did NOT specify a location, please respond with a request for them to specify the location. +It should be a single sentence, along the lines of "Please specify the location for the trip you want to go on". + +Extract only what is specified by the user. It is okay to leave fields blank if the user did not specify them. +`; + + const humanMessage = `Here is the entire conversation so far:\n${formatMessages(state.messages)}`; + + const response = await model.invoke([ + { role: "system", content: prompt }, + { role: "human", content: humanMessage }, + ]); + + const extractedDetails = response.tool_calls?.[0]?.args as + | z.infer + | undefined; + + if (!extractedDetails) { + return { + messages: [response], + }; + } + + const { startDate, endDate } = calculateDates( + extractedDetails.startDate, + extractedDetails.endDate, + ); + + const extractionDetailsWithDefaults: TripDetails = { + startDate, + endDate, + numberOfGuests: extractedDetails.numberOfGuests ?? 2, + location: extractedDetails.location, + }; + + return { + tripDetails: extractionDetailsWithDefaults, + }; +} diff --git a/agent/trip-planner/nodes/tools.tsx b/agent/trip-planner/nodes/tools.tsx new file mode 100644 index 0000000..9d45486 --- /dev/null +++ b/agent/trip-planner/nodes/tools.tsx @@ -0,0 +1,118 @@ +import { TripPlannerState } from "../types"; +import { ChatOpenAI } from "@langchain/openai"; +import { typedUi } from "@langchain/langgraph-sdk/react-ui/server"; +import type ComponentMap from "../../uis/index"; +import { z } from "zod"; +import { LangGraphRunnableConfig } from "@langchain/langgraph"; + +const schema = z.object({ + listAccommodations: z + .boolean() + .optional() + .describe( + "Whether or not the user has requested a list of accommodations for their trip.", + ), + bookAccommodation: z + .boolean() + .optional() + .describe( + "Whether or not the user has requested to book a reservation for an accommodation. If true, you MUST also set the 'accommodationName' field", + ), + accommodationName: z + .string() + .optional() + .describe( + "The name of the accommodation to book a reservation for. Only required if the 'bookAccommodation' field is true.", + ), + + listRestaurants: z + .boolean() + .optional() + .describe( + "Whether or not the user has requested a list of restaurants for their trip.", + ), + bookRestaurant: z + .boolean() + .optional() + .describe( + "Whether or not the user has requested to book a reservation for a restaurant. If true, you MUST also set the 'restaurantName' field", + ), + restaurantName: z + .string() + .optional() + .describe( + "The name of the restaurant to book a reservation for. Only required if the 'bookRestaurant' field is true.", + ), +}); + +export async function callTools( + state: TripPlannerState, + config: LangGraphRunnableConfig, +): Promise> { + if (!state.tripDetails) { + throw new Error("No trip details found"); + } + + const ui = typedUi(config); + + const llm = new ChatOpenAI({ model: "gpt-4o", temperature: 0 }).bindTools( + [ + { + name: "trip-planner", + description: "A series of actions to take for planning a trip", + schema, + }, + ], + { + tool_choice: "trip-planner", + }, + ); + + const response = await llm.invoke([ + { + role: "system", + content: + "You are an AI assistant who helps users book trips. Use the user's most recent message(s) to contextually generate a response.", + }, + ...state.messages, + ]); + + const tripPlan = response.tool_calls?.[0]?.args as + | z.infer + | undefined; + if (!tripPlan) { + throw new Error("No trip plan found"); + } + + if (tripPlan.listAccommodations) { + // TODO: Replace with an accommodations list UI component + ui.write("accommodations-list", { tripDetails: state.tripDetails }); + } + if (tripPlan.bookAccommodation && tripPlan.accommodationName) { + // TODO: Replace with a book accommodation UI component + ui.write("book-accommodation", { + tripDetails: state.tripDetails, + accommodationName: tripPlan.accommodationName, + }); + } + + if (tripPlan.listRestaurants) { + // TODO: Replace with a restaurants list UI component + ui.write("restaurants-list", { tripDetails: state.tripDetails }); + } + + if (tripPlan.bookRestaurant && tripPlan.restaurantName) { + // TODO: Replace with a book restaurant UI component + ui.write("book-restaurant", { + tripDetails: state.tripDetails, + restaurantName: tripPlan.restaurantName, + }); + } + + return { + messages: [response], + // TODO: Fix the ui return type. + ui: ui.collect as any[], + timestamp: Date.now(), + }; +} diff --git a/agent/trip-planner/types.ts b/agent/trip-planner/types.ts new file mode 100644 index 0000000..9e546bb --- /dev/null +++ b/agent/trip-planner/types.ts @@ -0,0 +1,18 @@ +import { Annotation } from "@langchain/langgraph"; +import { GenerativeUIAnnotation } from "../types"; + +export type TripDetails = { + location: string; + startDate: Date; + endDate: Date; + numberOfGuests: number; +}; + +export const TripPlannerAnnotation = Annotation.Root({ + messages: GenerativeUIAnnotation.spec.messages, + ui: GenerativeUIAnnotation.spec.ui, + timestamp: GenerativeUIAnnotation.spec.timestamp, + tripDetails: Annotation(), +}); + +export type TripPlannerState = typeof TripPlannerAnnotation.State; diff --git a/agent/types.ts b/agent/types.ts index 40b2d22..a5355bb 100644 --- a/agent/types.ts +++ b/agent/types.ts @@ -5,7 +5,7 @@ export const GenerativeUIAnnotation = Annotation.Root({ messages: MessagesAnnotation.spec["messages"], ui: Annotation({ default: () => [], reducer: uiMessageReducer }), timestamp: Annotation, - next: Annotation<"stockbroker" | "weather" | "generalInput">(), + next: Annotation<"stockbroker" | "tripPlanner" | "generalInput">(), }); export type GenerativeUIState = typeof GenerativeUIAnnotation.State; diff --git a/agent/uis/index.tsx b/agent/uis/index.tsx index a9259be..ad76b8c 100644 --- a/agent/uis/index.tsx +++ b/agent/uis/index.tsx @@ -1,8 +1,16 @@ -import StockPrice from "./stock-price"; -import PortfolioView from "./portfolio-view"; +import StockPrice from "./stockbroker/stock-price"; +import PortfolioView from "./stockbroker/portfolio-view"; +import AccommodationsList from "./trip-planner/accommodations-list"; +import BookAccommodation from "./trip-planner/book-accommodation"; +import RestaurantsList from "./trip-planner/restaurants-list"; +import BookRestaurant from "./trip-planner/book-restaurant"; const ComponentMap = { "stock-price": StockPrice, portfolio: PortfolioView, + "accommodations-list": AccommodationsList, + "book-accommodation": BookAccommodation, + "restaurants-list": RestaurantsList, + "book-restaurant": BookRestaurant, } as const; export default ComponentMap; diff --git a/agent/uis/portfolio-view/index.css b/agent/uis/stockbroker/portfolio-view/index.css similarity index 100% rename from agent/uis/portfolio-view/index.css rename to agent/uis/stockbroker/portfolio-view/index.css diff --git a/agent/uis/portfolio-view/index.tsx b/agent/uis/stockbroker/portfolio-view/index.tsx similarity index 100% rename from agent/uis/portfolio-view/index.tsx rename to agent/uis/stockbroker/portfolio-view/index.tsx diff --git a/agent/uis/stock-price/index.css b/agent/uis/stockbroker/stock-price/index.css similarity index 100% rename from agent/uis/stock-price/index.css rename to agent/uis/stockbroker/stock-price/index.css diff --git a/agent/uis/stock-price/index.tsx b/agent/uis/stockbroker/stock-price/index.tsx similarity index 100% rename from agent/uis/stock-price/index.tsx rename to agent/uis/stockbroker/stock-price/index.tsx diff --git a/agent/uis/trip-planner/accommodations-list/index.css b/agent/uis/trip-planner/accommodations-list/index.css new file mode 100644 index 0000000..f1d8c73 --- /dev/null +++ b/agent/uis/trip-planner/accommodations-list/index.css @@ -0,0 +1 @@ +@import "tailwindcss"; diff --git a/agent/uis/trip-planner/accommodations-list/index.tsx b/agent/uis/trip-planner/accommodations-list/index.tsx new file mode 100644 index 0000000..da97aad --- /dev/null +++ b/agent/uis/trip-planner/accommodations-list/index.tsx @@ -0,0 +1,9 @@ +import { TripDetails } from "../../../trip-planner/types"; + +export default function AccommodationsList({ + tripDetails, +}: { + tripDetails: TripDetails; +}) { + return
Accommodations list for {JSON.stringify(tripDetails)}
; +} diff --git a/agent/uis/trip-planner/book-accommodation/index.css b/agent/uis/trip-planner/book-accommodation/index.css new file mode 100644 index 0000000..f1d8c73 --- /dev/null +++ b/agent/uis/trip-planner/book-accommodation/index.css @@ -0,0 +1 @@ +@import "tailwindcss"; diff --git a/agent/uis/trip-planner/book-accommodation/index.tsx b/agent/uis/trip-planner/book-accommodation/index.tsx new file mode 100644 index 0000000..9c98a5e --- /dev/null +++ b/agent/uis/trip-planner/book-accommodation/index.tsx @@ -0,0 +1,15 @@ +import { TripDetails } from "../../../trip-planner/types"; + +export default function BookAccommodation({ + tripDetails, + accommodationName, +}: { + tripDetails: TripDetails; + accommodationName: string; +}) { + return ( +
+ Book accommodation {accommodationName} for {JSON.stringify(tripDetails)} +
+ ); +} diff --git a/agent/uis/trip-planner/book-restaurant/index.css b/agent/uis/trip-planner/book-restaurant/index.css new file mode 100644 index 0000000..f1d8c73 --- /dev/null +++ b/agent/uis/trip-planner/book-restaurant/index.css @@ -0,0 +1 @@ +@import "tailwindcss"; diff --git a/agent/uis/trip-planner/book-restaurant/index.tsx b/agent/uis/trip-planner/book-restaurant/index.tsx new file mode 100644 index 0000000..010ba9b --- /dev/null +++ b/agent/uis/trip-planner/book-restaurant/index.tsx @@ -0,0 +1,15 @@ +import { TripDetails } from "../../../trip-planner/types"; + +export default function BookRestaurant({ + tripDetails, + restaurantName, +}: { + tripDetails: TripDetails; + restaurantName: string; +}) { + return ( +
+ Book restaurant {restaurantName} for {JSON.stringify(tripDetails)} +
+ ); +} diff --git a/agent/uis/trip-planner/restaurants-list/index.css b/agent/uis/trip-planner/restaurants-list/index.css new file mode 100644 index 0000000..f1d8c73 --- /dev/null +++ b/agent/uis/trip-planner/restaurants-list/index.css @@ -0,0 +1 @@ +@import "tailwindcss"; diff --git a/agent/uis/trip-planner/restaurants-list/index.tsx b/agent/uis/trip-planner/restaurants-list/index.tsx new file mode 100644 index 0000000..a283c85 --- /dev/null +++ b/agent/uis/trip-planner/restaurants-list/index.tsx @@ -0,0 +1,9 @@ +import { TripDetails } from "../../../trip-planner/types"; + +export default function RestaurantsList({ + tripDetails, +}: { + tripDetails: TripDetails; +}) { + return
Restaurants list for {JSON.stringify(tripDetails)}
; +} diff --git a/agent/utils/format-messages.ts b/agent/utils/format-messages.ts new file mode 100644 index 0000000..91f3bf4 --- /dev/null +++ b/agent/utils/format-messages.ts @@ -0,0 +1,12 @@ +import { BaseMessage } from "@langchain/core/messages"; + +export function formatMessages(messages: BaseMessage[]): string { + return messages + .map((m, i) => { + const role = m.getType(); + const contentString = + typeof m.content === "string" ? m.content : JSON.stringify(m.content); + return `<${role} index="${i}">\n${contentString}\n`; + }) + .join("\n"); +}