From a15a9104b4183a6fa6188c12cb57681acd30330c Mon Sep 17 00:00:00 2001 From: bracesproul Date: Fri, 28 Feb 2025 14:15:37 -0800 Subject: [PATCH] fix message reversing --- agent/agent.tsx | 9 ++- agent/stockbroker/nodes/tools.tsx | 14 +--- agent/uis/stock-price/index.tsx | 4 +- src/components/assistant-ui/thread.tsx | 92 +------------------------- src/providers/Runtime.tsx | 47 ++++++++++++- src/providers/convert-messages.ts | 20 +----- 6 files changed, 57 insertions(+), 129 deletions(-) diff --git a/agent/agent.tsx b/agent/agent.tsx index e84c24d..76f3fb9 100644 --- a/agent/agent.tsx +++ b/agent/agent.tsx @@ -32,7 +32,8 @@ async function router( const prompt = `You're a highly helpful AI assistant, tasked with routing the user's query to the appropriate tool. You should analyze the user's input, and choose the appropriate tool to use.`; - const recentHumanMessage = state.messages + const messagesCopy = state.messages; + const recentHumanMessage = messagesCopy .reverse() .find((m) => m.getType() === "human"); @@ -65,7 +66,11 @@ function handleRoute( async function handleGeneralInput(state: GenerativeUIState) { const llm = new ChatOpenAI({ model: "gpt-4o-mini", temperature: 0 }); - const response = await llm.invoke(state.messages); + const messagesCopy = state.messages; + messagesCopy.reverse(); + console.log("messagesCopy", messagesCopy); + + const response = await llm.invoke(messagesCopy); return { messages: [response], diff --git a/agent/stockbroker/nodes/tools.tsx b/agent/stockbroker/nodes/tools.tsx index 8212965..9fbe432 100644 --- a/agent/stockbroker/nodes/tools.tsx +++ b/agent/stockbroker/nodes/tools.tsx @@ -1,5 +1,4 @@ import { StockbrokerState } from "../types"; -import { ToolMessage } from "@langchain/core/messages"; import { ChatOpenAI } from "@langchain/openai"; import { typedUi } from "@langchain/langgraph-sdk/react-ui/server"; import type ComponentMap from "../../uis/index"; @@ -64,19 +63,8 @@ export async function callTools( ui.write("portfolio", {}); } - const toolMessages = - message.tool_calls?.map((tc) => { - return new ToolMessage({ - name: tc.name, - tool_call_id: tc.id ?? "", - content: "Successfully handled tool call", - }); - }) || []; - - console.log("Returning", [message, ...toolMessages]); - return { - messages: [message, ...toolMessages], + messages: [message], // TODO: Fix the ui return type. ui: ui.collect as any[], timestamp: Date.now(), diff --git a/agent/uis/stock-price/index.tsx b/agent/uis/stock-price/index.tsx index 99ce8cd..d773f07 100644 --- a/agent/uis/stock-price/index.tsx +++ b/agent/uis/stock-price/index.tsx @@ -15,7 +15,9 @@ export default function StockPrice(props: { apiUrl: "http://localhost:3123", }); - const aiTool = thread.messages + const messagesCopy = thread.messages; + + const aiTool = messagesCopy .slice() .reverse() .find( diff --git a/src/components/assistant-ui/thread.tsx b/src/components/assistant-ui/thread.tsx index 5c3aa30..d99e53d 100644 --- a/src/components/assistant-ui/thread.tsx +++ b/src/components/assistant-ui/thread.tsx @@ -1,6 +1,5 @@ import { ActionBarPrimitive, - BranchPickerPrimitive, ComposerPrimitive, getExternalStoreMessages, MessagePrimitive, @@ -8,17 +7,7 @@ import { useMessage, } from "@assistant-ui/react"; import type { FC } from "react"; -import { - ArrowDownIcon, - CheckIcon, - ChevronLeftIcon, - ChevronRightIcon, - CopyIcon, - PencilIcon, - RefreshCwIcon, - SendHorizontalIcon, -} from "lucide-react"; -import { cn } from "@/lib/utils"; +import { ArrowDownIcon, PencilIcon, SendHorizontalIcon } from "lucide-react"; import { LoadExternalComponent } from "@langchain/langgraph-sdk/react-ui/client"; import { Avatar, AvatarFallback } from "@/components/ui/avatar"; @@ -168,8 +157,6 @@ const UserMessage: FC = () => {
- - ); }; @@ -218,7 +205,6 @@ function CustomComponent({ }) { const meta = thread.getMessagesMetadata(message, idx); const seenState = meta?.firstSeenState; - console.log("seenState", meta); const customComponent = seenState?.values.ui .slice() .reverse() @@ -228,7 +214,6 @@ function CustomComponent({ ); if (!customComponent) { - console.log("no custom component", message, meta); return null; } @@ -278,85 +263,10 @@ const AssistantMessage: FC = () => {
- - - - ); }; -const AssistantActionBar: FC = () => { - return ( - - {/* - - - - - - - - - - - - - */} - - - - - - - - - - - - - - - - - ); -}; - -const BranchPicker: FC = ({ - className, - ...rest -}) => { - return ( - - - - - - - - / - - - - - - - - ); -}; - const CircleStopIcon = () => { return ( { + if (message.type !== "ai" || message.tool_calls?.length === 0) { + // If it's not an AI message, or it doesn't have tool calls, we can ignore. + return; + } + // If it has tool calls, ensure the message which follows this is a tool message + const followingMessage = messages[index + 1]; + if (followingMessage && followingMessage.type === "tool") { + // Following message is a tool message, so we can ignore. + return; + } + + // Since the following message is not a tool message, we must create a new tool message + newMessages.push( + ...(message.tool_calls?.map((tc) => ({ + type: "tool" as const, + tool_call_id: tc.id ?? "", + id: tc.id ?? "", + name: tc.name, + content: "Successfully handled tool call.", + })) ?? []), + ); + }); + + return newMessages; +} + export function RuntimeProvider({ children, }: Readonly<{ @@ -21,9 +51,20 @@ export function RuntimeProvider({ const input = message.content[0].text; const humanMessage: HumanMessage = { type: "human", content: input }; - stream.submit({ messages: [humanMessage] }); + const newMessages = [ + ...ensureToolCallsHaveResponses(stream.messages), + humanMessage, + ]; + console.log("Sending new messages", newMessages); + stream.submit({ + messages: newMessages, + }); }; + useEffect(() => { + console.log("useEffect - stream.messages", stream.messages); + }, [stream.messages]); + const runtime = useExternalStoreRuntime({ isRunning: stream.isLoading, messages: stream.messages, diff --git a/src/providers/convert-messages.ts b/src/providers/convert-messages.ts index 506f9b0..8e8e556 100644 --- a/src/providers/convert-messages.ts +++ b/src/providers/convert-messages.ts @@ -79,15 +79,6 @@ export function convertLangChainMessages(message: Message): ThreadMessageLike { role: "user", id: message.id, content: [{ type: "text", text: content }], - // ...(message.additional_kwargs - // ? { - // metadata: { - // custom: { - // ...message.additional_kwargs, - // }, - // }, - // } - // : {}), }; case "ai": const aiMsg = message as AIMessage; @@ -110,20 +101,11 @@ export function convertLangChainMessages(message: Message): ThreadMessageLike { text: content, }, ], - // ...(message.additional_kwargs - // ? { - // metadata: { - // custom: { - // ...message.additional_kwargs, - // }, - // }, - // } - // : {}), }; case "tool": const toolMsg = message as ToolMessage; return { - role: "user", + role: "assistant", content: [ { type: "tool-call",