diff --git a/src/components/icons/langgraph.tsx b/src/components/icons/langgraph.tsx index c23aaeb..826927b 100644 --- a/src/components/icons/langgraph.tsx +++ b/src/components/icons/langgraph.tsx @@ -8,8 +8,8 @@ export function LangGraphLogoSVG({ width = 20, height = 20 }) { xmlns="http://www.w3.org/2000/svg" > diff --git a/src/components/thread/index.tsx b/src/components/thread/index.tsx index c71b9f0..0280276 100644 --- a/src/components/thread/index.tsx +++ b/src/components/thread/index.tsx @@ -5,7 +5,7 @@ import { useStreamContext } from "@/providers/Stream"; import { useState, FormEvent } from "react"; import { Input } from "../ui/input"; import { Button } from "../ui/button"; -import { Message } from "@langchain/langgraph-sdk"; +import { Checkpoint, Message } from "@langchain/langgraph-sdk"; import { AssistantMessage, AssistantMessageLoading } from "./messages/ai"; import { HumanMessage } from "./messages/human"; import { @@ -86,6 +86,18 @@ export function Thread() { setInput(""); }; + const handleRegenerate = ( + parentCheckpoint: Checkpoint | null | undefined, + ) => { + // Do this so the loading state is correct + prevMessageLength.current = prevMessageLength.current - 1; + setFirstTokenReceived(false); + stream.submit(undefined, { + checkpoint: parentCheckpoint, + streamMode: ["values"], + }); + }; + const chatStarted = isLoading || messages.length > 0; const renderMessages = messages.filter( (m) => !m.id?.startsWith(DO_NOT_RENDER_ID_PREFIX), @@ -128,6 +140,7 @@ export function Thread() { key={"id" in message ? message.id : `${message.type}-${index}`} message={message as Message} isLoading={isLoading} + handleRegenerate={handleRegenerate} /> ), )} diff --git a/src/components/thread/messages/ai.tsx b/src/components/thread/messages/ai.tsx index 189f55d..505eaa8 100644 --- a/src/components/thread/messages/ai.tsx +++ b/src/components/thread/messages/ai.tsx @@ -1,5 +1,5 @@ import { useStreamContext } from "@/providers/Stream"; -import { Message } from "@langchain/langgraph-sdk"; +import { Checkpoint, Message } from "@langchain/langgraph-sdk"; import { getContentString } from "../utils"; import { BranchSwitcher, CommandBar } from "./shared"; import { Avatar, AvatarFallback } from "@/components/ui/avatar"; @@ -43,9 +43,11 @@ function CustomComponent({ export function AssistantMessage({ message, isLoading, + handleRegenerate, }: { message: Message; isLoading: boolean; + handleRegenerate: (parentCheckpoint: Checkpoint | null | undefined) => void; }) { const thread = useStreamContext(); const meta = thread.getMessagesMetadata(message); @@ -53,10 +55,6 @@ export function AssistantMessage({ const contentString = getContentString(message.content); - const handleRegenerate = () => { - thread.submit(undefined, { checkpoint: parentCheckpoint, streamMode: ["values"] }); - }; - return (
@@ -80,7 +78,7 @@ export function AssistantMessage({ content={contentString} isLoading={isLoading} isAiMessage={true} - handleRegenerate={handleRegenerate} + handleRegenerate={() => handleRegenerate(parentCheckpoint)} />