feat: drop assistant ui, use custom chat ui

This commit is contained in:
bracesproul
2025-03-03 12:31:27 -08:00
parent a75c710990
commit 3f3f50d5c5
20 changed files with 4553 additions and 2477 deletions

View File

@@ -6,7 +6,7 @@ import { stockbrokerGraph } from "./stockbroker";
import { ChatOpenAI } from "@langchain/openai"; import { ChatOpenAI } from "@langchain/openai";
async function router( async function router(
state: GenerativeUIState state: GenerativeUIState,
): Promise<Partial<GenerativeUIState>> { ): Promise<Partial<GenerativeUIState>> {
const routerDescription = `The route to take based on the user's input. const routerDescription = `The route to take based on the user's input.
- stockbroker: can fetch the price of a ticker, purchase/sell a ticker, or get the user's portfolio - stockbroker: can fetch the price of a ticker, purchase/sell a ticker, or get the user's portfolio
@@ -35,7 +35,7 @@ async function router(
You should analyze the user's input, and choose the appropriate tool to use.`; You should analyze the user's input, and choose the appropriate tool to use.`;
const recentHumanMessage = state.messages.findLast( const recentHumanMessage = state.messages.findLast(
(m) => m.getType() === "human" (m) => m.getType() === "human",
); );
if (!recentHumanMessage) { if (!recentHumanMessage) {
@@ -60,7 +60,7 @@ You should analyze the user's input, and choose the appropriate tool to use.`;
} }
function handleRoute( function handleRoute(
state: GenerativeUIState state: GenerativeUIState,
): "stockbroker" | "weather" | "generalInput" { ): "stockbroker" | "weather" | "generalInput" {
return state.next; return state.next;
} }

View File

@@ -22,7 +22,7 @@ export default function StockPrice(props: {
.reverse() .reverse()
.find( .find(
(message): message is AIMessage => (message): message is AIMessage =>
message.type === "ai" && !!message.tool_calls?.length message.type === "ai" && !!message.tool_calls?.length,
); );
const toolCallId = aiTool?.tool_calls?.[0]?.id; const toolCallId = aiTool?.tool_calls?.[0]?.id;

View File

@@ -30,10 +30,12 @@
"clsx": "^2.1.1", "clsx": "^2.1.1",
"esbuild": "^0.25.0", "esbuild": "^0.25.0",
"esbuild-plugin-tailwindcss": "^2.0.1", "esbuild-plugin-tailwindcss": "^2.0.1",
"framer-motion": "^12.4.9",
"lucide-react": "^0.476.0", "lucide-react": "^0.476.0",
"prettier": "^3.5.2", "prettier": "^3.5.2",
"react": "^19.0.0", "react": "^19.0.0",
"react-dom": "^19.0.0", "react-dom": "^19.0.0",
"react-markdown": "^10.0.1",
"remark-gfm": "^4.0.1", "remark-gfm": "^4.0.1",
"tailwind-merge": "^3.0.2", "tailwind-merge": "^3.0.2",
"tailwindcss-animate": "^1.0.7", "tailwindcss-animate": "^1.0.7",

5930
pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,9 +1,9 @@
import "./App.css"; import "./App.css";
import { Thread } from "@/components/assistant-ui/thread"; import { Thread } from "@/components/thread";
function App() { function App() {
return ( return (
<div className="h-screen"> <div>
<Thread /> <Thread />
</div> </div>
); );

View File

@@ -1,282 +0,0 @@
import {
ActionBarPrimitive,
ComposerPrimitive,
getExternalStoreMessages,
MessagePrimitive,
ThreadPrimitive,
useMessage,
} from "@assistant-ui/react";
import type { FC } from "react";
import { ArrowDownIcon, PencilIcon, SendHorizontalIcon } from "lucide-react";
import { LoadExternalComponent } from "@langchain/langgraph-sdk/react-ui/client";
import { Avatar, AvatarFallback } from "@/components/ui/avatar";
import { Button } from "@/components/ui/button";
import { MarkdownText } from "@/components/assistant-ui/markdown-text";
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
import { Message } from "@langchain/langgraph-sdk";
import { useStreamContext } from "@/providers/Stream";
export const Thread: FC = () => {
return (
<ThreadPrimitive.Root
className="bg-background box-border h-full"
style={{
["--thread-max-width" as string]: "42rem",
}}
>
<ThreadPrimitive.Viewport className="flex h-full flex-col items-center overflow-y-scroll scroll-smooth bg-inherit px-4 pt-8">
<ThreadWelcome />
<ThreadPrimitive.Messages
components={{
UserMessage: UserMessage,
EditComposer: EditComposer,
AssistantMessage: AssistantMessage,
}}
/>
<ThreadPrimitive.If empty={false}>
<div className="min-h-8 flex-grow" />
</ThreadPrimitive.If>
<div className="sticky bottom-0 mt-3 flex w-full max-w-[var(--thread-max-width)] flex-col items-center justify-end rounded-t-lg bg-inherit pb-4">
<ThreadScrollToBottom />
<Composer />
</div>
</ThreadPrimitive.Viewport>
</ThreadPrimitive.Root>
);
};
const ThreadScrollToBottom: FC = () => {
return (
<ThreadPrimitive.ScrollToBottom asChild>
<TooltipIconButton
tooltip="Scroll to bottom"
variant="outline"
className="absolute -top-8 rounded-full disabled:invisible"
>
<ArrowDownIcon />
</TooltipIconButton>
</ThreadPrimitive.ScrollToBottom>
);
};
const ThreadWelcome: FC = () => {
return (
<ThreadPrimitive.Empty>
<div className="flex w-full max-w-[var(--thread-max-width)] flex-grow flex-col">
<div className="flex w-full flex-grow flex-col items-center justify-center">
<Avatar>
<AvatarFallback>C</AvatarFallback>
</Avatar>
<p className="mt-4 font-medium">How can I help you today?</p>
</div>
<ThreadWelcomeSuggestions />
</div>
</ThreadPrimitive.Empty>
);
};
const ThreadWelcomeSuggestions: FC = () => {
return (
<div className="mt-3 flex w-full items-stretch justify-center gap-4">
<ThreadPrimitive.Suggestion
className="hover:bg-muted/80 flex max-w-sm grow basis-0 flex-col items-center justify-center rounded-lg border p-3 transition-colors ease-in"
prompt="What's the current price of $APPL?"
method="replace"
autoSend
>
<span className="line-clamp-2 text-ellipsis text-sm font-semibold">
What's the current price of $APPL?
</span>
</ThreadPrimitive.Suggestion>
<ThreadPrimitive.Suggestion
className="hover:bg-muted/80 flex max-w-sm grow basis-0 flex-col items-center justify-center rounded-lg border p-3 transition-colors ease-in"
prompt="What is assistant-ui?"
method="replace"
autoSend
>
<span className="line-clamp-2 text-ellipsis text-sm font-semibold">
What's the weather like in San Francisco Today?
</span>
</ThreadPrimitive.Suggestion>
</div>
);
};
const Composer: FC = () => {
return (
<ComposerPrimitive.Root className="focus-within:border-ring/20 flex w-full flex-wrap items-end rounded-lg border bg-inherit px-2.5 shadow-sm transition-colors ease-in">
<ComposerPrimitive.Input
rows={1}
autoFocus
placeholder="Write a message..."
className="placeholder:text-muted-foreground max-h-40 flex-grow resize-none border-none bg-transparent px-2 py-4 text-sm outline-none focus:ring-0 disabled:cursor-not-allowed"
/>
<ComposerAction />
</ComposerPrimitive.Root>
);
};
const ComposerAction: FC = () => {
return (
<>
<ThreadPrimitive.If running={false}>
<ComposerPrimitive.Send asChild>
<TooltipIconButton
tooltip="Send"
variant="default"
className="my-2.5 size-8 p-2 transition-opacity ease-in"
>
<SendHorizontalIcon />
</TooltipIconButton>
</ComposerPrimitive.Send>
</ThreadPrimitive.If>
<ThreadPrimitive.If running>
<ComposerPrimitive.Cancel asChild>
<TooltipIconButton
tooltip="Cancel"
variant="default"
className="my-2.5 size-8 p-2 transition-opacity ease-in"
>
<CircleStopIcon />
</TooltipIconButton>
</ComposerPrimitive.Cancel>
</ThreadPrimitive.If>
</>
);
};
const UserMessage: FC = () => {
return (
<MessagePrimitive.Root className="grid auto-rows-auto grid-cols-[minmax(72px,1fr)_auto] gap-y-2 [&:where(>*)]:col-start-2 w-full max-w-[var(--thread-max-width)] py-4">
<UserActionBar />
<div className="bg-muted text-foreground max-w-[calc(var(--thread-max-width)*0.8)] break-words rounded-3xl px-5 py-2.5 col-start-2 row-start-2">
<MessagePrimitive.Content />
</div>
</MessagePrimitive.Root>
);
};
const UserActionBar: FC = () => {
return (
<ActionBarPrimitive.Root
hideWhenRunning
autohide="not-last"
className="flex flex-col items-end col-start-1 row-start-2 mr-3 mt-2.5"
>
<ActionBarPrimitive.Edit asChild>
<TooltipIconButton tooltip="Edit">
<PencilIcon />
</TooltipIconButton>
</ActionBarPrimitive.Edit>
</ActionBarPrimitive.Root>
);
};
const EditComposer: FC = () => {
return (
<ComposerPrimitive.Root className="bg-muted my-4 flex w-full max-w-[var(--thread-max-width)] flex-col gap-2 rounded-xl">
<ComposerPrimitive.Input className="text-foreground flex h-8 w-full resize-none bg-transparent p-4 pb-0 outline-none" />
<div className="mx-3 mb-3 flex items-center justify-center gap-2 self-end">
<ComposerPrimitive.Cancel asChild>
<Button variant="ghost">Cancel</Button>
</ComposerPrimitive.Cancel>
<ComposerPrimitive.Send asChild>
<Button>Send</Button>
</ComposerPrimitive.Send>
</div>
</ComposerPrimitive.Root>
);
};
function CustomComponent({
message,
idx,
thread,
}: {
message: Message;
idx: number;
thread: ReturnType<typeof useStreamContext>;
}) {
const meta = thread.getMessagesMetadata(message, idx);
const seenState = meta?.firstSeenState;
const customComponent = seenState?.values.ui
.slice()
.reverse()
.find(
({ additional_kwargs }) =>
additional_kwargs.run_id === seenState.metadata?.run_id,
);
if (!customComponent) {
return null;
}
return (
<div key={message.id}>
{customComponent && (
<LoadExternalComponent
assistantId="agent"
stream={thread}
message={customComponent}
/>
)}
</div>
);
}
const AssistantMessage: FC = () => {
const thread = useStreamContext();
const assistantMsgs = useMessage((m) => {
const langchainMessage = getExternalStoreMessages<Message>(m);
return langchainMessage;
});
const assistantMsg = assistantMsgs[0];
let threadMsgIdx: number | undefined = undefined;
const threadMsg = thread.messages.find((m, idx) => {
if (m.id === assistantMsg?.id) {
threadMsgIdx = idx;
return true;
}
});
return (
<MessagePrimitive.Root className="grid grid-cols-[auto_auto_1fr] grid-rows-[auto_1fr] relative w-full max-w-[var(--thread-max-width)] py-4">
<Avatar className="col-start-1 row-span-full row-start-1 mr-4">
<AvatarFallback>A</AvatarFallback>
</Avatar>
{threadMsg && threadMsgIdx !== undefined && (
<CustomComponent
message={threadMsg}
idx={threadMsgIdx}
thread={thread}
/>
)}
<div className="text-foreground max-w-[calc(var(--thread-max-width)*0.8)] break-words leading-7 col-span-2 col-start-2 row-start-1 my-1.5">
<MessagePrimitive.Content components={{ Text: MarkdownText }} />
</div>
</MessagePrimitive.Root>
);
};
const CircleStopIcon = () => {
return (
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
width="16"
height="16"
>
<rect width="10" height="10" x="3" y="3" rx="2" />
</svg>
);
};

View File

@@ -0,0 +1,138 @@
import { useEffect, useRef } from "react";
import { cn } from "@/lib/utils";
import { useStreamContext } from "@/providers/Stream";
import { useState, FormEvent } from "react";
import { Input } from "../ui/input";
import { Button } from "../ui/button";
import { Message } from "@langchain/langgraph-sdk";
import { AssistantMessage, AssistantMessageLoading } from "./messages/ai";
import { HumanMessage } from "./messages/human";
// const dummyMessages = [
// { type: "human", content: "Hi! What can you do?" },
// {
// type: "ai",
// content: `Hello! I can assist you with a variety of tasks, including:
// 1. **Answering Questions**: I can provide information on a wide range of topics, from science and history to technology and culture.
// 2. **Writing Assistance**: I can help you draft emails, essays, reports, and creative writing pieces.
// 3. **Learning Support**: I can explain concepts, help with homework, and provide study tips.
// 4. **Language Help**: I can assist with translations, grammar, and vocabulary in multiple languages.
// 5. **Recommendations**: I can suggest books, movies, recipes, and more based on your interests.
// 6. **General Advice**: I can offer tips on various subjects, including productivity, wellness, and personal development.
// If you have something specific in mind, feel free to ask!`,
// },
// ];
export function Thread() {
const [input, setInput] = useState("");
const [firstTokenReceived, setFirstTokenReceived] = useState(false);
const stream = useStreamContext();
// const messages = [...dummyMessages, ...stream.messages];
const messages = stream.messages;
const isLoading = stream.isLoading;
const prevMessageLength = useRef(0);
useEffect(() => {
if (
messages.length !== prevMessageLength.current &&
messages?.length &&
messages[messages.length - 1].type === "ai"
) {
setFirstTokenReceived(true);
prevMessageLength.current = messages.length;
}
}, [messages]);
const handleSubmit = (e: FormEvent) => {
e.preventDefault();
if (!input.trim() || isLoading) return;
setFirstTokenReceived(false);
stream.submit(
{
messages: [{ type: "human", content: input }],
},
{
streamMode: ["values"],
},
);
setInput("");
};
const chatStarted = isLoading || messages.length > 0;
return (
<div
className={cn(
"flex flex-col w-full h-full",
chatStarted ? "relative" : "",
)}
>
<div className={cn("flex-1 px-4", chatStarted ? "pb-28" : "mt-64")}>
<h1
className={cn(
"text-2xl font-medium mb-12 text-center",
chatStarted && "hidden",
)}
>
Chat
</h1>
<div
className={cn(
"flex flex-col gap-4 max-w-4xl w-full mx-auto mt-12 overflow-y-auto",
!chatStarted && "hidden",
)}
>
{messages.map((message, index) =>
message.type === "human" ? (
<HumanMessage
key={"id" in message ? message.id : `${message.type}-${index}`}
message={message as Message}
isLoading={isLoading}
/>
) : (
<AssistantMessage
key={"id" in message ? message.id : `${message.type}-${index}`}
message={message as Message}
isLoading={isLoading}
/>
),
)}
{isLoading && !firstTokenReceived && <AssistantMessageLoading />}
</div>
</div>
<div
className={cn(
"bg-white rounded-2xl border-[1px] border-gray-200 shadow-md p-3 mx-auto w-full max-w-5xl",
chatStarted ? "fixed bottom-6 left-0 right-0" : "",
)}
>
<form
onSubmit={handleSubmit}
className="flex w-full gap-2 max-w-5xl mx-auto"
>
<Input
type="text"
value={input}
onChange={(e) => setInput(e.target.value)}
disabled={isLoading}
placeholder="Type your message..."
className="p-5 border-[0px] shadow-none ring-0 outline-none focus:outline-none focus:ring-0"
/>
<Button
type="submit"
className="p-5"
disabled={isLoading || !input.trim()}
>
Send
</Button>
</form>
</div>
</div>
);
}

View File

@@ -4,24 +4,22 @@ import "@assistant-ui/react-markdown/styles/dot.css";
import { import {
CodeHeaderProps, CodeHeaderProps,
MarkdownTextPrimitive,
unstable_memoizeMarkdownComponents as memoizeMarkdownComponents, unstable_memoizeMarkdownComponents as memoizeMarkdownComponents,
useIsMarkdownCodeBlock, useIsMarkdownCodeBlock,
} from "@assistant-ui/react-markdown"; } from "@assistant-ui/react-markdown";
import ReactMarkdown from "react-markdown";
import remarkGfm from "remark-gfm"; import remarkGfm from "remark-gfm";
import { FC, memo, useState } from "react"; import { FC, memo, useState } from "react";
import { CheckIcon, CopyIcon } from "lucide-react"; import { CheckIcon, CopyIcon } from "lucide-react";
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; import { TooltipIconButton } from "@/components/thread/tooltip-icon-button";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
const MarkdownTextImpl = () => { const MarkdownTextImpl = ({ children }: { children: string }) => {
return ( return (
<MarkdownTextPrimitive <ReactMarkdown remarkPlugins={[remarkGfm]} components={defaultComponents}>
remarkPlugins={[remarkGfm]} {children}
className="aui-md" </ReactMarkdown>
components={defaultComponents}
/>
); );
}; };

View File

@@ -0,0 +1,66 @@
import { useStreamContext } from "@/providers/Stream";
import { Message } from "@langchain/langgraph-sdk";
import { getContentString } from "../utils";
import { BranchSwitcher, CommandBar } from "./shared";
import { Avatar, AvatarFallback } from "@/components/ui/avatar";
import { MarkdownText } from "../markdown-text";
export function AssistantMessage({
message,
isLoading,
}: {
message: Message;
isLoading: boolean;
}) {
const thread = useStreamContext();
const meta = thread.getMessagesMetadata(message);
const parentCheckpoint = meta?.firstSeenState?.parent_checkpoint;
const contentString = getContentString(message.content);
const handleRegenerate = () => {
thread.submit(undefined, { checkpoint: parentCheckpoint });
};
return (
<div className="flex items-start mr-auto gap-2 group">
<Avatar>
<AvatarFallback>A</AvatarFallback>
</Avatar>
<div className="flex flex-col gap-2">
<div className="rounded-2xl bg-muted px-4 py-2">
<MarkdownText>{contentString}</MarkdownText>
</div>
<div className="flex gap-2 items-center mr-auto opacity-0 group-hover:opacity-100 transition-opacity">
<BranchSwitcher
branch={meta?.branch}
branchOptions={meta?.branchOptions}
onSelect={(branch) => thread.setBranch(branch)}
isLoading={isLoading}
/>
<CommandBar
content={contentString}
isLoading={isLoading}
isAiMessage={true}
handleRegenerate={handleRegenerate}
/>
</div>
</div>
</div>
);
}
export function AssistantMessageLoading() {
return (
<div className="flex items-start mr-auto gap-2">
<Avatar>
<AvatarFallback>A</AvatarFallback>
</Avatar>
<div className="flex items-center gap-1 rounded-2xl bg-muted px-4 py-2 h-8">
<div className="w-1.5 h-1.5 rounded-full bg-foreground/50 animate-[pulse_1.5s_ease-in-out_infinite]"></div>
<div className="w-1.5 h-1.5 rounded-full bg-foreground/50 animate-[pulse_1.5s_ease-in-out_0.5s_infinite]"></div>
<div className="w-1.5 h-1.5 rounded-full bg-foreground/50 animate-[pulse_1.5s_ease-in-out_1s_infinite]"></div>
</div>
</div>
);
}

View File

@@ -0,0 +1,107 @@
import { useStreamContext } from "@/providers/Stream";
import { Message } from "@langchain/langgraph-sdk";
import { useState } from "react";
import { getContentString } from "../utils";
import { cn } from "@/lib/utils";
import { Textarea } from "@/components/ui/textarea";
import { BranchSwitcher, CommandBar } from "./shared";
function EditableContent({
value,
setValue,
onSubmit,
}: {
value: string;
setValue: React.Dispatch<React.SetStateAction<string>>;
onSubmit: () => void;
}) {
const handleKeyDown = (e: React.KeyboardEvent) => {
if ((e.metaKey || e.ctrlKey) && e.key === "Enter") {
e.preventDefault();
onSubmit();
}
};
return (
<Textarea
value={value}
onChange={(e) => setValue(e.target.value)}
onKeyDown={handleKeyDown}
/>
);
}
export function HumanMessage({
message,
isLoading,
}: {
message: Message;
isLoading: boolean;
}) {
const thread = useStreamContext();
const meta = thread.getMessagesMetadata(message);
const parentCheckpoint = meta?.firstSeenState?.parent_checkpoint;
const [isEditing, setIsEditing] = useState(false);
const [value, setValue] = useState("");
const contentString = getContentString(message.content);
const handleSubmitEdit = () => {
setIsEditing(false);
thread.submit(
{
messages: [
{
...message,
content: value,
},
],
},
{
checkpoint: parentCheckpoint,
},
);
};
return (
<div
className={cn(
"flex items-center ml-auto gap-2 px-4 py-2 group",
isEditing && "w-full max-w-xl",
)}
>
<div className={cn("flex flex-col gap-2", isEditing && "w-full")}>
{isEditing ? (
<EditableContent
value={value}
setValue={setValue}
onSubmit={handleSubmitEdit}
/>
) : (
<p>{contentString}</p>
)}
<div className="flex gap-2 items-center ml-auto opacity-0 group-hover:opacity-100 transition-opacity">
<BranchSwitcher
branch={meta?.branch}
branchOptions={meta?.branchOptions}
onSelect={(branch) => thread.setBranch(branch)}
isLoading={isLoading}
/>
<CommandBar
isLoading={isLoading}
content={contentString}
isEditing={isEditing}
setIsEditing={(c) => {
if (c) {
setValue(contentString);
}
setIsEditing(c);
}}
handleSubmitEdit={handleSubmitEdit}
isHumanMessage={true}
/>
</div>
</div>
</div>
);
}

View File

@@ -0,0 +1,213 @@
import {
XIcon,
SendHorizontal,
RefreshCcw,
Pencil,
Copy,
CopyCheck,
ChevronLeft,
ChevronRight,
} from "lucide-react";
import { TooltipIconButton } from "../tooltip-icon-button";
import { AnimatePresence, motion } from "framer-motion";
import { useState } from "react";
import { Button } from "@/components/ui/button";
function ContentCopyable({
content,
disabled,
}: {
content: string;
disabled: boolean;
}) {
const [copied, setCopied] = useState(false);
const handleCopy = (e: React.MouseEvent<HTMLButtonElement, MouseEvent>) => {
e.stopPropagation();
navigator.clipboard.writeText(content);
setCopied(true);
setTimeout(() => setCopied(false), 2000);
};
return (
<TooltipIconButton
onClick={(e: any) => handleCopy(e)}
variant="ghost"
tooltip="Copy content"
disabled={disabled}
>
<AnimatePresence mode="wait" initial={false}>
{copied ? (
<motion.div
key="check"
initial={{ opacity: 0, scale: 0.8 }}
animate={{ opacity: 1, scale: 1 }}
exit={{ opacity: 0, scale: 0.8 }}
transition={{ duration: 0.15 }}
>
<CopyCheck className="text-green-500" />
</motion.div>
) : (
<motion.div
key="copy"
initial={{ opacity: 0, scale: 0.8 }}
animate={{ opacity: 1, scale: 1 }}
exit={{ opacity: 0, scale: 0.8 }}
transition={{ duration: 0.15 }}
>
<Copy />
</motion.div>
)}
</AnimatePresence>
</TooltipIconButton>
);
}
export function BranchSwitcher({
branch,
branchOptions,
onSelect,
isLoading,
}: {
branch: string | undefined;
branchOptions: string[] | undefined;
onSelect: (branch: string) => void;
isLoading: boolean;
}) {
if (!branchOptions || !branch) return null;
const index = branchOptions.indexOf(branch);
return (
<div className="flex items-center gap-2">
<Button
variant="ghost"
size="icon"
onClick={() => {
const prevBranch = branchOptions[index - 1];
if (!prevBranch) return;
onSelect(prevBranch);
}}
disabled={isLoading}
>
<ChevronLeft />
</Button>
<span className="text-sm">
{index + 1} / {branchOptions.length}
</span>
<Button
variant="ghost"
size="icon"
onClick={() => {
const nextBranch = branchOptions[index + 1];
if (!nextBranch) return;
onSelect(nextBranch);
}}
disabled={isLoading}
>
<ChevronRight />
</Button>
</div>
);
}
export function CommandBar({
content,
isHumanMessage,
isAiMessage,
isEditing,
setIsEditing,
handleSubmitEdit,
handleRegenerate,
isLoading,
}: {
content: string;
isHumanMessage?: boolean;
isAiMessage?: boolean;
isEditing?: boolean;
setIsEditing?: React.Dispatch<React.SetStateAction<boolean>>;
handleSubmitEdit?: () => void;
handleRegenerate?: () => void;
isLoading: boolean;
}) {
if (isHumanMessage && isAiMessage) {
throw new Error(
"Can only set one of isHumanMessage or isAiMessage to true, not both.",
);
}
if (!isHumanMessage && !isAiMessage) {
throw new Error(
"One of isHumanMessage or isAiMessage must be set to true.",
);
}
if (
isHumanMessage &&
(isEditing === undefined ||
setIsEditing === undefined ||
handleSubmitEdit === undefined)
) {
throw new Error(
"If isHumanMessage is true, all of isEditing, setIsEditing, and handleSubmitEdit must be set.",
);
}
const showEdit =
isHumanMessage &&
isEditing !== undefined &&
!!setIsEditing &&
!!handleSubmitEdit;
if (isHumanMessage && isEditing && !!setIsEditing && !!handleSubmitEdit) {
return (
<div className="flex items-center gap-2">
<TooltipIconButton
disabled={isLoading}
tooltip="Cancel edit"
variant="ghost"
onClick={() => {
setIsEditing(false);
}}
>
<XIcon />
</TooltipIconButton>
<TooltipIconButton
disabled={isLoading}
tooltip="Submit"
variant="secondary"
onClick={handleSubmitEdit}
>
<SendHorizontal />
</TooltipIconButton>
</div>
);
}
return (
<div className="flex items-center gap-2">
<ContentCopyable content={content} disabled={isLoading} />
{isAiMessage && !!handleRegenerate && (
<TooltipIconButton
disabled={isLoading}
tooltip="Refresh"
variant="ghost"
onClick={handleRegenerate}
>
<RefreshCcw />
</TooltipIconButton>
)}
{showEdit && (
<TooltipIconButton
disabled={isLoading}
tooltip="Edit"
variant="ghost"
onClick={() => {
setIsEditing?.(true);
}}
>
<Pencil />
</TooltipIconButton>
)}
</div>
);
}

View File

@@ -0,0 +1,9 @@
import { MessageContent } from "@langchain/core/messages";
export function getContentString(content: MessageContent): string {
if (typeof content === "string") return content;
const texts = content
.filter((c): c is { type: "text"; text: string } => c.type === "text")
.map((c) => c.text);
return texts.join(" ");
}

View File

@@ -34,16 +34,18 @@ const buttonVariants = cva(
}, },
); );
type ButtonProps = React.ComponentProps<"button"> &
VariantProps<typeof buttonVariants> & {
asChild?: boolean;
};
function Button({ function Button({
className, className,
variant, variant,
size, size,
asChild = false, asChild = false,
...props ...props
}: React.ComponentProps<"button"> & }: ButtonProps) {
VariantProps<typeof buttonVariants> & {
asChild?: boolean;
}) {
const Comp = asChild ? Slot : "button"; const Comp = asChild ? Slot : "button";
return ( return (
@@ -55,4 +57,4 @@ function Button({
); );
} }
export { Button, buttonVariants }; export { Button, buttonVariants, type ButtonProps };

View File

@@ -0,0 +1,20 @@
import * as React from "react";
import { cn } from "@/lib/utils";
function Input({ className, type, ...props }: React.ComponentProps<"input">) {
return (
<input
type={type}
data-slot="input"
className={cn(
"border-input file:text-foreground placeholder:text-muted-foreground selection:bg-primary selection:text-primary-foreground flex h-9 w-full min-w-0 rounded-md border bg-transparent px-3 py-1 text-base shadow-xs transition-[color,box-shadow] outline-none file:inline-flex file:h-7 file:border-0 file:bg-transparent file:text-sm file:font-medium disabled:pointer-events-none disabled:cursor-not-allowed disabled:opacity-50 md:text-sm",
"aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive",
className,
)}
{...props}
/>
);
}
export { Input };

View File

@@ -0,0 +1,18 @@
import * as React from "react";
import { cn } from "@/lib/utils";
function Textarea({ className, ...props }: React.ComponentProps<"textarea">) {
return (
<textarea
data-slot="textarea"
className={cn(
"border-input placeholder:text-muted-foreground focus-visible:border-ring focus-visible:ring-ring/50 aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive flex field-sizing-content min-h-16 w-full rounded-md border bg-transparent px-3 py-2 text-base shadow-xs transition-[color,box-shadow] outline-none focus-visible:ring-[3px] disabled:cursor-not-allowed disabled:opacity-50 md:text-sm",
className,
)}
{...props}
/>
);
}
export { Textarea };

View File

@@ -1,13 +1,10 @@
import { createRoot } from "react-dom/client"; import { createRoot } from "react-dom/client";
import "./index.css"; import "./index.css";
import App from "./App.tsx"; import App from "./App.tsx";
import { RuntimeProvider } from "./providers/Runtime.tsx";
import { StreamProvider } from "./providers/Stream.tsx"; import { StreamProvider } from "./providers/Stream.tsx";
createRoot(document.getElementById("root")!).render( createRoot(document.getElementById("root")!).render(
<StreamProvider> <StreamProvider>
<RuntimeProvider> <App />
<App />
</RuntimeProvider>
</StreamProvider>, </StreamProvider>,
); );

View File

@@ -1,78 +0,0 @@
import { ReactNode, useEffect } from "react";
import {
useExternalStoreRuntime,
AppendMessage,
AssistantRuntimeProvider,
} from "@assistant-ui/react";
import { HumanMessage, Message, ToolMessage } from "@langchain/langgraph-sdk";
import { useStreamContext } from "./Stream";
import { convertLangChainMessages } from "./convert-messages";
function ensureToolCallsHaveResponses(messages: Message[]): Message[] {
const newMessages: ToolMessage[] = [];
messages.forEach((message, index) => {
if (message.type !== "ai" || message.tool_calls?.length === 0) {
// If it's not an AI message, or it doesn't have tool calls, we can ignore.
return;
}
// If it has tool calls, ensure the message which follows this is a tool message
const followingMessage = messages[index + 1];
if (followingMessage && followingMessage.type === "tool") {
// Following message is a tool message, so we can ignore.
return;
}
// Since the following message is not a tool message, we must create a new tool message
newMessages.push(
...(message.tool_calls?.map((tc) => ({
type: "tool" as const,
tool_call_id: tc.id ?? "",
id: tc.id ?? "",
name: tc.name,
content: "Successfully handled tool call.",
})) ?? [])
);
});
return newMessages;
}
export function RuntimeProvider({
children,
}: Readonly<{
children: ReactNode;
}>) {
const stream = useStreamContext();
const onNew = async (message: AppendMessage) => {
if (message.content[0]?.type !== "text")
throw new Error("Only text messages are supported");
const input = message.content[0].text;
const humanMessage: HumanMessage = { type: "human", content: input };
const newMessages = [
...ensureToolCallsHaveResponses(stream.messages),
humanMessage,
];
console.log("Sending new messages", newMessages);
stream.submit({ messages: newMessages }, { streamMode: ["values"] });
};
useEffect(() => {
console.log("useEffect - stream.messages", stream.messages);
}, [stream.messages]);
const runtime = useExternalStoreRuntime({
isRunning: stream.isLoading,
messages: stream.messages,
convertMessage: convertLangChainMessages,
onNew,
});
return (
<AssistantRuntimeProvider runtime={runtime}>
{children}
</AssistantRuntimeProvider>
);
}

View File

@@ -13,7 +13,7 @@ const useTypedStream = useStream<
messages?: Message[] | Message | string; messages?: Message[] | Message | string;
ui?: (UIMessage | RemoveUIMessage)[] | UIMessage | RemoveUIMessage; ui?: (UIMessage | RemoveUIMessage)[] | UIMessage | RemoveUIMessage;
}; };
CustomType: UIMessage | RemoveUIMessage; CustomUpdateType: UIMessage | RemoveUIMessage;
} }
>; >;

View File

@@ -1,122 +0,0 @@
import { ThreadMessageLike, ToolCallContentPart } from "@assistant-ui/react";
import { Message, AIMessage, ToolMessage } from "@langchain/langgraph-sdk";
export const getMessageType = (message: Record<string, any>): string => {
if (Array.isArray(message.id)) {
const lastItem = message.id[message.id.length - 1];
if (lastItem.startsWith("HumanMessage")) {
return "human";
} else if (lastItem.startsWith("AIMessage")) {
return "ai";
} else if (lastItem.startsWith("ToolMessage")) {
return "tool";
} else if (
lastItem.startsWith("BaseMessage") ||
lastItem.startsWith("SystemMessage")
) {
return "system";
}
}
if ("getType" in message && typeof message.getType === "function") {
return message.getType();
} else if ("_getType" in message && typeof message._getType === "function") {
return message._getType();
} else if ("type" in message) {
return message.type as string;
} else {
console.error(message);
throw new Error("Unsupported message type");
}
};
function getMessageContentOrThrow(message: unknown): string {
if (typeof message !== "object" || message === null) {
return "";
}
const castMsg = message as Record<string, any>;
if (
typeof castMsg?.content !== "string" &&
(!Array.isArray(castMsg.content) || castMsg.content[0]?.type !== "text") &&
(!castMsg.kwargs ||
!castMsg.kwargs?.content ||
typeof castMsg.kwargs?.content !== "string")
) {
console.error(castMsg);
throw new Error("Only text messages are supported");
}
let content = "";
if (Array.isArray(castMsg.content) && castMsg.content[0]?.type === "text") {
content = castMsg.content[0].text;
} else if (typeof castMsg.content === "string") {
content = castMsg.content;
} else if (
castMsg?.kwargs &&
castMsg.kwargs?.content &&
typeof castMsg.kwargs?.content === "string"
) {
content = castMsg.kwargs.content;
}
return content;
}
export function convertLangChainMessages(message: Message): ThreadMessageLike {
const content = getMessageContentOrThrow(message);
switch (getMessageType(message)) {
case "system":
return {
role: "system",
id: message.id,
content: [{ type: "text", text: content }],
};
case "human":
return {
role: "user",
id: message.id,
content: [{ type: "text", text: content }],
};
case "ai":
const aiMsg = message as AIMessage;
const toolCallsContent: ToolCallContentPart[] = aiMsg.tool_calls?.length
? aiMsg.tool_calls.map((tc) => ({
type: "tool-call" as const,
toolCallId: tc.id ?? "",
toolName: tc.name,
args: tc.args,
argsText: JSON.stringify(tc.args),
}))
: [];
return {
role: "assistant",
id: message.id,
content: [
...toolCallsContent,
{
type: "text",
text: content,
},
],
};
case "tool":
const toolMsg = message as ToolMessage;
return {
role: "assistant",
content: [
{
type: "tool-call",
toolName: toolMsg.name ?? "ToolCall",
toolCallId: toolMsg.tool_call_id,
result: content,
},
],
};
default:
console.error(message);
throw new Error(`Unsupported message type: ${getMessageType(message)}`);
}
}