fix tool call responses
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import { v4 as uuidv4 } from "uuid";
|
||||||
import { useEffect, useRef } from "react";
|
import { useEffect, useRef } from "react";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { useStreamContext } from "@/providers/Stream";
|
import { useStreamContext } from "@/providers/Stream";
|
||||||
@@ -7,6 +8,10 @@ import { Button } from "../ui/button";
|
|||||||
import { Message } from "@langchain/langgraph-sdk";
|
import { Message } from "@langchain/langgraph-sdk";
|
||||||
import { AssistantMessage, AssistantMessageLoading } from "./messages/ai";
|
import { AssistantMessage, AssistantMessageLoading } from "./messages/ai";
|
||||||
import { HumanMessage } from "./messages/human";
|
import { HumanMessage } from "./messages/human";
|
||||||
|
import {
|
||||||
|
DO_NOT_RENDER_ID_PREFIX,
|
||||||
|
ensureToolCallsHaveResponses,
|
||||||
|
} from "@/lib/ensure-tool-responses";
|
||||||
|
|
||||||
// const dummyMessages = [
|
// const dummyMessages = [
|
||||||
// { type: "human", content: "Hi! What can you do?" },
|
// { type: "human", content: "Hi! What can you do?" },
|
||||||
@@ -50,9 +55,18 @@ export function Thread() {
|
|||||||
if (!input.trim() || isLoading) return;
|
if (!input.trim() || isLoading) return;
|
||||||
setFirstTokenReceived(false);
|
setFirstTokenReceived(false);
|
||||||
|
|
||||||
|
const newHumanMessage: Message = {
|
||||||
|
id: uuidv4(),
|
||||||
|
type: "human",
|
||||||
|
content: input,
|
||||||
|
};
|
||||||
|
|
||||||
stream.submit(
|
stream.submit(
|
||||||
{
|
{
|
||||||
messages: [{ type: "human", content: input }],
|
messages: [
|
||||||
|
...ensureToolCallsHaveResponses(stream.messages),
|
||||||
|
newHumanMessage,
|
||||||
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
streamMode: ["values"],
|
streamMode: ["values"],
|
||||||
@@ -63,6 +77,9 @@ export function Thread() {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const chatStarted = isLoading || messages.length > 0;
|
const chatStarted = isLoading || messages.length > 0;
|
||||||
|
const renderMessages = messages.filter(
|
||||||
|
(m) => !m.id?.startsWith(DO_NOT_RENDER_ID_PREFIX),
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
@@ -87,7 +104,7 @@ export function Thread() {
|
|||||||
!chatStarted && "hidden",
|
!chatStarted && "hidden",
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
{messages.map((message, index) =>
|
{renderMessages.map((message, index) =>
|
||||||
message.type === "human" ? (
|
message.type === "human" ? (
|
||||||
<HumanMessage
|
<HumanMessage
|
||||||
key={"id" in message ? message.id : `${message.type}-${index}`}
|
key={"id" in message ? message.id : `${message.type}-${index}`}
|
||||||
@@ -120,7 +137,6 @@ export function Thread() {
|
|||||||
type="text"
|
type="text"
|
||||||
value={input}
|
value={input}
|
||||||
onChange={(e) => setInput(e.target.value)}
|
onChange={(e) => setInput(e.target.value)}
|
||||||
disabled={isLoading}
|
|
||||||
placeholder="Type your message..."
|
placeholder="Type your message..."
|
||||||
className="p-5 border-[0px] shadow-none ring-0 outline-none focus:outline-none focus:ring-0"
|
className="p-5 border-[0px] shadow-none ring-0 outline-none focus:outline-none focus:ring-0"
|
||||||
/>
|
/>
|
||||||
|
|||||||
@@ -4,6 +4,41 @@ import { getContentString } from "../utils";
|
|||||||
import { BranchSwitcher, CommandBar } from "./shared";
|
import { BranchSwitcher, CommandBar } from "./shared";
|
||||||
import { Avatar, AvatarFallback } from "@/components/ui/avatar";
|
import { Avatar, AvatarFallback } from "@/components/ui/avatar";
|
||||||
import { MarkdownText } from "../markdown-text";
|
import { MarkdownText } from "../markdown-text";
|
||||||
|
import { LoadExternalComponent } from "@langchain/langgraph-sdk/react-ui/client";
|
||||||
|
|
||||||
|
function CustomComponent({
|
||||||
|
message,
|
||||||
|
thread,
|
||||||
|
}: {
|
||||||
|
message: Message;
|
||||||
|
thread: ReturnType<typeof useStreamContext>;
|
||||||
|
}) {
|
||||||
|
const meta = thread.getMessagesMetadata(message);
|
||||||
|
const seenState = meta?.firstSeenState;
|
||||||
|
const customComponent = seenState?.values.ui
|
||||||
|
.slice()
|
||||||
|
.reverse()
|
||||||
|
.find(
|
||||||
|
({ additional_kwargs }) =>
|
||||||
|
additional_kwargs.run_id === seenState.metadata?.run_id,
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!customComponent) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div key={message.id}>
|
||||||
|
{customComponent && (
|
||||||
|
<LoadExternalComponent
|
||||||
|
assistantId="agent"
|
||||||
|
stream={thread}
|
||||||
|
message={customComponent}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
export function AssistantMessage({
|
export function AssistantMessage({
|
||||||
message,
|
message,
|
||||||
@@ -28,9 +63,12 @@ export function AssistantMessage({
|
|||||||
<AvatarFallback>A</AvatarFallback>
|
<AvatarFallback>A</AvatarFallback>
|
||||||
</Avatar>
|
</Avatar>
|
||||||
<div className="flex flex-col gap-2">
|
<div className="flex flex-col gap-2">
|
||||||
<div className="rounded-2xl bg-muted px-4 py-2">
|
<CustomComponent message={message} thread={thread} />
|
||||||
<MarkdownText>{contentString}</MarkdownText>
|
{contentString.length > 0 && (
|
||||||
</div>
|
<div className="rounded-2xl bg-muted px-4 py-2">
|
||||||
|
<MarkdownText>{contentString}</MarkdownText>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
<div className="flex gap-2 items-center mr-auto opacity-0 group-hover:opacity-100 transition-opacity">
|
<div className="flex gap-2 items-center mr-auto opacity-0 group-hover:opacity-100 transition-opacity">
|
||||||
<BranchSwitcher
|
<BranchSwitcher
|
||||||
branch={meta?.branch}
|
branch={meta?.branch}
|
||||||
|
|||||||
34
src/lib/ensure-tool-responses.ts
Normal file
34
src/lib/ensure-tool-responses.ts
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
import { v4 as uuidv4 } from "uuid";
|
||||||
|
import { Message, ToolMessage } from "@langchain/langgraph-sdk";
|
||||||
|
|
||||||
|
export const DO_NOT_RENDER_ID_PREFIX = "do-not-render-";
|
||||||
|
|
||||||
|
export function ensureToolCallsHaveResponses(messages: Message[]): Message[] {
|
||||||
|
const newMessages: ToolMessage[] = [];
|
||||||
|
|
||||||
|
messages.forEach((message, index) => {
|
||||||
|
if (message.type !== "ai" || message.tool_calls?.length === 0) {
|
||||||
|
// If it's not an AI message, or it doesn't have tool calls, we can ignore.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// If it has tool calls, ensure the message which follows this is a tool message
|
||||||
|
const followingMessage = messages[index + 1];
|
||||||
|
if (followingMessage && followingMessage.type === "tool") {
|
||||||
|
// Following message is a tool message, so we can ignore.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Since the following message is not a tool message, we must create a new tool message
|
||||||
|
newMessages.push(
|
||||||
|
...(message.tool_calls?.map((tc) => ({
|
||||||
|
type: "tool" as const,
|
||||||
|
tool_call_id: tc.id ?? "",
|
||||||
|
id: `${DO_NOT_RENDER_ID_PREFIX}${uuidv4()}`,
|
||||||
|
name: tc.name,
|
||||||
|
content: "Successfully handled tool call.",
|
||||||
|
})) ?? []),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
return newMessages;
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user