feat: Refetch threads when a new thread is created

This commit is contained in:
bracesproul
2025-03-05 15:28:27 -08:00
parent b5ed8d1b10
commit 76bf8b2533
4 changed files with 111 additions and 12 deletions

View File

@@ -1,7 +1,7 @@
import { Button } from "@/components/ui/button";
import { useThreads } from "@/hooks/useThreads";
import { useThreads } from "@/providers/Thread";
import { Thread } from "@langchain/langgraph-sdk";
import { useEffect, useState } from "react";
import { useEffect } from "react";
import { getContentString } from "../utils";
import { useQueryParam, StringParam, BooleanParam } from "use-query-params";
import {
@@ -67,29 +67,32 @@ function ThreadHistoryLoading() {
}
export default function ThreadHistory() {
const [threads, setThreads] = useState<Thread[]>([]);
const [loading, setLoading] = useState<boolean>(false);
const [chatHistoryOpen, setChatHistoryOpen] = useQueryParam(
"chatHistoryOpen",
BooleanParam,
);
const { getThreads } = useThreads();
const { getThreads, threads, setThreads, threadsLoading, setThreadsLoading } =
useThreads();
useEffect(() => {
if (typeof window === "undefined") return;
setLoading(true);
setThreadsLoading(true);
getThreads()
.then(setThreads)
.catch(console.error)
.finally(() => setLoading(false));
.finally(() => setThreadsLoading(false));
}, []);
return (
<>
<div className="hidden lg:flex flex-col border-r-[1px] border-slate-300 items-start justify-start gap-6 h-screen w-[300px] shrink-0 px-2 py-4 shadow-inner-right">
<h1 className="text-2xl font-medium pl-4">Thread History</h1>
{loading ? <ThreadHistoryLoading /> : <ThreadList threads={threads} />}
{threadsLoading ? (
<ThreadHistoryLoading />
) : (
<ThreadList threads={threads} />
)}
</div>
<Sheet open={!!chatHistoryOpen} onOpenChange={setChatHistoryOpen}>
<SheetContent side="left" className="lg:hidden flex">

View File

@@ -2,6 +2,7 @@ import { createRoot } from "react-dom/client";
import "./index.css";
import App from "./App.tsx";
import { StreamProvider } from "./providers/Stream.tsx";
import { ThreadProvider } from "./providers/Thread.tsx";
import { QueryParamProvider } from "use-query-params";
import { ReactRouter6Adapter } from "use-query-params/adapters/react-router-6";
import { BrowserRouter } from "react-router-dom";
@@ -10,9 +11,11 @@ import { Toaster } from "@/components/ui/sonner";
createRoot(document.getElementById("root")!).render(
<BrowserRouter>
<QueryParamProvider adapter={ReactRouter6Adapter}>
<StreamProvider>
<App />
</StreamProvider>
<ThreadProvider>
<StreamProvider>
<App />
</StreamProvider>
</ThreadProvider>
</QueryParamProvider>
<Toaster />
</BrowserRouter>,

View File

@@ -13,6 +13,7 @@ import { Label } from "@/components/ui/label";
import { ArrowRight } from "lucide-react";
import { PasswordInput } from "@/components/ui/password-input";
import { getApiKey } from "@/lib/api-key";
import { useThreads } from "./Thread";
export type StateType = { messages: Message[]; ui?: UIMessage[] };
@@ -30,6 +31,10 @@ const useTypedStream = useStream<
type StreamContextType = ReturnType<typeof useTypedStream>;
const StreamContext = createContext<StreamContextType | undefined>(undefined);
async function sleep(ms = 4000) {
return new Promise((resolve) => setTimeout(resolve, ms));
}
const StreamSession = ({
children,
apiKey,
@@ -42,12 +47,18 @@ const StreamSession = ({
assistantId: string;
}) => {
const [threadId, setThreadId] = useQueryParam("threadId", StringParam);
const { getThreads, setThreads } = useThreads();
const streamValue = useTypedStream({
apiUrl,
apiKey: apiKey ?? undefined,
assistantId,
threadId: threadId ?? null,
onThreadId: setThreadId,
onThreadId: (id) => {
setThreadId(id);
// Refetch threads list when thread ID changes.
// Wait for some seconds before fetching so we're able to get the new thread that was created.
sleep().then(() => getThreads().then(setThreads).catch(console.error));
},
});
return (

82
src/providers/Thread.tsx Normal file
View File

@@ -0,0 +1,82 @@
import { validate } from "uuid";
import { getApiKey } from "@/lib/api-key";
import { Client, Thread } from "@langchain/langgraph-sdk";
import { useQueryParam, StringParam } from "use-query-params";
import {
createContext,
useContext,
ReactNode,
useCallback,
useState,
Dispatch,
SetStateAction,
} from "react";
interface ThreadContextType {
getThreads: () => Promise<Thread[]>;
threads: Thread[];
setThreads: Dispatch<SetStateAction<Thread[]>>;
threadsLoading: boolean;
setThreadsLoading: Dispatch<SetStateAction<boolean>>;
}
const ThreadContext = createContext<ThreadContextType | undefined>(undefined);
function createClient(apiUrl: string, apiKey: string | undefined) {
return new Client({
apiKey,
apiUrl,
});
}
function getThreadSearchMetadata(
assistantId: string,
): { graph_id: string } | { assistant_id: string } {
if (validate(assistantId)) {
return { assistant_id: assistantId };
} else {
return { graph_id: assistantId };
}
}
export function ThreadProvider({ children }: { children: ReactNode }) {
const [apiUrl] = useQueryParam("apiUrl", StringParam);
const [assistantId] = useQueryParam("assistantId", StringParam);
const [threads, setThreads] = useState<Thread[]>([]);
const [threadsLoading, setThreadsLoading] = useState(false);
const getThreads = useCallback(async (): Promise<Thread[]> => {
if (!apiUrl || !assistantId) return [];
const client = createClient(apiUrl, getApiKey() ?? undefined);
const threads = await client.threads.search({
metadata: {
...getThreadSearchMetadata(assistantId),
},
limit: 100,
});
return threads;
}, [apiUrl, assistantId]);
const value = {
getThreads,
threads,
setThreads,
threadsLoading,
setThreadsLoading,
};
return (
<ThreadContext.Provider value={value}>{children}</ThreadContext.Provider>
);
}
export function useThreads() {
const context = useContext(ThreadContext);
if (context === undefined) {
throw new Error("useThreads must be used within a ThreadProvider");
}
return context;
}