From 84cdbbe5507e2feebe8bed857828fecd1becae72 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Fri, 7 Mar 2025 13:42:37 -0800 Subject: [PATCH] feat: implement accept and dont ask again feature --- agent/open-code/index.ts | 39 ++++++- agent/open-code/nodes/executor.ts | 17 ++- agent/open-code/nodes/planner.ts | 32 +++++- agent/uis/open-code/plan/index.tsx | 41 ++++--- agent/uis/open-code/proposed-change/index.tsx | 107 ++++++++++++++---- src/components/thread/messages/tool-calls.tsx | 2 +- 6 files changed, 185 insertions(+), 53 deletions(-) diff --git a/agent/open-code/index.ts b/agent/open-code/index.ts index 91f0a9e..2707cff 100644 --- a/agent/open-code/index.ts +++ b/agent/open-code/index.ts @@ -1,14 +1,45 @@ -import { END, START, StateGraph } from "@langchain/langgraph"; -import { OpenCodeAnnotation } from "./types"; +import { + END, + LangGraphRunnableConfig, + START, + StateGraph, +} from "@langchain/langgraph"; +import { OpenCodeAnnotation, OpenCodeState } from "./types"; import { planner } from "./nodes/planner"; -import { executor } from "./nodes/executor"; +import { + executor, + SUCCESSFULLY_COMPLETED_STEPS_CONTENT, +} from "./nodes/executor"; +import { AIMessage } from "@langchain/langgraph-sdk"; + +function conditionallyEnd( + state: OpenCodeState, + config: LangGraphRunnableConfig, +): typeof END | "planner" { + const fullWriteAccess = !!config.configurable?.permissions?.full_write_access; + const lastAiMessage = state.messages.findLast( + (m) => m.getType() === "ai", + ) as unknown as AIMessage; + + // If the user did not grant full write access, or the last AI message is the success message, end + // otherwise, loop back to the start. + if ( + (typeof lastAiMessage.content === "string" && + lastAiMessage.content === SUCCESSFULLY_COMPLETED_STEPS_CONTENT) || + !fullWriteAccess + ) { + return END; + } + + return "planner"; +} const workflow = new StateGraph(OpenCodeAnnotation) .addNode("planner", planner) .addNode("executor", executor) .addEdge(START, "planner") .addEdge("planner", "executor") - .addEdge("executor", END); + .addConditionalEdges("executor", conditionallyEnd, ["planner", END]); export const graph = workflow.compile(); graph.name = "Open Code Graph"; diff --git a/agent/open-code/nodes/executor.ts b/agent/open-code/nodes/executor.ts index a8e59e4..7452e54 100644 --- a/agent/open-code/nodes/executor.ts +++ b/agent/open-code/nodes/executor.ts @@ -6,6 +6,9 @@ import { LangGraphRunnableConfig } from "@langchain/langgraph"; import ComponentMap from "../../uis"; import { typedUi } from "@langchain/langgraph-sdk/react-ui/server"; +export const SUCCESSFULLY_COMPLETED_STEPS_CONTENT = + "Successfully completed all the steps in the plan. Please let me know if you need anything else!"; + export async function executor( state: OpenCodeState, config: LangGraphRunnableConfig, @@ -24,21 +27,24 @@ export async function executor( const nextPlanItem = planToolCallArgs?.remainingPlans?.[0] as | string | undefined; - const numOfExecutedPlanItems = planToolCallArgs?.executedPlans?.length ?? 0; + const numSeenPlans = + [ + ...(planToolCallArgs?.executedPlans ?? []), + ...(planToolCallArgs?.rejectedPlans ?? []), + ]?.length ?? 0; if (!nextPlanItem) { // All plans have been executed const successfullyFinishedMsg: AIMessage = { type: "ai", id: uuidv4(), - content: - "Successfully completed all the steps in the plan. Please let me know if you need anything else!", + content: SUCCESSFULLY_COMPLETED_STEPS_CONTENT, }; return { messages: [successfullyFinishedMsg] }; } let updateFileContents = ""; - switch (numOfExecutedPlanItems) { + switch (numSeenPlans) { case 0: updateFileContents = await fs.readFile( "agent/open-code/nodes/plan-code/step-1.txt", @@ -101,10 +107,13 @@ export async function executor( ], }; + const fullWriteAccess = !!config.configurable?.permissions?.full_write_access; + const msg = ui.create("proposed-change", { toolCallId, change: updateFileContents, planItem: nextPlanItem, + fullWriteAccess, }); msg.additional_kwargs["message_id"] = aiMessage.id; diff --git a/agent/open-code/nodes/planner.ts b/agent/open-code/nodes/planner.ts index 5c9bf5d..5135750 100644 --- a/agent/open-code/nodes/planner.ts +++ b/agent/open-code/nodes/planner.ts @@ -28,28 +28,46 @@ export async function planner( (tc) => tc.name === "update_file", ), ) as AIMessage | undefined; + const lastUpdateToolCallResponse = state.messages.findLast( + (m) => + m.getType() === "tool" && + (m as unknown as ToolMessage).tool_call_id === + lastUpdateCodeToolCall?.tool_calls?.[0]?.id, + ) as ToolMessage | undefined; const lastPlanToolCall = state.messages.findLast( (m) => m.getType() === "ai" && (m as unknown as AIMessage).tool_calls?.some((tc) => tc.name === "plan"), ) as AIMessage | undefined; + const wasPlanRejected = ( + lastUpdateToolCallResponse?.content as string | undefined + ) + ?.toLowerCase() + .includes("rejected"); + const planToolCallArgs = lastPlanToolCall?.tool_calls?.[0]?.args as Record< string, any >; const executedPlans: string[] = planToolCallArgs?.executedPlans ?? []; + const rejectedPlans: string[] = planToolCallArgs?.rejectedPlans ?? []; let remainingPlans: string[] = planToolCallArgs?.remainingPlans ?? PLAN; - const executedPlanItem = lastUpdateCodeToolCall?.tool_calls?.[0]?.args + const proposedChangePlanItem = lastUpdateCodeToolCall?.tool_calls?.[0]?.args ?.executed_plan_item as string | undefined; - if (executedPlanItem) { - executedPlans.push(executedPlanItem); - remainingPlans = remainingPlans.filter((p) => p !== executedPlanItem); + if (proposedChangePlanItem) { + if (wasPlanRejected) { + rejectedPlans.push(proposedChangePlanItem); + } else { + executedPlans.push(proposedChangePlanItem); + } + + remainingPlans = remainingPlans.filter((p) => p !== proposedChangePlanItem); } - const content = executedPlanItem - ? `I've updated the plan list based on the executed plans.` + const content = proposedChangePlanItem + ? `I've updated the plan list based on the last proposed change.` : `I've come up with a detailed plan for building the todo app.`; const toolCallId = uuidv4(); @@ -62,6 +80,7 @@ export async function planner( name: "plan", args: { executedPlans, + rejectedPlans, remainingPlans, }, id: toolCallId, @@ -73,6 +92,7 @@ export async function planner( const msg = ui.create("code-plan", { toolCallId, executedPlans, + rejectedPlans, remainingPlans, }); msg.additional_kwargs["message_id"] = aiMessage.id; diff --git a/agent/uis/open-code/plan/index.tsx b/agent/uis/open-code/plan/index.tsx index 0714577..d6bbce3 100644 --- a/agent/uis/open-code/plan/index.tsx +++ b/agent/uis/open-code/plan/index.tsx @@ -3,31 +3,42 @@ import "./index.css"; interface PlanProps { toolCallId: string; executedPlans: string[]; + rejectedPlans: string[]; remainingPlans: string[]; } export default function Plan(props: PlanProps) { return ( -
-

Code Plan

-
-
-

- Executed Plans -

- {props.executedPlans.map((step, index) => ( -

- {index + 1}. {step} -

- ))} -
-
+
+

Code Plan

+
+

Remaining Plans

{props.remainingPlans.map((step, index) => (

- {props.executedPlans.length + index + 1}. {step} + {index + 1}. {step} +

+ ))} +
+
+

+ Executed Plans +

+ {props.executedPlans.map((step, index) => ( +

+ {step} +

+ ))} +
+
+

+ Rejected Plans +

+ {props.rejectedPlans.map((step, index) => ( +

+ {step}

))}
diff --git a/agent/uis/open-code/proposed-change/index.tsx b/agent/uis/open-code/proposed-change/index.tsx index f6fac0e..d89c35e 100644 --- a/agent/uis/open-code/proposed-change/index.tsx +++ b/agent/uis/open-code/proposed-change/index.tsx @@ -9,18 +9,28 @@ import { Message } from "@langchain/langgraph-sdk"; import { DO_NOT_RENDER_ID_PREFIX } from "@/lib/ensure-tool-responses"; import { useEffect, useState } from "react"; import { getToolResponse } from "../../utils/get-tool-response"; +import { cn } from "@/lib/utils"; interface ProposedChangeProps { toolCallId: string; change: string; planItem: string; + /** + * Whether or not to show the "Accept"/"Reject" buttons + * If true, this means the user selected the "Accept, don't ask again" + * button for this session. + */ + fullWriteAccess: boolean; } const ACCEPTED_CHANGE_CONTENT = "User accepted the proposed change. Please continue."; +const REJECTED_CHANGE_CONTENT = + "User rejected the proposed change. Please continue."; export default function ProposedChange(props: ProposedChangeProps) { const [isAccepted, setIsAccepted] = useState(false); + const [isRejected, setIsRejected] = useState(false); const thread = useStreamContext< { messages: Message[]; ui: UIMessage[] }, @@ -28,41 +38,81 @@ export default function ProposedChange(props: ProposedChangeProps) { >(); const handleReject = () => { - alert("Rejected. (just kidding, you can't reject me silly!)"); - }; - const handleAccept = () => { thread.submit({ messages: [ { type: "tool", tool_call_id: props.toolCallId, id: `${DO_NOT_RENDER_ID_PREFIX}${uuidv4()}`, - name: "buy-stock", - content: ACCEPTED_CHANGE_CONTENT, + name: "update_file", + content: REJECTED_CHANGE_CONTENT, }, { type: "human", - content: `Accepted change.`, + content: `Rejected change.`, }, ], }); + setIsRejected(true); + }; + + const handleAccept = (shouldGrantFullWriteAccess = false) => { + const humanMessageContent = `Accepted change. ${shouldGrantFullWriteAccess ? "Granted full write access." : ""}`; + thread.submit( + { + messages: [ + { + type: "tool", + tool_call_id: props.toolCallId, + id: `${DO_NOT_RENDER_ID_PREFIX}${uuidv4()}`, + name: "update_file", + content: ACCEPTED_CHANGE_CONTENT, + }, + { + type: "human", + content: humanMessageContent, + }, + ], + }, + { + config: { + configurable: { + permissions: { + full_write_access: shouldGrantFullWriteAccess, + }, + }, + }, + }, + ); + setIsAccepted(true); }; useEffect(() => { if (typeof window === "undefined" || isAccepted) return; const toolResponse = getToolResponse(props.toolCallId, thread); - if (toolResponse && toolResponse.content === ACCEPTED_CHANGE_CONTENT) { - setIsAccepted(true); + if (toolResponse) { + if (toolResponse.content === ACCEPTED_CHANGE_CONTENT) { + setIsAccepted(true); + } else if (toolResponse.content === REJECTED_CHANGE_CONTENT) { + setIsRejected(true); + } } }, []); - if (isAccepted) { + if (isAccepted || isRejected) { return ( -
+
-

Accepted Change

+

+ {isAccepted ? "Accepted" : "Rejected"} Change +

{props.planItem}

-
- - -
+ {!props.fullWriteAccess && ( +
+ + + +
+ )}
); } diff --git a/src/components/thread/messages/tool-calls.tsx b/src/components/thread/messages/tool-calls.tsx index b94f885..816a440 100644 --- a/src/components/thread/messages/tool-calls.tsx +++ b/src/components/thread/messages/tool-calls.tsx @@ -15,7 +15,7 @@ export function ToolCalls({ if (!toolCalls || toolCalls.length === 0) return null; return ( -
+
{toolCalls.map((tc, idx) => { const args = tc.args as Record; const hasArgs = Object.keys(args).length > 0;