From d654db38711e1b2057d88751f5fe847189d803a1 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Mon, 3 Mar 2025 12:40:24 -0800 Subject: [PATCH] fix tool call responses --- src/components/thread/index.tsx | 22 ++++++++++++-- src/components/thread/messages/ai.tsx | 44 +++++++++++++++++++++++++-- src/lib/ensure-tool-responses.ts | 34 +++++++++++++++++++++ 3 files changed, 94 insertions(+), 6 deletions(-) create mode 100644 src/lib/ensure-tool-responses.ts diff --git a/src/components/thread/index.tsx b/src/components/thread/index.tsx index 3be9989..761073f 100644 --- a/src/components/thread/index.tsx +++ b/src/components/thread/index.tsx @@ -1,3 +1,4 @@ +import { v4 as uuidv4 } from "uuid"; import { useEffect, useRef } from "react"; import { cn } from "@/lib/utils"; import { useStreamContext } from "@/providers/Stream"; @@ -7,6 +8,10 @@ import { Button } from "../ui/button"; import { Message } from "@langchain/langgraph-sdk"; import { AssistantMessage, AssistantMessageLoading } from "./messages/ai"; import { HumanMessage } from "./messages/human"; +import { + DO_NOT_RENDER_ID_PREFIX, + ensureToolCallsHaveResponses, +} from "@/lib/ensure-tool-responses"; // const dummyMessages = [ // { type: "human", content: "Hi! What can you do?" }, @@ -50,9 +55,18 @@ export function Thread() { if (!input.trim() || isLoading) return; setFirstTokenReceived(false); + const newHumanMessage: Message = { + id: uuidv4(), + type: "human", + content: input, + }; + stream.submit( { - messages: [{ type: "human", content: input }], + messages: [ + ...ensureToolCallsHaveResponses(stream.messages), + newHumanMessage, + ], }, { streamMode: ["values"], @@ -63,6 +77,9 @@ export function Thread() { }; const chatStarted = isLoading || messages.length > 0; + const renderMessages = messages.filter( + (m) => !m.id?.startsWith(DO_NOT_RENDER_ID_PREFIX), + ); return (
- {messages.map((message, index) => + {renderMessages.map((message, index) => message.type === "human" ? ( setInput(e.target.value)} - disabled={isLoading} placeholder="Type your message..." className="p-5 border-[0px] shadow-none ring-0 outline-none focus:outline-none focus:ring-0" /> diff --git a/src/components/thread/messages/ai.tsx b/src/components/thread/messages/ai.tsx index 1dd5e74..079d67d 100644 --- a/src/components/thread/messages/ai.tsx +++ b/src/components/thread/messages/ai.tsx @@ -4,6 +4,41 @@ import { getContentString } from "../utils"; import { BranchSwitcher, CommandBar } from "./shared"; import { Avatar, AvatarFallback } from "@/components/ui/avatar"; import { MarkdownText } from "../markdown-text"; +import { LoadExternalComponent } from "@langchain/langgraph-sdk/react-ui/client"; + +function CustomComponent({ + message, + thread, +}: { + message: Message; + thread: ReturnType; +}) { + const meta = thread.getMessagesMetadata(message); + const seenState = meta?.firstSeenState; + const customComponent = seenState?.values.ui + .slice() + .reverse() + .find( + ({ additional_kwargs }) => + additional_kwargs.run_id === seenState.metadata?.run_id, + ); + + if (!customComponent) { + return null; + } + + return ( +
+ {customComponent && ( + + )} +
+ ); +} export function AssistantMessage({ message, @@ -28,9 +63,12 @@ export function AssistantMessage({ A
-
- {contentString} -
+ + {contentString.length > 0 && ( +
+ {contentString} +
+ )}
{ + if (message.type !== "ai" || message.tool_calls?.length === 0) { + // If it's not an AI message, or it doesn't have tool calls, we can ignore. + return; + } + // If it has tool calls, ensure the message which follows this is a tool message + const followingMessage = messages[index + 1]; + if (followingMessage && followingMessage.type === "tool") { + // Following message is a tool message, so we can ignore. + return; + } + + // Since the following message is not a tool message, we must create a new tool message + newMessages.push( + ...(message.tool_calls?.map((tc) => ({ + type: "tool" as const, + tool_call_id: tc.id ?? "", + id: `${DO_NOT_RENDER_ID_PREFIX}${uuidv4()}`, + name: tc.name, + content: "Successfully handled tool call.", + })) ?? []), + ); + }); + + return newMessages; +}