From 3f4aad48e61f42a17abee80d4d31b811ce99b457 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Thu, 27 Feb 2025 15:41:47 -0800 Subject: [PATCH] split up --- agent/agent.tsx | 166 ++++++++++++------------- agent/find-tool-call.ts | 14 +++ agent/stockbroker/index.tsx | 11 ++ agent/stockbroker/nodes/tools.tsx | 84 +++++++++++++ agent/stockbroker/types.ts | 11 ++ agent/types.ts | 11 ++ agent/uis/index.tsx | 2 +- langgraph.json | 2 +- package.json | 1 + pnpm-lock.yaml | 29 +++++ src/App.tsx | 2 +- src/components/assistant-ui/thread.tsx | 13 +- src/providers/Runtime.tsx | 7 -- src/providers/Stream.tsx | 5 - 14 files changed, 254 insertions(+), 104 deletions(-) create mode 100644 agent/find-tool-call.ts create mode 100644 agent/stockbroker/index.tsx create mode 100644 agent/stockbroker/nodes/tools.tsx create mode 100644 agent/stockbroker/types.ts create mode 100644 agent/types.ts diff --git a/agent/agent.tsx b/agent/agent.tsx index eea101c..e84c24d 100644 --- a/agent/agent.tsx +++ b/agent/agent.tsx @@ -1,100 +1,94 @@ -/* eslint-disable @typescript-eslint/no-explicit-any */ -import { - StateGraph, - MessagesAnnotation, - START, - Annotation, -} from "@langchain/langgraph"; -import { SystemMessage } from "@langchain/core/messages"; +import { StateGraph, START, END } from "@langchain/langgraph"; +import { ChatGoogleGenerativeAI } from "@langchain/google-genai"; +import { z } from "zod"; +import { GenerativeUIAnnotation, GenerativeUIState } from "./types"; +import { stockbrokerGraph } from "./stockbroker"; import { ChatOpenAI } from "@langchain/openai"; -import { typedUi } from "@langchain/langgraph-sdk/react-ui/server"; -import { uiMessageReducer } from "@langchain/langgraph-sdk/react-ui/types"; -import type ComponentMap from "./uis/index"; -import { z, ZodTypeAny } from "zod"; -const llm = new ChatOpenAI({ model: "gpt-4o-mini", temperature: 0 }); +async function router( + state: GenerativeUIState, +): Promise> { + const routerDescription = `The route to take based on the user's input. +- stockbroker: can fetch the price of a ticker, purchase/sell a ticker, or get the user's portfolio +- weather: can fetch the current weather conditions for a location +- generalInput: handles all other cases where the above tools don't apply +`; + const routerSchema = z.object({ + route: z + .enum(["stockbroker", "weather", "generalInput"]) + .describe(routerDescription), + }); + const routerTool = { + name: "router", + description: "A tool to route the user's query to the appropriate tool.", + schema: routerSchema, + }; -const getStockPriceSchema = z.object({ - ticker: z.string().describe("The ticker symbol of the company"), -}); -const getPortfolioSchema = z.object({ - get_portfolio: z.boolean().describe("Should be true."), -}); + const llm = new ChatGoogleGenerativeAI({ + model: "gemini-2.0-flash", + temperature: 0, + }).bindTools([routerTool], { tool_choice: "router" }); -const STOCKBROKER_TOOLS = [ - { - name: "get_stock_price", - description: "A tool to get the stock price of a company", - schema: getStockPriceSchema, - }, - { - name: "get_portfolio", - description: - "A tool to get the user's portfolio details. Only call this tool if the user requests their portfolio details.", - schema: getPortfolioSchema, - }, -]; + const prompt = `You're a highly helpful AI assistant, tasked with routing the user's query to the appropriate tool. +You should analyze the user's input, and choose the appropriate tool to use.`; -interface ToolCall { - name: string; - args: Record; - id?: string; - type?: "tool_call"; + const recentHumanMessage = state.messages + .reverse() + .find((m) => m.getType() === "human"); + + if (!recentHumanMessage) { + throw new Error("No human message found in state"); + } + + const response = await llm.invoke([ + { role: "system", content: prompt }, + recentHumanMessage, + ]); + + const toolCall = response.tool_calls?.[0]?.args as + | z.infer + | undefined; + if (!toolCall) { + throw new Error("No tool call found in response"); + } + + return { + next: toolCall.route, + }; } -function findToolCall(name: Name) { - return ( - x: ToolCall, - ): x is { name: Name; args: z.infer } => x.name === name; +function handleRoute( + state: GenerativeUIState, +): "stockbroker" | "weather" | "generalInput" { + return state.next; } -const builder = new StateGraph( - Annotation.Root({ - messages: MessagesAnnotation.spec["messages"], - ui: Annotation({ default: () => [], reducer: uiMessageReducer }), - timestamp: Annotation, - }), -) - .addNode("agent", async (state, config) => { - const ui = typedUi(config); +async function handleGeneralInput(state: GenerativeUIState) { + const llm = new ChatOpenAI({ model: "gpt-4o-mini", temperature: 0 }); + const response = await llm.invoke(state.messages); - const message = await llm - .bindTools(STOCKBROKER_TOOLS) - .invoke([ - new SystemMessage( - "You are a stockbroker agent that uses tools to get the stock price of a company", - ), - ...state.messages, - ]); + return { + messages: [response], + }; +} - const stockbrokerToolCall = message.tool_calls?.find( - findToolCall("get_stock_price"), - ); - const portfolioToolCall = message.tool_calls?.find( - findToolCall("get_portfolio"), - ); - - if (stockbrokerToolCall) { - const instruction = `The stock price of ${ - stockbrokerToolCall.args.ticker - } is ${Math.random() * 100}`; - - ui.write("stock-price", { instruction, logo: "hey" }); - } - - if (portfolioToolCall) { - ui.write("portfolio-view", {}); - } - - return { messages: message, ui: ui.collect, timestamp: Date.now() }; +const builder = new StateGraph(GenerativeUIAnnotation) + .addNode("router", router) + .addNode("stockbroker", stockbrokerGraph) + .addNode("weather", () => { + throw new Error("Weather not implemented"); }) - .addEdge(START, "agent"); + .addNode("generalInput", handleGeneralInput) + + .addConditionalEdges("router", handleRoute, [ + "stockbroker", + "weather", + "generalInput", + ]) + .addEdge(START, "router") + .addEdge("stockbroker", END) + .addEdge("weather", END) + .addEdge("generalInput", END); export const graph = builder.compile(); - -// event handler of evetns ˇtypes) -// event handler for specific node -> handle node - -// TODO: -// - Send run ID & additional metadata for the client to properly use messages (maybe we even have a config) -// - Store that run ID in messages +graph.name = "Generative UI Agent"; diff --git a/agent/find-tool-call.ts b/agent/find-tool-call.ts new file mode 100644 index 0000000..0372551 --- /dev/null +++ b/agent/find-tool-call.ts @@ -0,0 +1,14 @@ +import { z, ZodTypeAny } from "zod"; + +interface ToolCall { + name: string; + args: Record; + id?: string; + type?: "tool_call"; +} + +export function findToolCall(name: Name) { + return ( + x: ToolCall, + ): x is { name: Name; args: z.infer } => x.name === name; +} diff --git a/agent/stockbroker/index.tsx b/agent/stockbroker/index.tsx new file mode 100644 index 0000000..7315838 --- /dev/null +++ b/agent/stockbroker/index.tsx @@ -0,0 +1,11 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ +import { StateGraph, START } from "@langchain/langgraph"; +import { StockbrokerAnnotation } from "./types"; +import { callTools } from "./nodes/tools"; + +const builder = new StateGraph(StockbrokerAnnotation) + .addNode("agent", callTools) + .addEdge(START, "agent"); + +export const stockbrokerGraph = builder.compile(); +stockbrokerGraph.name = "Stockbroker"; diff --git a/agent/stockbroker/nodes/tools.tsx b/agent/stockbroker/nodes/tools.tsx new file mode 100644 index 0000000..8212965 --- /dev/null +++ b/agent/stockbroker/nodes/tools.tsx @@ -0,0 +1,84 @@ +import { StockbrokerState } from "../types"; +import { ToolMessage } from "@langchain/core/messages"; +import { ChatOpenAI } from "@langchain/openai"; +import { typedUi } from "@langchain/langgraph-sdk/react-ui/server"; +import type ComponentMap from "../../uis/index"; +import { z } from "zod"; +import { LangGraphRunnableConfig } from "@langchain/langgraph"; +import { findToolCall } from "../../find-tool-call"; + +const llm = new ChatOpenAI({ model: "gpt-4o-mini", temperature: 0 }); + +const getStockPriceSchema = z.object({ + ticker: z.string().describe("The ticker symbol of the company"), +}); +const getPortfolioSchema = z.object({ + get_portfolio: z.boolean().describe("Should be true."), +}); + +const STOCKBROKER_TOOLS = [ + { + name: "stock-price", + description: "A tool to get the stock price of a company", + schema: getStockPriceSchema, + }, + { + name: "portfolio", + description: + "A tool to get the user's portfolio details. Only call this tool if the user requests their portfolio details.", + schema: getPortfolioSchema, + }, +]; + +export async function callTools( + state: StockbrokerState, + config: LangGraphRunnableConfig, +): Promise> { + const ui = typedUi(config); + + const message = await llm.bindTools(STOCKBROKER_TOOLS).invoke([ + { + role: "system", + content: + "You are a stockbroker agent that uses tools to get the stock price of a company", + }, + ...state.messages, + ]); + + const stockbrokerToolCall = message.tool_calls?.find( + findToolCall("stock-price"), + ); + const portfolioToolCall = message.tool_calls?.find( + findToolCall("portfolio"), + ); + + if (stockbrokerToolCall) { + const instruction = `The stock price of ${ + stockbrokerToolCall.args.ticker + } is ${Math.random() * 100}`; + + ui.write("stock-price", { instruction, logo: "hey" }); + } + + if (portfolioToolCall) { + ui.write("portfolio", {}); + } + + const toolMessages = + message.tool_calls?.map((tc) => { + return new ToolMessage({ + name: tc.name, + tool_call_id: tc.id ?? "", + content: "Successfully handled tool call", + }); + }) || []; + + console.log("Returning", [message, ...toolMessages]); + + return { + messages: [message, ...toolMessages], + // TODO: Fix the ui return type. + ui: ui.collect as any[], + timestamp: Date.now(), + }; +} diff --git a/agent/stockbroker/types.ts b/agent/stockbroker/types.ts new file mode 100644 index 0000000..bd32d2e --- /dev/null +++ b/agent/stockbroker/types.ts @@ -0,0 +1,11 @@ +import { Annotation } from "@langchain/langgraph"; +import { GenerativeUIAnnotation } from "../types"; + +export const StockbrokerAnnotation = Annotation.Root({ + messages: GenerativeUIAnnotation.spec.messages, + ui: GenerativeUIAnnotation.spec.ui, + timestamp: GenerativeUIAnnotation.spec.timestamp, + next: Annotation<"stockbroker" | "weather">(), +}); + +export type StockbrokerState = typeof StockbrokerAnnotation.State; diff --git a/agent/types.ts b/agent/types.ts new file mode 100644 index 0000000..40b2d22 --- /dev/null +++ b/agent/types.ts @@ -0,0 +1,11 @@ +import { MessagesAnnotation, Annotation } from "@langchain/langgraph"; +import { uiMessageReducer } from "@langchain/langgraph-sdk/react-ui/types"; + +export const GenerativeUIAnnotation = Annotation.Root({ + messages: MessagesAnnotation.spec["messages"], + ui: Annotation({ default: () => [], reducer: uiMessageReducer }), + timestamp: Annotation, + next: Annotation<"stockbroker" | "weather" | "generalInput">(), +}); + +export type GenerativeUIState = typeof GenerativeUIAnnotation.State; diff --git a/agent/uis/index.tsx b/agent/uis/index.tsx index 6886aaf..a9259be 100644 --- a/agent/uis/index.tsx +++ b/agent/uis/index.tsx @@ -3,6 +3,6 @@ import PortfolioView from "./portfolio-view"; const ComponentMap = { "stock-price": StockPrice, - "portfolio-view": PortfolioView, + portfolio: PortfolioView, } as const; export default ComponentMap; diff --git a/langgraph.json b/langgraph.json index e29c8c2..eec351e 100644 --- a/langgraph.json +++ b/langgraph.json @@ -3,7 +3,7 @@ "agent": "./agent/agent.tsx:graph" }, "ui": { - "agent": "./agent/ui.tsx" + "agent": "./agent/uis/index.tsx" }, "_INTERNAL_docker_tag": "20", "env": ".env" diff --git a/package.json b/package.json index 028fce0..bc51e0a 100644 --- a/package.json +++ b/package.json @@ -15,6 +15,7 @@ "@assistant-ui/react": "^0.8.0", "@assistant-ui/react-markdown": "^0.8.0", "@langchain/core": "^0.3.41", + "@langchain/google-genai": "^0.1.10", "@langchain/langgraph": "^0.2.49", "@langchain/langgraph-api": "*", "@langchain/langgraph-cli": "*", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index f168958..81ebc3c 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -21,6 +21,9 @@ importers: "@langchain/core": specifier: ^0.3.41 version: 0.3.41(openai@4.85.4(zod@3.24.2)) + "@langchain/google-genai": + specifier: ^0.1.10 + version: 0.1.10(@langchain/core@0.3.41(openai@4.85.4(zod@3.24.2)))(zod@3.24.2) "@langchain/langgraph": specifier: ^0.2.49 version: 0.2.49(@langchain/core@0.3.41(openai@4.85.4(zod@3.24.2)))(react@19.0.0) @@ -879,6 +882,13 @@ packages: integrity: sha512-MDWhGtE+eHw5JW7lq4qhc5yRLS11ERl1c7Z6Xd0a58DozHES6EnNNwUWbMiG4J9Cgj053Bhk8zvlhFYKVhULwg==, } + "@google/generative-ai@0.21.0": + resolution: + { + integrity: sha512-7XhUbtnlkSEZK15kN3t+tzIMxsbKm/dSkKBFalj+20NvPKe1kBY7mR2P7vuijEn+f06z5+A8bVGKO0v39cr6Wg==, + } + engines: { node: ">=18.0.0" } + "@hono/node-server@1.13.8": resolution: { @@ -986,6 +996,15 @@ packages: } engines: { node: ">=18" } + "@langchain/google-genai@0.1.10": + resolution: + { + integrity: sha512-+0xFWvauNDNp8Nvhy5F5g8RbB5g4WWQSIxoPI4IQIUICBBT/kS/Omf1VJI6Loc0IH93m9ZSwYxRVCRu3qx51TQ==, + } + engines: { node: ">=18" } + peerDependencies: + "@langchain/core": ">=0.3.17 <0.4.0" + "@langchain/langgraph-api@http://localhost:3123/17/@langchain/langgraph-api": resolution: { tarball: http://localhost:3123/17/@langchain/langgraph-api } version: 0.0.10 @@ -5313,6 +5332,8 @@ snapshots: "@floating-ui/utils@0.2.9": {} + "@google/generative-ai@0.21.0": {} + "@hono/node-server@1.13.8(hono@4.7.2)": dependencies: hono: 4.7.2 @@ -5382,6 +5403,14 @@ snapshots: transitivePeerDependencies: - openai + "@langchain/google-genai@0.1.10(@langchain/core@0.3.41(openai@4.85.4(zod@3.24.2)))(zod@3.24.2)": + dependencies: + "@google/generative-ai": 0.21.0 + "@langchain/core": 0.3.41(openai@4.85.4(zod@3.24.2)) + zod-to-json-schema: 3.24.3(zod@3.24.2) + transitivePeerDependencies: + - zod + "@langchain/langgraph-api@http://localhost:3123/17/@langchain/langgraph-api(@langchain/core@0.3.41(openai@4.85.4(zod@3.24.2)))(@langchain/langgraph-checkpoint@0.0.15(@langchain/core@0.3.41(openai@4.85.4(zod@3.24.2))))(@langchain/langgraph@0.2.49(@langchain/core@0.3.41(openai@4.85.4(zod@3.24.2)))(react@19.0.0))(openai@4.85.4(zod@3.24.2))(typescript@5.7.3)": dependencies: "@babel/code-frame": 7.26.2 diff --git a/src/App.tsx b/src/App.tsx index 613c89e..717de17 100644 --- a/src/App.tsx +++ b/src/App.tsx @@ -3,7 +3,7 @@ import { Thread } from "@/components/assistant-ui/thread"; function App() { return ( -
+
); diff --git a/src/components/assistant-ui/thread.tsx b/src/components/assistant-ui/thread.tsx index 10f9dc4..5c3aa30 100644 --- a/src/components/assistant-ui/thread.tsx +++ b/src/components/assistant-ui/thread.tsx @@ -218,6 +218,7 @@ function CustomComponent({ }) { const meta = thread.getMessagesMetadata(message, idx); const seenState = meta?.firstSeenState; + console.log("seenState", meta); const customComponent = seenState?.values.ui .slice() .reverse() @@ -226,9 +227,13 @@ function CustomComponent({ additional_kwargs.run_id === seenState.metadata?.run_id, ); + if (!customComponent) { + console.log("no custom component", message, meta); + return null; + } + return (
-
{JSON.stringify(message, null, 2)}
{customComponent && ( { const assistantMsgs = useMessage((m) => { const langchainMessage = getExternalStoreMessages(m); return langchainMessage; - })?.[0]; + }); + + const assistantMsg = assistantMsgs[0]; let threadMsgIdx: number | undefined = undefined; const threadMsg = thread.messages.find((m, idx) => { - if (m.id === assistantMsgs?.id) { + if (m.id === assistantMsg?.id) { threadMsgIdx = idx; return true; } diff --git a/src/providers/Runtime.tsx b/src/providers/Runtime.tsx index 38290b9..86ff9f6 100644 --- a/src/providers/Runtime.tsx +++ b/src/providers/Runtime.tsx @@ -21,14 +21,7 @@ export function RuntimeProvider({ const input = message.content[0].text; const humanMessage: HumanMessage = { type: "human", content: input }; - // TODO: I dont think I need to do this, since we're passing stream.messages into the state hook, and it should update when we call `submit` - // setMessages((currentConversation) => [ - // ...currentConversation, - // humanMessage, - // ]); - stream.submit({ messages: [humanMessage] }); - console.log("Sent message", humanMessage); }; const runtime = useExternalStoreRuntime({ diff --git a/src/providers/Stream.tsx b/src/providers/Stream.tsx index 0c79f9c..c830637 100644 --- a/src/providers/Stream.tsx +++ b/src/providers/Stream.tsx @@ -6,7 +6,6 @@ import type { RemoveUIMessage, } from "@langchain/langgraph-sdk/react-ui/types"; -// Define the type for the context value type StreamContextType = ReturnType< typeof useStream< { messages: Message[]; ui: UIMessage[] }, @@ -18,10 +17,8 @@ type StreamContextType = ReturnType< > >; -// Create the context with a default undefined value const StreamContext = createContext(undefined); -// Create a provider component export const StreamProvider: React.FC<{ children: ReactNode }> = ({ children, }) => { @@ -37,8 +34,6 @@ export const StreamProvider: React.FC<{ children: ReactNode }> = ({ assistantId: "agent", }); - console.log("StreamProvider", streamValue); - return ( {children}