split up
This commit is contained in:
166
agent/agent.tsx
166
agent/agent.tsx
@@ -1,100 +1,94 @@
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import {
|
||||
StateGraph,
|
||||
MessagesAnnotation,
|
||||
START,
|
||||
Annotation,
|
||||
} from "@langchain/langgraph";
|
||||
import { SystemMessage } from "@langchain/core/messages";
|
||||
import { StateGraph, START, END } from "@langchain/langgraph";
|
||||
import { ChatGoogleGenerativeAI } from "@langchain/google-genai";
|
||||
import { z } from "zod";
|
||||
import { GenerativeUIAnnotation, GenerativeUIState } from "./types";
|
||||
import { stockbrokerGraph } from "./stockbroker";
|
||||
import { ChatOpenAI } from "@langchain/openai";
|
||||
import { typedUi } from "@langchain/langgraph-sdk/react-ui/server";
|
||||
import { uiMessageReducer } from "@langchain/langgraph-sdk/react-ui/types";
|
||||
import type ComponentMap from "./uis/index";
|
||||
import { z, ZodTypeAny } from "zod";
|
||||
|
||||
const llm = new ChatOpenAI({ model: "gpt-4o-mini", temperature: 0 });
|
||||
async function router(
|
||||
state: GenerativeUIState,
|
||||
): Promise<Partial<GenerativeUIState>> {
|
||||
const routerDescription = `The route to take based on the user's input.
|
||||
- stockbroker: can fetch the price of a ticker, purchase/sell a ticker, or get the user's portfolio
|
||||
- weather: can fetch the current weather conditions for a location
|
||||
- generalInput: handles all other cases where the above tools don't apply
|
||||
`;
|
||||
const routerSchema = z.object({
|
||||
route: z
|
||||
.enum(["stockbroker", "weather", "generalInput"])
|
||||
.describe(routerDescription),
|
||||
});
|
||||
const routerTool = {
|
||||
name: "router",
|
||||
description: "A tool to route the user's query to the appropriate tool.",
|
||||
schema: routerSchema,
|
||||
};
|
||||
|
||||
const getStockPriceSchema = z.object({
|
||||
ticker: z.string().describe("The ticker symbol of the company"),
|
||||
});
|
||||
const getPortfolioSchema = z.object({
|
||||
get_portfolio: z.boolean().describe("Should be true."),
|
||||
});
|
||||
const llm = new ChatGoogleGenerativeAI({
|
||||
model: "gemini-2.0-flash",
|
||||
temperature: 0,
|
||||
}).bindTools([routerTool], { tool_choice: "router" });
|
||||
|
||||
const STOCKBROKER_TOOLS = [
|
||||
{
|
||||
name: "get_stock_price",
|
||||
description: "A tool to get the stock price of a company",
|
||||
schema: getStockPriceSchema,
|
||||
},
|
||||
{
|
||||
name: "get_portfolio",
|
||||
description:
|
||||
"A tool to get the user's portfolio details. Only call this tool if the user requests their portfolio details.",
|
||||
schema: getPortfolioSchema,
|
||||
},
|
||||
];
|
||||
const prompt = `You're a highly helpful AI assistant, tasked with routing the user's query to the appropriate tool.
|
||||
You should analyze the user's input, and choose the appropriate tool to use.`;
|
||||
|
||||
interface ToolCall {
|
||||
name: string;
|
||||
args: Record<string, any>;
|
||||
id?: string;
|
||||
type?: "tool_call";
|
||||
const recentHumanMessage = state.messages
|
||||
.reverse()
|
||||
.find((m) => m.getType() === "human");
|
||||
|
||||
if (!recentHumanMessage) {
|
||||
throw new Error("No human message found in state");
|
||||
}
|
||||
|
||||
const response = await llm.invoke([
|
||||
{ role: "system", content: prompt },
|
||||
recentHumanMessage,
|
||||
]);
|
||||
|
||||
const toolCall = response.tool_calls?.[0]?.args as
|
||||
| z.infer<typeof routerSchema>
|
||||
| undefined;
|
||||
if (!toolCall) {
|
||||
throw new Error("No tool call found in response");
|
||||
}
|
||||
|
||||
return {
|
||||
next: toolCall.route,
|
||||
};
|
||||
}
|
||||
|
||||
function findToolCall<Name extends string>(name: Name) {
|
||||
return <Args extends ZodTypeAny>(
|
||||
x: ToolCall,
|
||||
): x is { name: Name; args: z.infer<Args> } => x.name === name;
|
||||
function handleRoute(
|
||||
state: GenerativeUIState,
|
||||
): "stockbroker" | "weather" | "generalInput" {
|
||||
return state.next;
|
||||
}
|
||||
|
||||
const builder = new StateGraph(
|
||||
Annotation.Root({
|
||||
messages: MessagesAnnotation.spec["messages"],
|
||||
ui: Annotation({ default: () => [], reducer: uiMessageReducer }),
|
||||
timestamp: Annotation<number>,
|
||||
}),
|
||||
)
|
||||
.addNode("agent", async (state, config) => {
|
||||
const ui = typedUi<typeof ComponentMap>(config);
|
||||
async function handleGeneralInput(state: GenerativeUIState) {
|
||||
const llm = new ChatOpenAI({ model: "gpt-4o-mini", temperature: 0 });
|
||||
const response = await llm.invoke(state.messages);
|
||||
|
||||
const message = await llm
|
||||
.bindTools(STOCKBROKER_TOOLS)
|
||||
.invoke([
|
||||
new SystemMessage(
|
||||
"You are a stockbroker agent that uses tools to get the stock price of a company",
|
||||
),
|
||||
...state.messages,
|
||||
]);
|
||||
return {
|
||||
messages: [response],
|
||||
};
|
||||
}
|
||||
|
||||
const stockbrokerToolCall = message.tool_calls?.find(
|
||||
findToolCall("get_stock_price")<typeof getStockPriceSchema>,
|
||||
);
|
||||
const portfolioToolCall = message.tool_calls?.find(
|
||||
findToolCall("get_portfolio")<typeof getStockPriceSchema>,
|
||||
);
|
||||
|
||||
if (stockbrokerToolCall) {
|
||||
const instruction = `The stock price of ${
|
||||
stockbrokerToolCall.args.ticker
|
||||
} is ${Math.random() * 100}`;
|
||||
|
||||
ui.write("stock-price", { instruction, logo: "hey" });
|
||||
}
|
||||
|
||||
if (portfolioToolCall) {
|
||||
ui.write("portfolio-view", {});
|
||||
}
|
||||
|
||||
return { messages: message, ui: ui.collect, timestamp: Date.now() };
|
||||
const builder = new StateGraph(GenerativeUIAnnotation)
|
||||
.addNode("router", router)
|
||||
.addNode("stockbroker", stockbrokerGraph)
|
||||
.addNode("weather", () => {
|
||||
throw new Error("Weather not implemented");
|
||||
})
|
||||
.addEdge(START, "agent");
|
||||
.addNode("generalInput", handleGeneralInput)
|
||||
|
||||
.addConditionalEdges("router", handleRoute, [
|
||||
"stockbroker",
|
||||
"weather",
|
||||
"generalInput",
|
||||
])
|
||||
.addEdge(START, "router")
|
||||
.addEdge("stockbroker", END)
|
||||
.addEdge("weather", END)
|
||||
.addEdge("generalInput", END);
|
||||
|
||||
export const graph = builder.compile();
|
||||
|
||||
// event handler of evetns ˇtypes)
|
||||
// event handler for specific node -> handle node
|
||||
|
||||
// TODO:
|
||||
// - Send run ID & additional metadata for the client to properly use messages (maybe we even have a config)
|
||||
// - Store that run ID in messages
|
||||
graph.name = "Generative UI Agent";
|
||||
|
||||
14
agent/find-tool-call.ts
Normal file
14
agent/find-tool-call.ts
Normal file
@@ -0,0 +1,14 @@
|
||||
import { z, ZodTypeAny } from "zod";
|
||||
|
||||
interface ToolCall {
|
||||
name: string;
|
||||
args: Record<string, any>;
|
||||
id?: string;
|
||||
type?: "tool_call";
|
||||
}
|
||||
|
||||
export function findToolCall<Name extends string>(name: Name) {
|
||||
return <Args extends ZodTypeAny>(
|
||||
x: ToolCall,
|
||||
): x is { name: Name; args: z.infer<Args> } => x.name === name;
|
||||
}
|
||||
11
agent/stockbroker/index.tsx
Normal file
11
agent/stockbroker/index.tsx
Normal file
@@ -0,0 +1,11 @@
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import { StateGraph, START } from "@langchain/langgraph";
|
||||
import { StockbrokerAnnotation } from "./types";
|
||||
import { callTools } from "./nodes/tools";
|
||||
|
||||
const builder = new StateGraph(StockbrokerAnnotation)
|
||||
.addNode("agent", callTools)
|
||||
.addEdge(START, "agent");
|
||||
|
||||
export const stockbrokerGraph = builder.compile();
|
||||
stockbrokerGraph.name = "Stockbroker";
|
||||
84
agent/stockbroker/nodes/tools.tsx
Normal file
84
agent/stockbroker/nodes/tools.tsx
Normal file
@@ -0,0 +1,84 @@
|
||||
import { StockbrokerState } from "../types";
|
||||
import { ToolMessage } from "@langchain/core/messages";
|
||||
import { ChatOpenAI } from "@langchain/openai";
|
||||
import { typedUi } from "@langchain/langgraph-sdk/react-ui/server";
|
||||
import type ComponentMap from "../../uis/index";
|
||||
import { z } from "zod";
|
||||
import { LangGraphRunnableConfig } from "@langchain/langgraph";
|
||||
import { findToolCall } from "../../find-tool-call";
|
||||
|
||||
const llm = new ChatOpenAI({ model: "gpt-4o-mini", temperature: 0 });
|
||||
|
||||
const getStockPriceSchema = z.object({
|
||||
ticker: z.string().describe("The ticker symbol of the company"),
|
||||
});
|
||||
const getPortfolioSchema = z.object({
|
||||
get_portfolio: z.boolean().describe("Should be true."),
|
||||
});
|
||||
|
||||
const STOCKBROKER_TOOLS = [
|
||||
{
|
||||
name: "stock-price",
|
||||
description: "A tool to get the stock price of a company",
|
||||
schema: getStockPriceSchema,
|
||||
},
|
||||
{
|
||||
name: "portfolio",
|
||||
description:
|
||||
"A tool to get the user's portfolio details. Only call this tool if the user requests their portfolio details.",
|
||||
schema: getPortfolioSchema,
|
||||
},
|
||||
];
|
||||
|
||||
export async function callTools(
|
||||
state: StockbrokerState,
|
||||
config: LangGraphRunnableConfig,
|
||||
): Promise<Partial<StockbrokerState>> {
|
||||
const ui = typedUi<typeof ComponentMap>(config);
|
||||
|
||||
const message = await llm.bindTools(STOCKBROKER_TOOLS).invoke([
|
||||
{
|
||||
role: "system",
|
||||
content:
|
||||
"You are a stockbroker agent that uses tools to get the stock price of a company",
|
||||
},
|
||||
...state.messages,
|
||||
]);
|
||||
|
||||
const stockbrokerToolCall = message.tool_calls?.find(
|
||||
findToolCall("stock-price")<typeof getStockPriceSchema>,
|
||||
);
|
||||
const portfolioToolCall = message.tool_calls?.find(
|
||||
findToolCall("portfolio")<typeof getStockPriceSchema>,
|
||||
);
|
||||
|
||||
if (stockbrokerToolCall) {
|
||||
const instruction = `The stock price of ${
|
||||
stockbrokerToolCall.args.ticker
|
||||
} is ${Math.random() * 100}`;
|
||||
|
||||
ui.write("stock-price", { instruction, logo: "hey" });
|
||||
}
|
||||
|
||||
if (portfolioToolCall) {
|
||||
ui.write("portfolio", {});
|
||||
}
|
||||
|
||||
const toolMessages =
|
||||
message.tool_calls?.map((tc) => {
|
||||
return new ToolMessage({
|
||||
name: tc.name,
|
||||
tool_call_id: tc.id ?? "",
|
||||
content: "Successfully handled tool call",
|
||||
});
|
||||
}) || [];
|
||||
|
||||
console.log("Returning", [message, ...toolMessages]);
|
||||
|
||||
return {
|
||||
messages: [message, ...toolMessages],
|
||||
// TODO: Fix the ui return type.
|
||||
ui: ui.collect as any[],
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
11
agent/stockbroker/types.ts
Normal file
11
agent/stockbroker/types.ts
Normal file
@@ -0,0 +1,11 @@
|
||||
import { Annotation } from "@langchain/langgraph";
|
||||
import { GenerativeUIAnnotation } from "../types";
|
||||
|
||||
export const StockbrokerAnnotation = Annotation.Root({
|
||||
messages: GenerativeUIAnnotation.spec.messages,
|
||||
ui: GenerativeUIAnnotation.spec.ui,
|
||||
timestamp: GenerativeUIAnnotation.spec.timestamp,
|
||||
next: Annotation<"stockbroker" | "weather">(),
|
||||
});
|
||||
|
||||
export type StockbrokerState = typeof StockbrokerAnnotation.State;
|
||||
11
agent/types.ts
Normal file
11
agent/types.ts
Normal file
@@ -0,0 +1,11 @@
|
||||
import { MessagesAnnotation, Annotation } from "@langchain/langgraph";
|
||||
import { uiMessageReducer } from "@langchain/langgraph-sdk/react-ui/types";
|
||||
|
||||
export const GenerativeUIAnnotation = Annotation.Root({
|
||||
messages: MessagesAnnotation.spec["messages"],
|
||||
ui: Annotation({ default: () => [], reducer: uiMessageReducer }),
|
||||
timestamp: Annotation<number>,
|
||||
next: Annotation<"stockbroker" | "weather" | "generalInput">(),
|
||||
});
|
||||
|
||||
export type GenerativeUIState = typeof GenerativeUIAnnotation.State;
|
||||
@@ -3,6 +3,6 @@ import PortfolioView from "./portfolio-view";
|
||||
|
||||
const ComponentMap = {
|
||||
"stock-price": StockPrice,
|
||||
"portfolio-view": PortfolioView,
|
||||
portfolio: PortfolioView,
|
||||
} as const;
|
||||
export default ComponentMap;
|
||||
|
||||
Reference in New Issue
Block a user