Files
agent-chat-ui/agent/agent.ts

121 lines
3.8 KiB
TypeScript
Raw Normal View History

2025-02-27 15:41:47 -08:00
import { StateGraph, START, END } from "@langchain/langgraph";
import { ChatGoogleGenerativeAI } from "@langchain/google-genai";
import { z } from "zod";
import { GenerativeUIAnnotation, GenerativeUIState } from "./types";
import { stockbrokerGraph } from "./stockbroker";
2025-02-18 19:35:46 +01:00
import { ChatOpenAI } from "@langchain/openai";
2025-03-03 16:51:46 -08:00
import { tripPlannerGraph } from "./trip-planner";
2025-03-03 18:09:49 -08:00
import { formatMessages } from "./utils/format-messages";
2025-03-03 16:51:46 -08:00
const allToolDescriptions = `- stockbroker: can fetch the price of a ticker, purchase/sell a ticker, or get the user's portfolio
- tripPlanner: helps the user plan their trip. it can suggest restaurants, and places to stay in any given location.`;
2025-02-18 19:35:46 +01:00
2025-02-27 15:41:47 -08:00
async function router(
state: GenerativeUIState,
2025-02-27 15:41:47 -08:00
): Promise<Partial<GenerativeUIState>> {
const routerDescription = `The route to take based on the user's input.
2025-03-03 16:52:58 -08:00
${allToolDescriptions}
2025-02-27 15:41:47 -08:00
- generalInput: handles all other cases where the above tools don't apply
`;
const routerSchema = z.object({
route: z
2025-03-03 16:51:46 -08:00
.enum(["stockbroker", "tripPlanner", "generalInput"])
2025-02-27 15:41:47 -08:00
.describe(routerDescription),
});
const routerTool = {
name: "router",
description: "A tool to route the user's query to the appropriate tool.",
schema: routerSchema,
};
2025-02-18 19:35:46 +01:00
2025-02-27 15:41:47 -08:00
const llm = new ChatGoogleGenerativeAI({
model: "gemini-2.0-flash",
temperature: 0,
})
.bindTools([routerTool], { tool_choice: "router" })
.withConfig({ tags: ["langsmith:nostream"] });
2025-02-27 14:08:24 -08:00
2025-02-27 15:41:47 -08:00
const prompt = `You're a highly helpful AI assistant, tasked with routing the user's query to the appropriate tool.
You should analyze the user's input, and choose the appropriate tool to use.`;
2025-02-27 14:08:24 -08:00
2025-03-03 18:09:49 -08:00
const allMessagesButLast = state.messages.slice(0, -1);
const lastMessage = state.messages.at(-1);
2025-02-18 19:35:46 +01:00
2025-03-03 18:09:49 -08:00
const formattedPreviousMessages = formatMessages(allMessagesButLast);
const formattedLastMessage = lastMessage ? formatMessages([lastMessage]) : "";
const humanMessage = `Here is the full conversation, excluding the most recent message:
${formattedPreviousMessages}
Here is the most recent message:
${formattedLastMessage}
Please pick the proper route based on the most recent message, in the context of the entire conversation.`;
2025-02-18 19:35:46 +01:00
2025-02-27 15:41:47 -08:00
const response = await llm.invoke([
{ role: "system", content: prompt },
2025-03-03 18:09:49 -08:00
{ role: "user", content: humanMessage },
2025-02-27 15:41:47 -08:00
]);
2025-02-18 19:35:46 +01:00
2025-02-27 15:41:47 -08:00
const toolCall = response.tool_calls?.[0]?.args as
| z.infer<typeof routerSchema>
| undefined;
if (!toolCall) {
throw new Error("No tool call found in response");
}
2025-02-18 19:35:46 +01:00
2025-02-27 15:41:47 -08:00
return {
next: toolCall.route,
};
}
2025-02-18 19:35:46 +01:00
2025-02-27 15:41:47 -08:00
function handleRoute(
state: GenerativeUIState,
2025-03-03 16:51:46 -08:00
): "stockbroker" | "tripPlanner" | "generalInput" {
2025-02-27 15:41:47 -08:00
return state.next;
}
2025-02-18 19:35:46 +01:00
const GENERAL_INPUT_SYSTEM_PROMPT = `You are an AI assistant.
If the user asks what you can do, describe these tools.
${allToolDescriptions}
If the last message is a tool result, describe what the action was, congratulate the user, or send a friendly followup in response to the tool action. Ensure this is a clear and concise message.
Otherwise, just answer as normal.`;
2025-02-27 15:41:47 -08:00
async function handleGeneralInput(state: GenerativeUIState) {
const llm = new ChatOpenAI({ model: "gpt-4o-mini", temperature: 0 });
2025-03-03 16:51:46 -08:00
const response = await llm.invoke([
{
role: "system",
content: GENERAL_INPUT_SYSTEM_PROMPT,
2025-03-03 16:51:46 -08:00
},
...state.messages,
]);
2025-02-27 14:08:24 -08:00
2025-02-27 15:41:47 -08:00
return {
messages: [response],
};
}
2025-02-18 19:35:46 +01:00
2025-02-27 15:41:47 -08:00
const builder = new StateGraph(GenerativeUIAnnotation)
.addNode("router", router)
.addNode("stockbroker", stockbrokerGraph)
2025-03-03 16:51:46 -08:00
.addNode("tripPlanner", tripPlannerGraph)
2025-02-27 15:41:47 -08:00
.addNode("generalInput", handleGeneralInput)
2025-02-18 19:35:46 +01:00
2025-02-27 15:41:47 -08:00
.addConditionalEdges("router", handleRoute, [
"stockbroker",
2025-03-03 16:51:46 -08:00
"tripPlanner",
2025-02-27 15:41:47 -08:00
"generalInput",
])
.addEdge(START, "router")
.addEdge("stockbroker", END)
2025-03-03 16:51:46 -08:00
.addEdge("tripPlanner", END)
2025-02-27 15:41:47 -08:00
.addEdge("generalInput", END);
2025-02-18 19:35:46 +01:00
2025-02-27 15:41:47 -08:00
export const graph = builder.compile();
graph.name = "Generative UI Agent";