feat: implement accept and dont ask again feature

This commit is contained in:
bracesproul
2025-03-07 13:42:37 -08:00
parent 81adff780f
commit 84cdbbe550
6 changed files with 185 additions and 53 deletions

View File

@@ -1,14 +1,45 @@
import { END, START, StateGraph } from "@langchain/langgraph"; import {
import { OpenCodeAnnotation } from "./types"; END,
LangGraphRunnableConfig,
START,
StateGraph,
} from "@langchain/langgraph";
import { OpenCodeAnnotation, OpenCodeState } from "./types";
import { planner } from "./nodes/planner"; import { planner } from "./nodes/planner";
import { executor } from "./nodes/executor"; import {
executor,
SUCCESSFULLY_COMPLETED_STEPS_CONTENT,
} from "./nodes/executor";
import { AIMessage } from "@langchain/langgraph-sdk";
function conditionallyEnd(
state: OpenCodeState,
config: LangGraphRunnableConfig,
): typeof END | "planner" {
const fullWriteAccess = !!config.configurable?.permissions?.full_write_access;
const lastAiMessage = state.messages.findLast(
(m) => m.getType() === "ai",
) as unknown as AIMessage;
// If the user did not grant full write access, or the last AI message is the success message, end
// otherwise, loop back to the start.
if (
(typeof lastAiMessage.content === "string" &&
lastAiMessage.content === SUCCESSFULLY_COMPLETED_STEPS_CONTENT) ||
!fullWriteAccess
) {
return END;
}
return "planner";
}
const workflow = new StateGraph(OpenCodeAnnotation) const workflow = new StateGraph(OpenCodeAnnotation)
.addNode("planner", planner) .addNode("planner", planner)
.addNode("executor", executor) .addNode("executor", executor)
.addEdge(START, "planner") .addEdge(START, "planner")
.addEdge("planner", "executor") .addEdge("planner", "executor")
.addEdge("executor", END); .addConditionalEdges("executor", conditionallyEnd, ["planner", END]);
export const graph = workflow.compile(); export const graph = workflow.compile();
graph.name = "Open Code Graph"; graph.name = "Open Code Graph";

View File

@@ -6,6 +6,9 @@ import { LangGraphRunnableConfig } from "@langchain/langgraph";
import ComponentMap from "../../uis"; import ComponentMap from "../../uis";
import { typedUi } from "@langchain/langgraph-sdk/react-ui/server"; import { typedUi } from "@langchain/langgraph-sdk/react-ui/server";
export const SUCCESSFULLY_COMPLETED_STEPS_CONTENT =
"Successfully completed all the steps in the plan. Please let me know if you need anything else!";
export async function executor( export async function executor(
state: OpenCodeState, state: OpenCodeState,
config: LangGraphRunnableConfig, config: LangGraphRunnableConfig,
@@ -24,21 +27,24 @@ export async function executor(
const nextPlanItem = planToolCallArgs?.remainingPlans?.[0] as const nextPlanItem = planToolCallArgs?.remainingPlans?.[0] as
| string | string
| undefined; | undefined;
const numOfExecutedPlanItems = planToolCallArgs?.executedPlans?.length ?? 0; const numSeenPlans =
[
...(planToolCallArgs?.executedPlans ?? []),
...(planToolCallArgs?.rejectedPlans ?? []),
]?.length ?? 0;
if (!nextPlanItem) { if (!nextPlanItem) {
// All plans have been executed // All plans have been executed
const successfullyFinishedMsg: AIMessage = { const successfullyFinishedMsg: AIMessage = {
type: "ai", type: "ai",
id: uuidv4(), id: uuidv4(),
content: content: SUCCESSFULLY_COMPLETED_STEPS_CONTENT,
"Successfully completed all the steps in the plan. Please let me know if you need anything else!",
}; };
return { messages: [successfullyFinishedMsg] }; return { messages: [successfullyFinishedMsg] };
} }
let updateFileContents = ""; let updateFileContents = "";
switch (numOfExecutedPlanItems) { switch (numSeenPlans) {
case 0: case 0:
updateFileContents = await fs.readFile( updateFileContents = await fs.readFile(
"agent/open-code/nodes/plan-code/step-1.txt", "agent/open-code/nodes/plan-code/step-1.txt",
@@ -101,10 +107,13 @@ export async function executor(
], ],
}; };
const fullWriteAccess = !!config.configurable?.permissions?.full_write_access;
const msg = ui.create("proposed-change", { const msg = ui.create("proposed-change", {
toolCallId, toolCallId,
change: updateFileContents, change: updateFileContents,
planItem: nextPlanItem, planItem: nextPlanItem,
fullWriteAccess,
}); });
msg.additional_kwargs["message_id"] = aiMessage.id; msg.additional_kwargs["message_id"] = aiMessage.id;

View File

@@ -28,28 +28,46 @@ export async function planner(
(tc) => tc.name === "update_file", (tc) => tc.name === "update_file",
), ),
) as AIMessage | undefined; ) as AIMessage | undefined;
const lastUpdateToolCallResponse = state.messages.findLast(
(m) =>
m.getType() === "tool" &&
(m as unknown as ToolMessage).tool_call_id ===
lastUpdateCodeToolCall?.tool_calls?.[0]?.id,
) as ToolMessage | undefined;
const lastPlanToolCall = state.messages.findLast( const lastPlanToolCall = state.messages.findLast(
(m) => (m) =>
m.getType() === "ai" && m.getType() === "ai" &&
(m as unknown as AIMessage).tool_calls?.some((tc) => tc.name === "plan"), (m as unknown as AIMessage).tool_calls?.some((tc) => tc.name === "plan"),
) as AIMessage | undefined; ) as AIMessage | undefined;
const wasPlanRejected = (
lastUpdateToolCallResponse?.content as string | undefined
)
?.toLowerCase()
.includes("rejected");
const planToolCallArgs = lastPlanToolCall?.tool_calls?.[0]?.args as Record< const planToolCallArgs = lastPlanToolCall?.tool_calls?.[0]?.args as Record<
string, string,
any any
>; >;
const executedPlans: string[] = planToolCallArgs?.executedPlans ?? []; const executedPlans: string[] = planToolCallArgs?.executedPlans ?? [];
const rejectedPlans: string[] = planToolCallArgs?.rejectedPlans ?? [];
let remainingPlans: string[] = planToolCallArgs?.remainingPlans ?? PLAN; let remainingPlans: string[] = planToolCallArgs?.remainingPlans ?? PLAN;
const executedPlanItem = lastUpdateCodeToolCall?.tool_calls?.[0]?.args const proposedChangePlanItem = lastUpdateCodeToolCall?.tool_calls?.[0]?.args
?.executed_plan_item as string | undefined; ?.executed_plan_item as string | undefined;
if (executedPlanItem) { if (proposedChangePlanItem) {
executedPlans.push(executedPlanItem); if (wasPlanRejected) {
remainingPlans = remainingPlans.filter((p) => p !== executedPlanItem); rejectedPlans.push(proposedChangePlanItem);
} else {
executedPlans.push(proposedChangePlanItem);
} }
const content = executedPlanItem remainingPlans = remainingPlans.filter((p) => p !== proposedChangePlanItem);
? `I've updated the plan list based on the executed plans.` }
const content = proposedChangePlanItem
? `I've updated the plan list based on the last proposed change.`
: `I've come up with a detailed plan for building the todo app.`; : `I've come up with a detailed plan for building the todo app.`;
const toolCallId = uuidv4(); const toolCallId = uuidv4();
@@ -62,6 +80,7 @@ export async function planner(
name: "plan", name: "plan",
args: { args: {
executedPlans, executedPlans,
rejectedPlans,
remainingPlans, remainingPlans,
}, },
id: toolCallId, id: toolCallId,
@@ -73,6 +92,7 @@ export async function planner(
const msg = ui.create("code-plan", { const msg = ui.create("code-plan", {
toolCallId, toolCallId,
executedPlans, executedPlans,
rejectedPlans,
remainingPlans, remainingPlans,
}); });
msg.additional_kwargs["message_id"] = aiMessage.id; msg.additional_kwargs["message_id"] = aiMessage.id;

View File

@@ -3,31 +3,42 @@ import "./index.css";
interface PlanProps { interface PlanProps {
toolCallId: string; toolCallId: string;
executedPlans: string[]; executedPlans: string[];
rejectedPlans: string[];
remainingPlans: string[]; remainingPlans: string[];
} }
export default function Plan(props: PlanProps) { export default function Plan(props: PlanProps) {
return ( return (
<div className="flex flex-col gap-4 w-full max-w-4xl p-6 border-[1px] rounded-xl border-slate-500"> <div className="flex flex-col gap-2 w-full max-w-4xl p-6 border-[1px] rounded-xl border-slate-200">
<h2 className="text-2xl font-semibold text-center mb-2">Code Plan</h2> <h2 className="text-2xl font-semibold text-left mb-2">Code Plan</h2>
<div className="grid grid-cols-2 divide-x divide-slate-300 w-full"> <div className="grid grid-cols-3 divide-x divide-slate-300 w-full border-t-[1px] pt-4">
<div className="flex flex-col gap-2 pr-6"> <div className="flex flex-col gap-2 px-6">
<h3 className="text-lg font-medium mb-4 text-slate-700">
Executed Plans
</h3>
{props.executedPlans.map((step, index) => (
<p key={index} className="font-mono text-sm">
{index + 1}. {step}
</p>
))}
</div>
<div className="flex flex-col gap-2 pl-6">
<h3 className="text-lg font-medium mb-4 text-slate-700"> <h3 className="text-lg font-medium mb-4 text-slate-700">
Remaining Plans Remaining Plans
</h3> </h3>
{props.remainingPlans.map((step, index) => ( {props.remainingPlans.map((step, index) => (
<p key={index} className="font-mono text-sm"> <p key={index} className="font-mono text-sm">
{props.executedPlans.length + index + 1}. {step} {index + 1}. {step}
</p>
))}
</div>
<div className="flex flex-col gap-2 px-6">
<h3 className="text-lg font-medium mb-4 text-slate-700">
Executed Plans
</h3>
{props.executedPlans.map((step, index) => (
<p key={index} className="font-mono text-sm">
{step}
</p>
))}
</div>
<div className="flex flex-col gap-2 px-6">
<h3 className="text-lg font-medium mb-4 text-slate-700">
Rejected Plans
</h3>
{props.rejectedPlans.map((step, index) => (
<p key={index} className="font-mono text-sm">
{step}
</p> </p>
))} ))}
</div> </div>

View File

@@ -9,18 +9,28 @@ import { Message } from "@langchain/langgraph-sdk";
import { DO_NOT_RENDER_ID_PREFIX } from "@/lib/ensure-tool-responses"; import { DO_NOT_RENDER_ID_PREFIX } from "@/lib/ensure-tool-responses";
import { useEffect, useState } from "react"; import { useEffect, useState } from "react";
import { getToolResponse } from "../../utils/get-tool-response"; import { getToolResponse } from "../../utils/get-tool-response";
import { cn } from "@/lib/utils";
interface ProposedChangeProps { interface ProposedChangeProps {
toolCallId: string; toolCallId: string;
change: string; change: string;
planItem: string; planItem: string;
/**
* Whether or not to show the "Accept"/"Reject" buttons
* If true, this means the user selected the "Accept, don't ask again"
* button for this session.
*/
fullWriteAccess: boolean;
} }
const ACCEPTED_CHANGE_CONTENT = const ACCEPTED_CHANGE_CONTENT =
"User accepted the proposed change. Please continue."; "User accepted the proposed change. Please continue.";
const REJECTED_CHANGE_CONTENT =
"User rejected the proposed change. Please continue.";
export default function ProposedChange(props: ProposedChangeProps) { export default function ProposedChange(props: ProposedChangeProps) {
const [isAccepted, setIsAccepted] = useState(false); const [isAccepted, setIsAccepted] = useState(false);
const [isRejected, setIsRejected] = useState(false);
const thread = useStreamContext< const thread = useStreamContext<
{ messages: Message[]; ui: UIMessage[] }, { messages: Message[]; ui: UIMessage[] },
@@ -28,24 +38,53 @@ export default function ProposedChange(props: ProposedChangeProps) {
>(); >();
const handleReject = () => { const handleReject = () => {
alert("Rejected. (just kidding, you can't reject me silly!)");
};
const handleAccept = () => {
thread.submit({ thread.submit({
messages: [ messages: [
{ {
type: "tool", type: "tool",
tool_call_id: props.toolCallId, tool_call_id: props.toolCallId,
id: `${DO_NOT_RENDER_ID_PREFIX}${uuidv4()}`, id: `${DO_NOT_RENDER_ID_PREFIX}${uuidv4()}`,
name: "buy-stock", name: "update_file",
content: REJECTED_CHANGE_CONTENT,
},
{
type: "human",
content: `Rejected change.`,
},
],
});
setIsRejected(true);
};
const handleAccept = (shouldGrantFullWriteAccess = false) => {
const humanMessageContent = `Accepted change. ${shouldGrantFullWriteAccess ? "Granted full write access." : ""}`;
thread.submit(
{
messages: [
{
type: "tool",
tool_call_id: props.toolCallId,
id: `${DO_NOT_RENDER_ID_PREFIX}${uuidv4()}`,
name: "update_file",
content: ACCEPTED_CHANGE_CONTENT, content: ACCEPTED_CHANGE_CONTENT,
}, },
{ {
type: "human", type: "human",
content: `Accepted change.`, content: humanMessageContent,
}, },
], ],
}); },
{
config: {
configurable: {
permissions: {
full_write_access: shouldGrantFullWriteAccess,
},
},
},
},
);
setIsAccepted(true); setIsAccepted(true);
}; };
@@ -53,16 +92,27 @@ export default function ProposedChange(props: ProposedChangeProps) {
useEffect(() => { useEffect(() => {
if (typeof window === "undefined" || isAccepted) return; if (typeof window === "undefined" || isAccepted) return;
const toolResponse = getToolResponse(props.toolCallId, thread); const toolResponse = getToolResponse(props.toolCallId, thread);
if (toolResponse && toolResponse.content === ACCEPTED_CHANGE_CONTENT) { if (toolResponse) {
if (toolResponse.content === ACCEPTED_CHANGE_CONTENT) {
setIsAccepted(true); setIsAccepted(true);
} else if (toolResponse.content === REJECTED_CHANGE_CONTENT) {
setIsRejected(true);
}
} }
}, []); }, []);
if (isAccepted) { if (isAccepted || isRejected) {
return ( return (
<div className="flex flex-col gap-4 w-full max-w-4xl p-4 border-[1px] rounded-xl border-green-300"> <div
className={cn(
"flex flex-col gap-4 w-full max-w-4xl p-4 border-[1px] rounded-xl",
isAccepted ? "border-green-300" : "border-red-300",
)}
>
<div className="flex flex-col items-start justify-start gap-2"> <div className="flex flex-col items-start justify-start gap-2">
<p className="text-lg font-medium">Accepted Change</p> <p className="text-lg font-medium">
{isAccepted ? "Accepted" : "Rejected"} Change
</p>
<p className="text-sm font-mono">{props.planItem}</p> <p className="text-sm font-mono">{props.planItem}</p>
</div> </div>
<ReactMarkdown <ReactMarkdown
@@ -111,18 +161,29 @@ export default function ProposedChange(props: ProposedChangeProps) {
}, },
}} }}
/> />
{!props.fullWriteAccess && (
<div className="flex gap-2 items-center w-full"> <div className="flex gap-2 items-center w-full">
<Button <Button
className="cursor-pointer" className="cursor-pointer w-full"
variant="destructive" variant="destructive"
onClick={handleReject} onClick={handleReject}
> >
Reject Reject
</Button> </Button>
<Button className="cursor-pointer" onClick={handleAccept}> <Button
className="cursor-pointer w-full"
onClick={() => handleAccept()}
>
Accept Accept
</Button> </Button>
<Button
className="cursor-pointer w-full bg-blue-500 hover:bg-blue-600"
onClick={() => handleAccept(true)}
>
Accept, don&apos;t ask again
</Button>
</div> </div>
)}
</div> </div>
); );
} }

View File

@@ -15,7 +15,7 @@ export function ToolCalls({
if (!toolCalls || toolCalls.length === 0) return null; if (!toolCalls || toolCalls.length === 0) return null;
return ( return (
<div className="space-y-4"> <div className="space-y-4 w-full max-w-4xl">
{toolCalls.map((tc, idx) => { {toolCalls.map((tc, idx) => {
const args = tc.args as Record<string, any>; const args = tc.args as Record<string, any>;
const hasArgs = Object.keys(args).length > 0; const hasArgs = Object.keys(args).length > 0;