diff --git a/src/components/icons/langgraph.tsx b/src/components/icons/langgraph.tsx
index c23aaeb..826927b 100644
--- a/src/components/icons/langgraph.tsx
+++ b/src/components/icons/langgraph.tsx
@@ -8,8 +8,8 @@ export function LangGraphLogoSVG({ width = 20, height = 20 }) {
xmlns="http://www.w3.org/2000/svg"
>
diff --git a/src/components/thread/index.tsx b/src/components/thread/index.tsx
index c71b9f0..0280276 100644
--- a/src/components/thread/index.tsx
+++ b/src/components/thread/index.tsx
@@ -5,7 +5,7 @@ import { useStreamContext } from "@/providers/Stream";
import { useState, FormEvent } from "react";
import { Input } from "../ui/input";
import { Button } from "../ui/button";
-import { Message } from "@langchain/langgraph-sdk";
+import { Checkpoint, Message } from "@langchain/langgraph-sdk";
import { AssistantMessage, AssistantMessageLoading } from "./messages/ai";
import { HumanMessage } from "./messages/human";
import {
@@ -86,6 +86,18 @@ export function Thread() {
setInput("");
};
+ const handleRegenerate = (
+ parentCheckpoint: Checkpoint | null | undefined,
+ ) => {
+ // Do this so the loading state is correct
+ prevMessageLength.current = prevMessageLength.current - 1;
+ setFirstTokenReceived(false);
+ stream.submit(undefined, {
+ checkpoint: parentCheckpoint,
+ streamMode: ["values"],
+ });
+ };
+
const chatStarted = isLoading || messages.length > 0;
const renderMessages = messages.filter(
(m) => !m.id?.startsWith(DO_NOT_RENDER_ID_PREFIX),
@@ -128,6 +140,7 @@ export function Thread() {
key={"id" in message ? message.id : `${message.type}-${index}`}
message={message as Message}
isLoading={isLoading}
+ handleRegenerate={handleRegenerate}
/>
),
)}
diff --git a/src/components/thread/messages/ai.tsx b/src/components/thread/messages/ai.tsx
index 189f55d..505eaa8 100644
--- a/src/components/thread/messages/ai.tsx
+++ b/src/components/thread/messages/ai.tsx
@@ -1,5 +1,5 @@
import { useStreamContext } from "@/providers/Stream";
-import { Message } from "@langchain/langgraph-sdk";
+import { Checkpoint, Message } from "@langchain/langgraph-sdk";
import { getContentString } from "../utils";
import { BranchSwitcher, CommandBar } from "./shared";
import { Avatar, AvatarFallback } from "@/components/ui/avatar";
@@ -43,9 +43,11 @@ function CustomComponent({
export function AssistantMessage({
message,
isLoading,
+ handleRegenerate,
}: {
message: Message;
isLoading: boolean;
+ handleRegenerate: (parentCheckpoint: Checkpoint | null | undefined) => void;
}) {
const thread = useStreamContext();
const meta = thread.getMessagesMetadata(message);
@@ -53,10 +55,6 @@ export function AssistantMessage({
const contentString = getContentString(message.content);
- const handleRegenerate = () => {
- thread.submit(undefined, { checkpoint: parentCheckpoint, streamMode: ["values"] });
- };
-
return (
@@ -80,7 +78,7 @@ export function AssistantMessage({
content={contentString}
isLoading={isLoading}
isAiMessage={true}
- handleRegenerate={handleRegenerate}
+ handleRegenerate={() => handleRegenerate(parentCheckpoint)}
/>