improved agent

This commit is contained in:
bracesproul
2025-02-27 14:08:24 -08:00
parent 4e6a831214
commit c7b61071a1
25 changed files with 4438 additions and 2270 deletions

View File

@@ -9,12 +9,32 @@ import { SystemMessage } from "@langchain/core/messages";
import { ChatOpenAI } from "@langchain/openai";
import { typedUi } from "@langchain/langgraph-sdk/react-ui/server";
import { uiMessageReducer } from "@langchain/langgraph-sdk/react-ui/types";
import type ComponentMap from "./ui";
import type ComponentMap from "./uis/index";
import { z, ZodTypeAny } from "zod";
// const llm = new ChatOllama({ model: "deepseek-r1" });
const llm = new ChatOpenAI({ model: "gpt-4o-mini", temperature: 0 });
const getStockPriceSchema = z.object({
ticker: z.string().describe("The ticker symbol of the company"),
});
const getPortfolioSchema = z.object({
get_portfolio: z.boolean().describe("Should be true."),
});
const STOCKBROKER_TOOLS = [
{
name: "get_stock_price",
description: "A tool to get the stock price of a company",
schema: getStockPriceSchema,
},
{
name: "get_portfolio",
description:
"A tool to get the user's portfolio details. Only call this tool if the user requests their portfolio details.",
schema: getPortfolioSchema,
},
];
interface ToolCall {
name: string;
args: Record<string, any>;
@@ -24,7 +44,7 @@ interface ToolCall {
function findToolCall<Name extends string>(name: Name) {
return <Args extends ZodTypeAny>(
x: ToolCall
x: ToolCall,
): x is { name: Name; args: z.infer<Args> } => x.name === name;
}
@@ -33,46 +53,37 @@ const builder = new StateGraph(
messages: MessagesAnnotation.spec["messages"],
ui: Annotation({ default: () => [], reducer: uiMessageReducer }),
timestamp: Annotation<number>,
})
}),
)
.addNode("agent", async (state, config) => {
const ui = typedUi<typeof ComponentMap>(config);
// const result = ui.interrupt("react-component", {
// instruction: "Hello world",
// });
// // throw new Error("Random error");
// // stream custom events
// for (let count = 0; count < 10; count++) config.writer?.({ count });
// How do I properly assign
const stockbrokerSchema = z.object({ company: z.string() });
const message = await llm
.bindTools([
{
name: "stockbroker",
description: "A tool to get the stock price of a company",
schema: stockbrokerSchema,
},
])
.bindTools(STOCKBROKER_TOOLS)
.invoke([
new SystemMessage(
"You are a stockbroker agent that uses tools to get the stock price of a company"
"You are a stockbroker agent that uses tools to get the stock price of a company",
),
...state.messages,
]);
const stockbrokerToolCall = message.tool_calls?.find(
findToolCall("stockbroker")<typeof stockbrokerSchema>
findToolCall("get_stock_price")<typeof getStockPriceSchema>,
);
const portfolioToolCall = message.tool_calls?.find(
findToolCall("get_portfolio")<typeof getStockPriceSchema>,
);
if (stockbrokerToolCall) {
const instruction = `The stock price of ${
stockbrokerToolCall.args.company
stockbrokerToolCall.args.ticker
} is ${Math.random() * 100}`;
ui.write("react-component", { instruction, logo: "hey" });
ui.write("stock-price", { instruction, logo: "hey" });
}
if (portfolioToolCall) {
ui.write("portfolio-view", {});
}
return { messages: message, ui: ui.collect, timestamp: Date.now() };

8
agent/uis/index.tsx Normal file
View File

@@ -0,0 +1,8 @@
import StockPrice from "./stock-price";
import PortfolioView from "./portfolio-view";
const ComponentMap = {
"stock-price": StockPrice,
"portfolio-view": PortfolioView,
} as const;
export default ComponentMap;

View File

@@ -0,0 +1,9 @@
import "./index.css";
export default function PortfolioView() {
return (
<div className="flex flex-col gap-2 border border-solid border-slate-500 p-4 rounded-md">
Portfolio View
</div>
);
}

View File

@@ -0,0 +1 @@
@import "tailwindcss";

View File

@@ -1,9 +1,12 @@
import "./ui.css";
import "./index.css";
import { useStream } from "@langchain/langgraph-sdk/react";
import type { AIMessage, Message } from "@langchain/langgraph-sdk";
import { useState } from "react";
function ReactComponent(props: { instruction: string; logo: string }) {
export default function StockPrice(props: {
instruction: string;
logo: string;
}) {
const [counter, setCounter] = useState(0);
// useStream should be able to be infered from context
@@ -17,7 +20,7 @@ function ReactComponent(props: { instruction: string; logo: string }) {
.reverse()
.find(
(message): message is AIMessage =>
message.type === "ai" && !!message.tool_calls?.length
message.type === "ai" && !!message.tool_calls?.length,
);
const toolCallId = aiTool?.tool_calls?.[0]?.id;
@@ -52,6 +55,3 @@ function ReactComponent(props: { instruction: string; logo: string }) {
</div>
);
}
const ComponentMap = { "react-component": ReactComponent } as const;
export default ComponentMap;