This commit is contained in:
bracesproul
2025-02-27 15:41:47 -08:00
parent c7b61071a1
commit 3f4aad48e6
14 changed files with 254 additions and 104 deletions

View File

@@ -1,100 +1,94 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import {
StateGraph,
MessagesAnnotation,
START,
Annotation,
} from "@langchain/langgraph";
import { SystemMessage } from "@langchain/core/messages";
import { StateGraph, START, END } from "@langchain/langgraph";
import { ChatGoogleGenerativeAI } from "@langchain/google-genai";
import { z } from "zod";
import { GenerativeUIAnnotation, GenerativeUIState } from "./types";
import { stockbrokerGraph } from "./stockbroker";
import { ChatOpenAI } from "@langchain/openai";
import { typedUi } from "@langchain/langgraph-sdk/react-ui/server";
import { uiMessageReducer } from "@langchain/langgraph-sdk/react-ui/types";
import type ComponentMap from "./uis/index";
import { z, ZodTypeAny } from "zod";
const llm = new ChatOpenAI({ model: "gpt-4o-mini", temperature: 0 });
async function router(
state: GenerativeUIState,
): Promise<Partial<GenerativeUIState>> {
const routerDescription = `The route to take based on the user's input.
- stockbroker: can fetch the price of a ticker, purchase/sell a ticker, or get the user's portfolio
- weather: can fetch the current weather conditions for a location
- generalInput: handles all other cases where the above tools don't apply
`;
const routerSchema = z.object({
route: z
.enum(["stockbroker", "weather", "generalInput"])
.describe(routerDescription),
});
const routerTool = {
name: "router",
description: "A tool to route the user's query to the appropriate tool.",
schema: routerSchema,
};
const getStockPriceSchema = z.object({
ticker: z.string().describe("The ticker symbol of the company"),
});
const getPortfolioSchema = z.object({
get_portfolio: z.boolean().describe("Should be true."),
});
const llm = new ChatGoogleGenerativeAI({
model: "gemini-2.0-flash",
temperature: 0,
}).bindTools([routerTool], { tool_choice: "router" });
const STOCKBROKER_TOOLS = [
{
name: "get_stock_price",
description: "A tool to get the stock price of a company",
schema: getStockPriceSchema,
},
{
name: "get_portfolio",
description:
"A tool to get the user's portfolio details. Only call this tool if the user requests their portfolio details.",
schema: getPortfolioSchema,
},
];
const prompt = `You're a highly helpful AI assistant, tasked with routing the user's query to the appropriate tool.
You should analyze the user's input, and choose the appropriate tool to use.`;
interface ToolCall {
name: string;
args: Record<string, any>;
id?: string;
type?: "tool_call";
const recentHumanMessage = state.messages
.reverse()
.find((m) => m.getType() === "human");
if (!recentHumanMessage) {
throw new Error("No human message found in state");
}
const response = await llm.invoke([
{ role: "system", content: prompt },
recentHumanMessage,
]);
const toolCall = response.tool_calls?.[0]?.args as
| z.infer<typeof routerSchema>
| undefined;
if (!toolCall) {
throw new Error("No tool call found in response");
}
return {
next: toolCall.route,
};
}
function findToolCall<Name extends string>(name: Name) {
return <Args extends ZodTypeAny>(
x: ToolCall,
): x is { name: Name; args: z.infer<Args> } => x.name === name;
function handleRoute(
state: GenerativeUIState,
): "stockbroker" | "weather" | "generalInput" {
return state.next;
}
const builder = new StateGraph(
Annotation.Root({
messages: MessagesAnnotation.spec["messages"],
ui: Annotation({ default: () => [], reducer: uiMessageReducer }),
timestamp: Annotation<number>,
}),
)
.addNode("agent", async (state, config) => {
const ui = typedUi<typeof ComponentMap>(config);
async function handleGeneralInput(state: GenerativeUIState) {
const llm = new ChatOpenAI({ model: "gpt-4o-mini", temperature: 0 });
const response = await llm.invoke(state.messages);
const message = await llm
.bindTools(STOCKBROKER_TOOLS)
.invoke([
new SystemMessage(
"You are a stockbroker agent that uses tools to get the stock price of a company",
),
...state.messages,
]);
return {
messages: [response],
};
}
const stockbrokerToolCall = message.tool_calls?.find(
findToolCall("get_stock_price")<typeof getStockPriceSchema>,
);
const portfolioToolCall = message.tool_calls?.find(
findToolCall("get_portfolio")<typeof getStockPriceSchema>,
);
if (stockbrokerToolCall) {
const instruction = `The stock price of ${
stockbrokerToolCall.args.ticker
} is ${Math.random() * 100}`;
ui.write("stock-price", { instruction, logo: "hey" });
}
if (portfolioToolCall) {
ui.write("portfolio-view", {});
}
return { messages: message, ui: ui.collect, timestamp: Date.now() };
const builder = new StateGraph(GenerativeUIAnnotation)
.addNode("router", router)
.addNode("stockbroker", stockbrokerGraph)
.addNode("weather", () => {
throw new Error("Weather not implemented");
})
.addEdge(START, "agent");
.addNode("generalInput", handleGeneralInput)
.addConditionalEdges("router", handleRoute, [
"stockbroker",
"weather",
"generalInput",
])
.addEdge(START, "router")
.addEdge("stockbroker", END)
.addEdge("weather", END)
.addEdge("generalInput", END);
export const graph = builder.compile();
// event handler of evetns ˇtypes)
// event handler for specific node -> handle node
// TODO:
// - Send run ID & additional metadata for the client to properly use messages (maybe we even have a config)
// - Store that run ID in messages
graph.name = "Generative UI Agent";