This commit is contained in:
bracesproul
2025-02-27 15:41:47 -08:00
parent c7b61071a1
commit 3f4aad48e6
14 changed files with 254 additions and 104 deletions

View File

@@ -1,100 +1,94 @@
/* eslint-disable @typescript-eslint/no-explicit-any */ import { StateGraph, START, END } from "@langchain/langgraph";
import { import { ChatGoogleGenerativeAI } from "@langchain/google-genai";
StateGraph, import { z } from "zod";
MessagesAnnotation, import { GenerativeUIAnnotation, GenerativeUIState } from "./types";
START, import { stockbrokerGraph } from "./stockbroker";
Annotation,
} from "@langchain/langgraph";
import { SystemMessage } from "@langchain/core/messages";
import { ChatOpenAI } from "@langchain/openai"; import { ChatOpenAI } from "@langchain/openai";
import { typedUi } from "@langchain/langgraph-sdk/react-ui/server";
import { uiMessageReducer } from "@langchain/langgraph-sdk/react-ui/types";
import type ComponentMap from "./uis/index";
import { z, ZodTypeAny } from "zod";
const llm = new ChatOpenAI({ model: "gpt-4o-mini", temperature: 0 }); async function router(
state: GenerativeUIState,
): Promise<Partial<GenerativeUIState>> {
const routerDescription = `The route to take based on the user's input.
- stockbroker: can fetch the price of a ticker, purchase/sell a ticker, or get the user's portfolio
- weather: can fetch the current weather conditions for a location
- generalInput: handles all other cases where the above tools don't apply
`;
const routerSchema = z.object({
route: z
.enum(["stockbroker", "weather", "generalInput"])
.describe(routerDescription),
});
const routerTool = {
name: "router",
description: "A tool to route the user's query to the appropriate tool.",
schema: routerSchema,
};
const getStockPriceSchema = z.object({ const llm = new ChatGoogleGenerativeAI({
ticker: z.string().describe("The ticker symbol of the company"), model: "gemini-2.0-flash",
}); temperature: 0,
const getPortfolioSchema = z.object({ }).bindTools([routerTool], { tool_choice: "router" });
get_portfolio: z.boolean().describe("Should be true."),
});
const STOCKBROKER_TOOLS = [ const prompt = `You're a highly helpful AI assistant, tasked with routing the user's query to the appropriate tool.
{ You should analyze the user's input, and choose the appropriate tool to use.`;
name: "get_stock_price",
description: "A tool to get the stock price of a company",
schema: getStockPriceSchema,
},
{
name: "get_portfolio",
description:
"A tool to get the user's portfolio details. Only call this tool if the user requests their portfolio details.",
schema: getPortfolioSchema,
},
];
interface ToolCall { const recentHumanMessage = state.messages
name: string; .reverse()
args: Record<string, any>; .find((m) => m.getType() === "human");
id?: string;
type?: "tool_call"; if (!recentHumanMessage) {
throw new Error("No human message found in state");
}
const response = await llm.invoke([
{ role: "system", content: prompt },
recentHumanMessage,
]);
const toolCall = response.tool_calls?.[0]?.args as
| z.infer<typeof routerSchema>
| undefined;
if (!toolCall) {
throw new Error("No tool call found in response");
}
return {
next: toolCall.route,
};
} }
function findToolCall<Name extends string>(name: Name) { function handleRoute(
return <Args extends ZodTypeAny>( state: GenerativeUIState,
x: ToolCall, ): "stockbroker" | "weather" | "generalInput" {
): x is { name: Name; args: z.infer<Args> } => x.name === name; return state.next;
} }
const builder = new StateGraph( async function handleGeneralInput(state: GenerativeUIState) {
Annotation.Root({ const llm = new ChatOpenAI({ model: "gpt-4o-mini", temperature: 0 });
messages: MessagesAnnotation.spec["messages"], const response = await llm.invoke(state.messages);
ui: Annotation({ default: () => [], reducer: uiMessageReducer }),
timestamp: Annotation<number>,
}),
)
.addNode("agent", async (state, config) => {
const ui = typedUi<typeof ComponentMap>(config);
const message = await llm return {
.bindTools(STOCKBROKER_TOOLS) messages: [response],
.invoke([ };
new SystemMessage( }
"You are a stockbroker agent that uses tools to get the stock price of a company",
),
...state.messages,
]);
const stockbrokerToolCall = message.tool_calls?.find( const builder = new StateGraph(GenerativeUIAnnotation)
findToolCall("get_stock_price")<typeof getStockPriceSchema>, .addNode("router", router)
); .addNode("stockbroker", stockbrokerGraph)
const portfolioToolCall = message.tool_calls?.find( .addNode("weather", () => {
findToolCall("get_portfolio")<typeof getStockPriceSchema>, throw new Error("Weather not implemented");
);
if (stockbrokerToolCall) {
const instruction = `The stock price of ${
stockbrokerToolCall.args.ticker
} is ${Math.random() * 100}`;
ui.write("stock-price", { instruction, logo: "hey" });
}
if (portfolioToolCall) {
ui.write("portfolio-view", {});
}
return { messages: message, ui: ui.collect, timestamp: Date.now() };
}) })
.addEdge(START, "agent"); .addNode("generalInput", handleGeneralInput)
.addConditionalEdges("router", handleRoute, [
"stockbroker",
"weather",
"generalInput",
])
.addEdge(START, "router")
.addEdge("stockbroker", END)
.addEdge("weather", END)
.addEdge("generalInput", END);
export const graph = builder.compile(); export const graph = builder.compile();
graph.name = "Generative UI Agent";
// event handler of evetns ˇtypes)
// event handler for specific node -> handle node
// TODO:
// - Send run ID & additional metadata for the client to properly use messages (maybe we even have a config)
// - Store that run ID in messages

14
agent/find-tool-call.ts Normal file
View File

@@ -0,0 +1,14 @@
import { z, ZodTypeAny } from "zod";
interface ToolCall {
name: string;
args: Record<string, any>;
id?: string;
type?: "tool_call";
}
export function findToolCall<Name extends string>(name: Name) {
return <Args extends ZodTypeAny>(
x: ToolCall,
): x is { name: Name; args: z.infer<Args> } => x.name === name;
}

View File

@@ -0,0 +1,11 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { StateGraph, START } from "@langchain/langgraph";
import { StockbrokerAnnotation } from "./types";
import { callTools } from "./nodes/tools";
const builder = new StateGraph(StockbrokerAnnotation)
.addNode("agent", callTools)
.addEdge(START, "agent");
export const stockbrokerGraph = builder.compile();
stockbrokerGraph.name = "Stockbroker";

View File

@@ -0,0 +1,84 @@
import { StockbrokerState } from "../types";
import { ToolMessage } from "@langchain/core/messages";
import { ChatOpenAI } from "@langchain/openai";
import { typedUi } from "@langchain/langgraph-sdk/react-ui/server";
import type ComponentMap from "../../uis/index";
import { z } from "zod";
import { LangGraphRunnableConfig } from "@langchain/langgraph";
import { findToolCall } from "../../find-tool-call";
const llm = new ChatOpenAI({ model: "gpt-4o-mini", temperature: 0 });
const getStockPriceSchema = z.object({
ticker: z.string().describe("The ticker symbol of the company"),
});
const getPortfolioSchema = z.object({
get_portfolio: z.boolean().describe("Should be true."),
});
const STOCKBROKER_TOOLS = [
{
name: "stock-price",
description: "A tool to get the stock price of a company",
schema: getStockPriceSchema,
},
{
name: "portfolio",
description:
"A tool to get the user's portfolio details. Only call this tool if the user requests their portfolio details.",
schema: getPortfolioSchema,
},
];
export async function callTools(
state: StockbrokerState,
config: LangGraphRunnableConfig,
): Promise<Partial<StockbrokerState>> {
const ui = typedUi<typeof ComponentMap>(config);
const message = await llm.bindTools(STOCKBROKER_TOOLS).invoke([
{
role: "system",
content:
"You are a stockbroker agent that uses tools to get the stock price of a company",
},
...state.messages,
]);
const stockbrokerToolCall = message.tool_calls?.find(
findToolCall("stock-price")<typeof getStockPriceSchema>,
);
const portfolioToolCall = message.tool_calls?.find(
findToolCall("portfolio")<typeof getStockPriceSchema>,
);
if (stockbrokerToolCall) {
const instruction = `The stock price of ${
stockbrokerToolCall.args.ticker
} is ${Math.random() * 100}`;
ui.write("stock-price", { instruction, logo: "hey" });
}
if (portfolioToolCall) {
ui.write("portfolio", {});
}
const toolMessages =
message.tool_calls?.map((tc) => {
return new ToolMessage({
name: tc.name,
tool_call_id: tc.id ?? "",
content: "Successfully handled tool call",
});
}) || [];
console.log("Returning", [message, ...toolMessages]);
return {
messages: [message, ...toolMessages],
// TODO: Fix the ui return type.
ui: ui.collect as any[],
timestamp: Date.now(),
};
}

View File

@@ -0,0 +1,11 @@
import { Annotation } from "@langchain/langgraph";
import { GenerativeUIAnnotation } from "../types";
export const StockbrokerAnnotation = Annotation.Root({
messages: GenerativeUIAnnotation.spec.messages,
ui: GenerativeUIAnnotation.spec.ui,
timestamp: GenerativeUIAnnotation.spec.timestamp,
next: Annotation<"stockbroker" | "weather">(),
});
export type StockbrokerState = typeof StockbrokerAnnotation.State;

11
agent/types.ts Normal file
View File

@@ -0,0 +1,11 @@
import { MessagesAnnotation, Annotation } from "@langchain/langgraph";
import { uiMessageReducer } from "@langchain/langgraph-sdk/react-ui/types";
export const GenerativeUIAnnotation = Annotation.Root({
messages: MessagesAnnotation.spec["messages"],
ui: Annotation({ default: () => [], reducer: uiMessageReducer }),
timestamp: Annotation<number>,
next: Annotation<"stockbroker" | "weather" | "generalInput">(),
});
export type GenerativeUIState = typeof GenerativeUIAnnotation.State;

View File

@@ -3,6 +3,6 @@ import PortfolioView from "./portfolio-view";
const ComponentMap = { const ComponentMap = {
"stock-price": StockPrice, "stock-price": StockPrice,
"portfolio-view": PortfolioView, portfolio: PortfolioView,
} as const; } as const;
export default ComponentMap; export default ComponentMap;

View File

@@ -3,7 +3,7 @@
"agent": "./agent/agent.tsx:graph" "agent": "./agent/agent.tsx:graph"
}, },
"ui": { "ui": {
"agent": "./agent/ui.tsx" "agent": "./agent/uis/index.tsx"
}, },
"_INTERNAL_docker_tag": "20", "_INTERNAL_docker_tag": "20",
"env": ".env" "env": ".env"

View File

@@ -15,6 +15,7 @@
"@assistant-ui/react": "^0.8.0", "@assistant-ui/react": "^0.8.0",
"@assistant-ui/react-markdown": "^0.8.0", "@assistant-ui/react-markdown": "^0.8.0",
"@langchain/core": "^0.3.41", "@langchain/core": "^0.3.41",
"@langchain/google-genai": "^0.1.10",
"@langchain/langgraph": "^0.2.49", "@langchain/langgraph": "^0.2.49",
"@langchain/langgraph-api": "*", "@langchain/langgraph-api": "*",
"@langchain/langgraph-cli": "*", "@langchain/langgraph-cli": "*",

29
pnpm-lock.yaml generated
View File

@@ -21,6 +21,9 @@ importers:
"@langchain/core": "@langchain/core":
specifier: ^0.3.41 specifier: ^0.3.41
version: 0.3.41(openai@4.85.4(zod@3.24.2)) version: 0.3.41(openai@4.85.4(zod@3.24.2))
"@langchain/google-genai":
specifier: ^0.1.10
version: 0.1.10(@langchain/core@0.3.41(openai@4.85.4(zod@3.24.2)))(zod@3.24.2)
"@langchain/langgraph": "@langchain/langgraph":
specifier: ^0.2.49 specifier: ^0.2.49
version: 0.2.49(@langchain/core@0.3.41(openai@4.85.4(zod@3.24.2)))(react@19.0.0) version: 0.2.49(@langchain/core@0.3.41(openai@4.85.4(zod@3.24.2)))(react@19.0.0)
@@ -879,6 +882,13 @@ packages:
integrity: sha512-MDWhGtE+eHw5JW7lq4qhc5yRLS11ERl1c7Z6Xd0a58DozHES6EnNNwUWbMiG4J9Cgj053Bhk8zvlhFYKVhULwg==, integrity: sha512-MDWhGtE+eHw5JW7lq4qhc5yRLS11ERl1c7Z6Xd0a58DozHES6EnNNwUWbMiG4J9Cgj053Bhk8zvlhFYKVhULwg==,
} }
"@google/generative-ai@0.21.0":
resolution:
{
integrity: sha512-7XhUbtnlkSEZK15kN3t+tzIMxsbKm/dSkKBFalj+20NvPKe1kBY7mR2P7vuijEn+f06z5+A8bVGKO0v39cr6Wg==,
}
engines: { node: ">=18.0.0" }
"@hono/node-server@1.13.8": "@hono/node-server@1.13.8":
resolution: resolution:
{ {
@@ -986,6 +996,15 @@ packages:
} }
engines: { node: ">=18" } engines: { node: ">=18" }
"@langchain/google-genai@0.1.10":
resolution:
{
integrity: sha512-+0xFWvauNDNp8Nvhy5F5g8RbB5g4WWQSIxoPI4IQIUICBBT/kS/Omf1VJI6Loc0IH93m9ZSwYxRVCRu3qx51TQ==,
}
engines: { node: ">=18" }
peerDependencies:
"@langchain/core": ">=0.3.17 <0.4.0"
"@langchain/langgraph-api@http://localhost:3123/17/@langchain/langgraph-api": "@langchain/langgraph-api@http://localhost:3123/17/@langchain/langgraph-api":
resolution: { tarball: http://localhost:3123/17/@langchain/langgraph-api } resolution: { tarball: http://localhost:3123/17/@langchain/langgraph-api }
version: 0.0.10 version: 0.0.10
@@ -5313,6 +5332,8 @@ snapshots:
"@floating-ui/utils@0.2.9": {} "@floating-ui/utils@0.2.9": {}
"@google/generative-ai@0.21.0": {}
"@hono/node-server@1.13.8(hono@4.7.2)": "@hono/node-server@1.13.8(hono@4.7.2)":
dependencies: dependencies:
hono: 4.7.2 hono: 4.7.2
@@ -5382,6 +5403,14 @@ snapshots:
transitivePeerDependencies: transitivePeerDependencies:
- openai - openai
"@langchain/google-genai@0.1.10(@langchain/core@0.3.41(openai@4.85.4(zod@3.24.2)))(zod@3.24.2)":
dependencies:
"@google/generative-ai": 0.21.0
"@langchain/core": 0.3.41(openai@4.85.4(zod@3.24.2))
zod-to-json-schema: 3.24.3(zod@3.24.2)
transitivePeerDependencies:
- zod
"@langchain/langgraph-api@http://localhost:3123/17/@langchain/langgraph-api(@langchain/core@0.3.41(openai@4.85.4(zod@3.24.2)))(@langchain/langgraph-checkpoint@0.0.15(@langchain/core@0.3.41(openai@4.85.4(zod@3.24.2))))(@langchain/langgraph@0.2.49(@langchain/core@0.3.41(openai@4.85.4(zod@3.24.2)))(react@19.0.0))(openai@4.85.4(zod@3.24.2))(typescript@5.7.3)": "@langchain/langgraph-api@http://localhost:3123/17/@langchain/langgraph-api(@langchain/core@0.3.41(openai@4.85.4(zod@3.24.2)))(@langchain/langgraph-checkpoint@0.0.15(@langchain/core@0.3.41(openai@4.85.4(zod@3.24.2))))(@langchain/langgraph@0.2.49(@langchain/core@0.3.41(openai@4.85.4(zod@3.24.2)))(react@19.0.0))(openai@4.85.4(zod@3.24.2))(typescript@5.7.3)":
dependencies: dependencies:
"@babel/code-frame": 7.26.2 "@babel/code-frame": 7.26.2

View File

@@ -3,7 +3,7 @@ import { Thread } from "@/components/assistant-ui/thread";
function App() { function App() {
return ( return (
<div className="h-full"> <div className="h-screen">
<Thread /> <Thread />
</div> </div>
); );

View File

@@ -218,6 +218,7 @@ function CustomComponent({
}) { }) {
const meta = thread.getMessagesMetadata(message, idx); const meta = thread.getMessagesMetadata(message, idx);
const seenState = meta?.firstSeenState; const seenState = meta?.firstSeenState;
console.log("seenState", meta);
const customComponent = seenState?.values.ui const customComponent = seenState?.values.ui
.slice() .slice()
.reverse() .reverse()
@@ -226,9 +227,13 @@ function CustomComponent({
additional_kwargs.run_id === seenState.metadata?.run_id, additional_kwargs.run_id === seenState.metadata?.run_id,
); );
if (!customComponent) {
console.log("no custom component", message, meta);
return null;
}
return ( return (
<div key={message.id}> <div key={message.id}>
<pre>{JSON.stringify(message, null, 2)}</pre>
{customComponent && ( {customComponent && (
<LoadExternalComponent <LoadExternalComponent
assistantId="agent" assistantId="agent"
@@ -245,10 +250,12 @@ const AssistantMessage: FC = () => {
const assistantMsgs = useMessage((m) => { const assistantMsgs = useMessage((m) => {
const langchainMessage = getExternalStoreMessages<Message>(m); const langchainMessage = getExternalStoreMessages<Message>(m);
return langchainMessage; return langchainMessage;
})?.[0]; });
const assistantMsg = assistantMsgs[0];
let threadMsgIdx: number | undefined = undefined; let threadMsgIdx: number | undefined = undefined;
const threadMsg = thread.messages.find((m, idx) => { const threadMsg = thread.messages.find((m, idx) => {
if (m.id === assistantMsgs?.id) { if (m.id === assistantMsg?.id) {
threadMsgIdx = idx; threadMsgIdx = idx;
return true; return true;
} }

View File

@@ -21,14 +21,7 @@ export function RuntimeProvider({
const input = message.content[0].text; const input = message.content[0].text;
const humanMessage: HumanMessage = { type: "human", content: input }; const humanMessage: HumanMessage = { type: "human", content: input };
// TODO: I dont think I need to do this, since we're passing stream.messages into the state hook, and it should update when we call `submit`
// setMessages((currentConversation) => [
// ...currentConversation,
// humanMessage,
// ]);
stream.submit({ messages: [humanMessage] }); stream.submit({ messages: [humanMessage] });
console.log("Sent message", humanMessage);
}; };
const runtime = useExternalStoreRuntime({ const runtime = useExternalStoreRuntime({

View File

@@ -6,7 +6,6 @@ import type {
RemoveUIMessage, RemoveUIMessage,
} from "@langchain/langgraph-sdk/react-ui/types"; } from "@langchain/langgraph-sdk/react-ui/types";
// Define the type for the context value
type StreamContextType = ReturnType< type StreamContextType = ReturnType<
typeof useStream< typeof useStream<
{ messages: Message[]; ui: UIMessage[] }, { messages: Message[]; ui: UIMessage[] },
@@ -18,10 +17,8 @@ type StreamContextType = ReturnType<
> >
>; >;
// Create the context with a default undefined value
const StreamContext = createContext<StreamContextType | undefined>(undefined); const StreamContext = createContext<StreamContextType | undefined>(undefined);
// Create a provider component
export const StreamProvider: React.FC<{ children: ReactNode }> = ({ export const StreamProvider: React.FC<{ children: ReactNode }> = ({
children, children,
}) => { }) => {
@@ -37,8 +34,6 @@ export const StreamProvider: React.FC<{ children: ReactNode }> = ({
assistantId: "agent", assistantId: "agent",
}); });
console.log("StreamProvider", streamValue);
return ( return (
<StreamContext.Provider value={streamValue}> <StreamContext.Provider value={streamValue}>
{children} {children}