fix message reversing

This commit is contained in:
bracesproul
2025-02-28 14:15:37 -08:00
parent 3f4aad48e6
commit a15a9104b4
6 changed files with 57 additions and 129 deletions

View File

@@ -32,7 +32,8 @@ async function router(
const prompt = `You're a highly helpful AI assistant, tasked with routing the user's query to the appropriate tool. const prompt = `You're a highly helpful AI assistant, tasked with routing the user's query to the appropriate tool.
You should analyze the user's input, and choose the appropriate tool to use.`; You should analyze the user's input, and choose the appropriate tool to use.`;
const recentHumanMessage = state.messages const messagesCopy = state.messages;
const recentHumanMessage = messagesCopy
.reverse() .reverse()
.find((m) => m.getType() === "human"); .find((m) => m.getType() === "human");
@@ -65,7 +66,11 @@ function handleRoute(
async function handleGeneralInput(state: GenerativeUIState) { async function handleGeneralInput(state: GenerativeUIState) {
const llm = new ChatOpenAI({ model: "gpt-4o-mini", temperature: 0 }); const llm = new ChatOpenAI({ model: "gpt-4o-mini", temperature: 0 });
const response = await llm.invoke(state.messages); const messagesCopy = state.messages;
messagesCopy.reverse();
console.log("messagesCopy", messagesCopy);
const response = await llm.invoke(messagesCopy);
return { return {
messages: [response], messages: [response],

View File

@@ -1,5 +1,4 @@
import { StockbrokerState } from "../types"; import { StockbrokerState } from "../types";
import { ToolMessage } from "@langchain/core/messages";
import { ChatOpenAI } from "@langchain/openai"; import { ChatOpenAI } from "@langchain/openai";
import { typedUi } from "@langchain/langgraph-sdk/react-ui/server"; import { typedUi } from "@langchain/langgraph-sdk/react-ui/server";
import type ComponentMap from "../../uis/index"; import type ComponentMap from "../../uis/index";
@@ -64,19 +63,8 @@ export async function callTools(
ui.write("portfolio", {}); ui.write("portfolio", {});
} }
const toolMessages =
message.tool_calls?.map((tc) => {
return new ToolMessage({
name: tc.name,
tool_call_id: tc.id ?? "",
content: "Successfully handled tool call",
});
}) || [];
console.log("Returning", [message, ...toolMessages]);
return { return {
messages: [message, ...toolMessages], messages: [message],
// TODO: Fix the ui return type. // TODO: Fix the ui return type.
ui: ui.collect as any[], ui: ui.collect as any[],
timestamp: Date.now(), timestamp: Date.now(),

View File

@@ -15,7 +15,9 @@ export default function StockPrice(props: {
apiUrl: "http://localhost:3123", apiUrl: "http://localhost:3123",
}); });
const aiTool = thread.messages const messagesCopy = thread.messages;
const aiTool = messagesCopy
.slice() .slice()
.reverse() .reverse()
.find( .find(

View File

@@ -1,6 +1,5 @@
import { import {
ActionBarPrimitive, ActionBarPrimitive,
BranchPickerPrimitive,
ComposerPrimitive, ComposerPrimitive,
getExternalStoreMessages, getExternalStoreMessages,
MessagePrimitive, MessagePrimitive,
@@ -8,17 +7,7 @@ import {
useMessage, useMessage,
} from "@assistant-ui/react"; } from "@assistant-ui/react";
import type { FC } from "react"; import type { FC } from "react";
import { import { ArrowDownIcon, PencilIcon, SendHorizontalIcon } from "lucide-react";
ArrowDownIcon,
CheckIcon,
ChevronLeftIcon,
ChevronRightIcon,
CopyIcon,
PencilIcon,
RefreshCwIcon,
SendHorizontalIcon,
} from "lucide-react";
import { cn } from "@/lib/utils";
import { LoadExternalComponent } from "@langchain/langgraph-sdk/react-ui/client"; import { LoadExternalComponent } from "@langchain/langgraph-sdk/react-ui/client";
import { Avatar, AvatarFallback } from "@/components/ui/avatar"; import { Avatar, AvatarFallback } from "@/components/ui/avatar";
@@ -168,8 +157,6 @@ const UserMessage: FC = () => {
<div className="bg-muted text-foreground max-w-[calc(var(--thread-max-width)*0.8)] break-words rounded-3xl px-5 py-2.5 col-start-2 row-start-2"> <div className="bg-muted text-foreground max-w-[calc(var(--thread-max-width)*0.8)] break-words rounded-3xl px-5 py-2.5 col-start-2 row-start-2">
<MessagePrimitive.Content /> <MessagePrimitive.Content />
</div> </div>
<BranchPicker className="col-span-full col-start-1 row-start-3 -mr-1 justify-end" />
</MessagePrimitive.Root> </MessagePrimitive.Root>
); );
}; };
@@ -218,7 +205,6 @@ function CustomComponent({
}) { }) {
const meta = thread.getMessagesMetadata(message, idx); const meta = thread.getMessagesMetadata(message, idx);
const seenState = meta?.firstSeenState; const seenState = meta?.firstSeenState;
console.log("seenState", meta);
const customComponent = seenState?.values.ui const customComponent = seenState?.values.ui
.slice() .slice()
.reverse() .reverse()
@@ -228,7 +214,6 @@ function CustomComponent({
); );
if (!customComponent) { if (!customComponent) {
console.log("no custom component", message, meta);
return null; return null;
} }
@@ -278,85 +263,10 @@ const AssistantMessage: FC = () => {
<div className="text-foreground max-w-[calc(var(--thread-max-width)*0.8)] break-words leading-7 col-span-2 col-start-2 row-start-1 my-1.5"> <div className="text-foreground max-w-[calc(var(--thread-max-width)*0.8)] break-words leading-7 col-span-2 col-start-2 row-start-1 my-1.5">
<MessagePrimitive.Content components={{ Text: MarkdownText }} /> <MessagePrimitive.Content components={{ Text: MarkdownText }} />
</div> </div>
<AssistantActionBar />
<BranchPicker className="col-start-2 row-start-2 -ml-2 mr-2" />
</MessagePrimitive.Root> </MessagePrimitive.Root>
); );
}; };
const AssistantActionBar: FC = () => {
return (
<ActionBarPrimitive.Root
hideWhenRunning
autohide="not-last"
autohideFloat="single-branch"
className="text-muted-foreground flex gap-1 col-start-3 row-start-2 -ml-1 data-[floating]:bg-background data-[floating]:absolute data-[floating]:rounded-md data-[floating]:border data-[floating]:p-1 data-[floating]:shadow-sm"
>
{/* <MessagePrimitive.If speaking={false}>
<ActionBarPrimitive.Speak asChild>
<TooltipIconButton tooltip="Read aloud">
<AudioLinesIcon />
</TooltipIconButton>
</ActionBarPrimitive.Speak>
</MessagePrimitive.If>
<MessagePrimitive.If speaking>
<ActionBarPrimitive.StopSpeaking asChild>
<TooltipIconButton tooltip="Stop">
<StopCircleIcon />
</TooltipIconButton>
</ActionBarPrimitive.StopSpeaking>
</MessagePrimitive.If> */}
<ActionBarPrimitive.Copy asChild>
<TooltipIconButton tooltip="Copy">
<MessagePrimitive.If copied>
<CheckIcon />
</MessagePrimitive.If>
<MessagePrimitive.If copied={false}>
<CopyIcon />
</MessagePrimitive.If>
</TooltipIconButton>
</ActionBarPrimitive.Copy>
<ActionBarPrimitive.Reload asChild>
<TooltipIconButton tooltip="Refresh">
<RefreshCwIcon />
</TooltipIconButton>
</ActionBarPrimitive.Reload>
</ActionBarPrimitive.Root>
);
};
const BranchPicker: FC<BranchPickerPrimitive.Root.Props> = ({
className,
...rest
}) => {
return (
<BranchPickerPrimitive.Root
hideWhenSingleBranch
className={cn(
"text-muted-foreground inline-flex items-center text-xs",
className,
)}
{...rest}
>
<BranchPickerPrimitive.Previous asChild>
<TooltipIconButton tooltip="Previous">
<ChevronLeftIcon />
</TooltipIconButton>
</BranchPickerPrimitive.Previous>
<span className="font-medium">
<BranchPickerPrimitive.Number /> / <BranchPickerPrimitive.Count />
</span>
<BranchPickerPrimitive.Next asChild>
<TooltipIconButton tooltip="Next">
<ChevronRightIcon />
</TooltipIconButton>
</BranchPickerPrimitive.Next>
</BranchPickerPrimitive.Root>
);
};
const CircleStopIcon = () => { const CircleStopIcon = () => {
return ( return (
<svg <svg

View File

@@ -1,13 +1,43 @@
import { ReactNode } from "react"; import { ReactNode, useEffect } from "react";
import { import {
useExternalStoreRuntime, useExternalStoreRuntime,
AppendMessage, AppendMessage,
AssistantRuntimeProvider, AssistantRuntimeProvider,
} from "@assistant-ui/react"; } from "@assistant-ui/react";
import { HumanMessage } from "@langchain/langgraph-sdk"; import { HumanMessage, Message, ToolMessage } from "@langchain/langgraph-sdk";
import { useStreamContext } from "./Stream"; import { useStreamContext } from "./Stream";
import { convertLangChainMessages } from "./convert-messages"; import { convertLangChainMessages } from "./convert-messages";
function ensureToolCallsHaveResponses(messages: Message[]): Message[] {
const newMessages: ToolMessage[] = [];
messages.forEach((message, index) => {
if (message.type !== "ai" || message.tool_calls?.length === 0) {
// If it's not an AI message, or it doesn't have tool calls, we can ignore.
return;
}
// If it has tool calls, ensure the message which follows this is a tool message
const followingMessage = messages[index + 1];
if (followingMessage && followingMessage.type === "tool") {
// Following message is a tool message, so we can ignore.
return;
}
// Since the following message is not a tool message, we must create a new tool message
newMessages.push(
...(message.tool_calls?.map((tc) => ({
type: "tool" as const,
tool_call_id: tc.id ?? "",
id: tc.id ?? "",
name: tc.name,
content: "Successfully handled tool call.",
})) ?? []),
);
});
return newMessages;
}
export function RuntimeProvider({ export function RuntimeProvider({
children, children,
}: Readonly<{ }: Readonly<{
@@ -21,9 +51,20 @@ export function RuntimeProvider({
const input = message.content[0].text; const input = message.content[0].text;
const humanMessage: HumanMessage = { type: "human", content: input }; const humanMessage: HumanMessage = { type: "human", content: input };
stream.submit({ messages: [humanMessage] }); const newMessages = [
...ensureToolCallsHaveResponses(stream.messages),
humanMessage,
];
console.log("Sending new messages", newMessages);
stream.submit({
messages: newMessages,
});
}; };
useEffect(() => {
console.log("useEffect - stream.messages", stream.messages);
}, [stream.messages]);
const runtime = useExternalStoreRuntime({ const runtime = useExternalStoreRuntime({
isRunning: stream.isLoading, isRunning: stream.isLoading,
messages: stream.messages, messages: stream.messages,

View File

@@ -79,15 +79,6 @@ export function convertLangChainMessages(message: Message): ThreadMessageLike {
role: "user", role: "user",
id: message.id, id: message.id,
content: [{ type: "text", text: content }], content: [{ type: "text", text: content }],
// ...(message.additional_kwargs
// ? {
// metadata: {
// custom: {
// ...message.additional_kwargs,
// },
// },
// }
// : {}),
}; };
case "ai": case "ai":
const aiMsg = message as AIMessage; const aiMsg = message as AIMessage;
@@ -110,20 +101,11 @@ export function convertLangChainMessages(message: Message): ThreadMessageLike {
text: content, text: content,
}, },
], ],
// ...(message.additional_kwargs
// ? {
// metadata: {
// custom: {
// ...message.additional_kwargs,
// },
// },
// }
// : {}),
}; };
case "tool": case "tool":
const toolMsg = message as ToolMessage; const toolMsg = message as ToolMessage;
return { return {
role: "user", role: "assistant",
content: [ content: [
{ {
type: "tool-call", type: "tool-call",