From 76bf8b25336aba4e373c1f29346bda75cdd69d66 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Wed, 5 Mar 2025 15:28:27 -0800 Subject: [PATCH] feat: Refetch threads when a new thread is created --- src/components/thread/history/index.tsx | 19 +++--- src/main.tsx | 9 ++- src/providers/Stream.tsx | 13 +++- src/providers/Thread.tsx | 82 +++++++++++++++++++++++++ 4 files changed, 111 insertions(+), 12 deletions(-) create mode 100644 src/providers/Thread.tsx diff --git a/src/components/thread/history/index.tsx b/src/components/thread/history/index.tsx index 16f73e7..79f123d 100644 --- a/src/components/thread/history/index.tsx +++ b/src/components/thread/history/index.tsx @@ -1,7 +1,7 @@ import { Button } from "@/components/ui/button"; -import { useThreads } from "@/hooks/useThreads"; +import { useThreads } from "@/providers/Thread"; import { Thread } from "@langchain/langgraph-sdk"; -import { useEffect, useState } from "react"; +import { useEffect } from "react"; import { getContentString } from "../utils"; import { useQueryParam, StringParam, BooleanParam } from "use-query-params"; import { @@ -67,29 +67,32 @@ function ThreadHistoryLoading() { } export default function ThreadHistory() { - const [threads, setThreads] = useState([]); - const [loading, setLoading] = useState(false); const [chatHistoryOpen, setChatHistoryOpen] = useQueryParam( "chatHistoryOpen", BooleanParam, ); - const { getThreads } = useThreads(); + const { getThreads, threads, setThreads, threadsLoading, setThreadsLoading } = + useThreads(); useEffect(() => { if (typeof window === "undefined") return; - setLoading(true); + setThreadsLoading(true); getThreads() .then(setThreads) .catch(console.error) - .finally(() => setLoading(false)); + .finally(() => setThreadsLoading(false)); }, []); return ( <>

Thread History

- {loading ? : } + {threadsLoading ? ( + + ) : ( + + )}
diff --git a/src/main.tsx b/src/main.tsx index d81f850..97321f0 100644 --- a/src/main.tsx +++ b/src/main.tsx @@ -2,6 +2,7 @@ import { createRoot } from "react-dom/client"; import "./index.css"; import App from "./App.tsx"; import { StreamProvider } from "./providers/Stream.tsx"; +import { ThreadProvider } from "./providers/Thread.tsx"; import { QueryParamProvider } from "use-query-params"; import { ReactRouter6Adapter } from "use-query-params/adapters/react-router-6"; import { BrowserRouter } from "react-router-dom"; @@ -10,9 +11,11 @@ import { Toaster } from "@/components/ui/sonner"; createRoot(document.getElementById("root")!).render( - - - + + + + + , diff --git a/src/providers/Stream.tsx b/src/providers/Stream.tsx index baaa00f..bd2911a 100644 --- a/src/providers/Stream.tsx +++ b/src/providers/Stream.tsx @@ -13,6 +13,7 @@ import { Label } from "@/components/ui/label"; import { ArrowRight } from "lucide-react"; import { PasswordInput } from "@/components/ui/password-input"; import { getApiKey } from "@/lib/api-key"; +import { useThreads } from "./Thread"; export type StateType = { messages: Message[]; ui?: UIMessage[] }; @@ -30,6 +31,10 @@ const useTypedStream = useStream< type StreamContextType = ReturnType; const StreamContext = createContext(undefined); +async function sleep(ms = 4000) { + return new Promise((resolve) => setTimeout(resolve, ms)); +} + const StreamSession = ({ children, apiKey, @@ -42,12 +47,18 @@ const StreamSession = ({ assistantId: string; }) => { const [threadId, setThreadId] = useQueryParam("threadId", StringParam); + const { getThreads, setThreads } = useThreads(); const streamValue = useTypedStream({ apiUrl, apiKey: apiKey ?? undefined, assistantId, threadId: threadId ?? null, - onThreadId: setThreadId, + onThreadId: (id) => { + setThreadId(id); + // Refetch threads list when thread ID changes. + // Wait for some seconds before fetching so we're able to get the new thread that was created. + sleep().then(() => getThreads().then(setThreads).catch(console.error)); + }, }); return ( diff --git a/src/providers/Thread.tsx b/src/providers/Thread.tsx new file mode 100644 index 0000000..581d028 --- /dev/null +++ b/src/providers/Thread.tsx @@ -0,0 +1,82 @@ +import { validate } from "uuid"; +import { getApiKey } from "@/lib/api-key"; +import { Client, Thread } from "@langchain/langgraph-sdk"; +import { useQueryParam, StringParam } from "use-query-params"; +import { + createContext, + useContext, + ReactNode, + useCallback, + useState, + Dispatch, + SetStateAction, +} from "react"; + +interface ThreadContextType { + getThreads: () => Promise; + threads: Thread[]; + setThreads: Dispatch>; + threadsLoading: boolean; + setThreadsLoading: Dispatch>; +} + +const ThreadContext = createContext(undefined); + +function createClient(apiUrl: string, apiKey: string | undefined) { + return new Client({ + apiKey, + apiUrl, + }); +} + +function getThreadSearchMetadata( + assistantId: string, +): { graph_id: string } | { assistant_id: string } { + if (validate(assistantId)) { + return { assistant_id: assistantId }; + } else { + return { graph_id: assistantId }; + } +} + +export function ThreadProvider({ children }: { children: ReactNode }) { + const [apiUrl] = useQueryParam("apiUrl", StringParam); + const [assistantId] = useQueryParam("assistantId", StringParam); + const [threads, setThreads] = useState([]); + const [threadsLoading, setThreadsLoading] = useState(false); + + const getThreads = useCallback(async (): Promise => { + if (!apiUrl || !assistantId) return []; + + const client = createClient(apiUrl, getApiKey() ?? undefined); + + const threads = await client.threads.search({ + metadata: { + ...getThreadSearchMetadata(assistantId), + }, + limit: 100, + }); + + return threads; + }, [apiUrl, assistantId]); + + const value = { + getThreads, + threads, + setThreads, + threadsLoading, + setThreadsLoading, + }; + + return ( + {children} + ); +} + +export function useThreads() { + const context = useContext(ThreadContext); + if (context === undefined) { + throw new Error("useThreads must be used within a ThreadProvider"); + } + return context; +}