2025-02-27 15:41:47 -08:00
|
|
|
import { StateGraph, START, END } from "@langchain/langgraph";
|
|
|
|
|
import { ChatGoogleGenerativeAI } from "@langchain/google-genai";
|
|
|
|
|
import { z } from "zod";
|
|
|
|
|
import { GenerativeUIAnnotation, GenerativeUIState } from "./types";
|
|
|
|
|
import { stockbrokerGraph } from "./stockbroker";
|
2025-02-18 19:35:46 +01:00
|
|
|
import { ChatOpenAI } from "@langchain/openai";
|
|
|
|
|
|
2025-02-27 15:41:47 -08:00
|
|
|
async function router(
|
2025-03-03 12:31:27 -08:00
|
|
|
state: GenerativeUIState,
|
2025-02-27 15:41:47 -08:00
|
|
|
): Promise<Partial<GenerativeUIState>> {
|
|
|
|
|
const routerDescription = `The route to take based on the user's input.
|
|
|
|
|
- stockbroker: can fetch the price of a ticker, purchase/sell a ticker, or get the user's portfolio
|
|
|
|
|
- weather: can fetch the current weather conditions for a location
|
|
|
|
|
- generalInput: handles all other cases where the above tools don't apply
|
|
|
|
|
`;
|
|
|
|
|
const routerSchema = z.object({
|
|
|
|
|
route: z
|
|
|
|
|
.enum(["stockbroker", "weather", "generalInput"])
|
|
|
|
|
.describe(routerDescription),
|
|
|
|
|
});
|
|
|
|
|
const routerTool = {
|
|
|
|
|
name: "router",
|
|
|
|
|
description: "A tool to route the user's query to the appropriate tool.",
|
|
|
|
|
schema: routerSchema,
|
|
|
|
|
};
|
2025-02-18 19:35:46 +01:00
|
|
|
|
2025-02-27 15:41:47 -08:00
|
|
|
const llm = new ChatGoogleGenerativeAI({
|
|
|
|
|
model: "gemini-2.0-flash",
|
|
|
|
|
temperature: 0,
|
2025-03-03 14:50:43 +01:00
|
|
|
})
|
|
|
|
|
.bindTools([routerTool], { tool_choice: "router" })
|
|
|
|
|
.withConfig({ tags: ["langsmith:nostream"] });
|
2025-02-27 14:08:24 -08:00
|
|
|
|
2025-02-27 15:41:47 -08:00
|
|
|
const prompt = `You're a highly helpful AI assistant, tasked with routing the user's query to the appropriate tool.
|
|
|
|
|
You should analyze the user's input, and choose the appropriate tool to use.`;
|
2025-02-27 14:08:24 -08:00
|
|
|
|
2025-03-02 19:09:59 +01:00
|
|
|
const recentHumanMessage = state.messages.findLast(
|
2025-03-03 12:31:27 -08:00
|
|
|
(m) => m.getType() === "human",
|
2025-03-02 19:09:59 +01:00
|
|
|
);
|
2025-02-18 19:35:46 +01:00
|
|
|
|
2025-02-27 15:41:47 -08:00
|
|
|
if (!recentHumanMessage) {
|
|
|
|
|
throw new Error("No human message found in state");
|
|
|
|
|
}
|
2025-02-18 19:35:46 +01:00
|
|
|
|
2025-02-27 15:41:47 -08:00
|
|
|
const response = await llm.invoke([
|
|
|
|
|
{ role: "system", content: prompt },
|
|
|
|
|
recentHumanMessage,
|
|
|
|
|
]);
|
2025-02-18 19:35:46 +01:00
|
|
|
|
2025-02-27 15:41:47 -08:00
|
|
|
const toolCall = response.tool_calls?.[0]?.args as
|
|
|
|
|
| z.infer<typeof routerSchema>
|
|
|
|
|
| undefined;
|
|
|
|
|
if (!toolCall) {
|
|
|
|
|
throw new Error("No tool call found in response");
|
|
|
|
|
}
|
2025-02-18 19:35:46 +01:00
|
|
|
|
2025-02-27 15:41:47 -08:00
|
|
|
return {
|
|
|
|
|
next: toolCall.route,
|
|
|
|
|
};
|
|
|
|
|
}
|
2025-02-18 19:35:46 +01:00
|
|
|
|
2025-02-27 15:41:47 -08:00
|
|
|
function handleRoute(
|
2025-03-03 12:31:27 -08:00
|
|
|
state: GenerativeUIState,
|
2025-02-27 15:41:47 -08:00
|
|
|
): "stockbroker" | "weather" | "generalInput" {
|
|
|
|
|
return state.next;
|
|
|
|
|
}
|
2025-02-18 19:35:46 +01:00
|
|
|
|
2025-02-27 15:41:47 -08:00
|
|
|
async function handleGeneralInput(state: GenerativeUIState) {
|
|
|
|
|
const llm = new ChatOpenAI({ model: "gpt-4o-mini", temperature: 0 });
|
2025-03-02 19:09:59 +01:00
|
|
|
const response = await llm.invoke(state.messages);
|
2025-02-27 14:08:24 -08:00
|
|
|
|
2025-02-27 15:41:47 -08:00
|
|
|
return {
|
|
|
|
|
messages: [response],
|
|
|
|
|
};
|
|
|
|
|
}
|
2025-02-18 19:35:46 +01:00
|
|
|
|
2025-02-27 15:41:47 -08:00
|
|
|
const builder = new StateGraph(GenerativeUIAnnotation)
|
|
|
|
|
.addNode("router", router)
|
|
|
|
|
.addNode("stockbroker", stockbrokerGraph)
|
|
|
|
|
.addNode("weather", () => {
|
|
|
|
|
throw new Error("Weather not implemented");
|
2025-02-18 19:35:46 +01:00
|
|
|
})
|
2025-02-27 15:41:47 -08:00
|
|
|
.addNode("generalInput", handleGeneralInput)
|
2025-02-18 19:35:46 +01:00
|
|
|
|
2025-02-27 15:41:47 -08:00
|
|
|
.addConditionalEdges("router", handleRoute, [
|
|
|
|
|
"stockbroker",
|
|
|
|
|
"weather",
|
|
|
|
|
"generalInput",
|
|
|
|
|
])
|
|
|
|
|
.addEdge(START, "router")
|
|
|
|
|
.addEdge("stockbroker", END)
|
|
|
|
|
.addEdge("weather", END)
|
|
|
|
|
.addEdge("generalInput", END);
|
2025-02-18 19:35:46 +01:00
|
|
|
|
2025-02-27 15:41:47 -08:00
|
|
|
export const graph = builder.compile();
|
|
|
|
|
graph.name = "Generative UI Agent";
|