fix: Custom tool call parsing for anthropic

This commit is contained in:
bracesproul
2025-03-07 16:57:30 -08:00
parent ecbcab7d24
commit c0cef943a6
6 changed files with 2286 additions and 4270 deletions

View File

@@ -7,10 +7,12 @@ import { ChatOpenAI } from "@langchain/openai";
import { tripPlannerGraph } from "./trip-planner"; import { tripPlannerGraph } from "./trip-planner";
import { formatMessages } from "./utils/format-messages"; import { formatMessages } from "./utils/format-messages";
import { graph as openCodeGraph } from "./open-code"; import { graph as openCodeGraph } from "./open-code";
import { graph as orderPizzaGraph } from "./pizza-orderer";
const allToolDescriptions = `- stockbroker: can fetch the price of a ticker, purchase/sell a ticker, or get the user's portfolio const allToolDescriptions = `- stockbroker: can fetch the price of a ticker, purchase/sell a ticker, or get the user's portfolio
- tripPlanner: helps the user plan their trip. it can suggest restaurants, and places to stay in any given location. - tripPlanner: helps the user plan their trip. it can suggest restaurants, and places to stay in any given location.
- openCode: can write code for the user. call this tool when the user asks you to write code`; - openCode: can write code for the user. call this tool when the user asks you to write code
- orderPizza: can order a pizza for the user`;
async function router( async function router(
state: GenerativeUIState, state: GenerativeUIState,
@@ -21,7 +23,7 @@ ${allToolDescriptions}
`; `;
const routerSchema = z.object({ const routerSchema = z.object({
route: z route: z
.enum(["stockbroker", "tripPlanner", "openCode", "generalInput"]) .enum(["stockbroker", "tripPlanner", "openCode", "orderPizza", "generalInput"])
.describe(routerDescription), .describe(routerDescription),
}); });
const routerTool = { const routerTool = {
@@ -75,7 +77,7 @@ Please pick the proper route based on the most recent message, in the context of
function handleRoute( function handleRoute(
state: GenerativeUIState, state: GenerativeUIState,
): "stockbroker" | "tripPlanner" | "openCode" | "generalInput" { ): "stockbroker" | "tripPlanner" | "openCode" | "orderPizza" | "generalInput" {
return state.next; return state.next;
} }
@@ -107,18 +109,21 @@ const builder = new StateGraph(GenerativeUIAnnotation)
.addNode("stockbroker", stockbrokerGraph) .addNode("stockbroker", stockbrokerGraph)
.addNode("tripPlanner", tripPlannerGraph) .addNode("tripPlanner", tripPlannerGraph)
.addNode("openCode", openCodeGraph) .addNode("openCode", openCodeGraph)
.addNode("orderPizza", orderPizzaGraph)
.addNode("generalInput", handleGeneralInput) .addNode("generalInput", handleGeneralInput)
.addConditionalEdges("router", handleRoute, [ .addConditionalEdges("router", handleRoute, [
"stockbroker", "stockbroker",
"tripPlanner", "tripPlanner",
"openCode", "openCode",
"orderPizza",
"generalInput", "generalInput",
]) ])
.addEdge(START, "router") .addEdge(START, "router")
.addEdge("stockbroker", END) .addEdge("stockbroker", END)
.addEdge("tripPlanner", END) .addEdge("tripPlanner", END)
.addEdge("openCode", END) .addEdge("openCode", END)
.addEdge("orderPizza", END)
.addEdge("generalInput", END); .addEdge("generalInput", END);
export const graph = builder.compile(); export const graph = builder.compile();

View File

@@ -0,0 +1,85 @@
import { ChatAnthropic } from "@langchain/anthropic";
import { Annotation, END, START, StateGraph } from "@langchain/langgraph";
import { GenerativeUIAnnotation } from "../types";
import { z } from "zod";
import { AIMessage, ToolMessage } from "@langchain/langgraph-sdk";
import { v4 as uuidv4 } from "uuid";
const PizzaOrdererAnnotation = Annotation.Root({
messages: GenerativeUIAnnotation.spec.messages,
})
async function sleep(ms = 5000) {
return new Promise((resolve) => setTimeout(resolve, ms));
}
const workflow = new StateGraph(PizzaOrdererAnnotation)
.addNode("findStore", async (state) => {
const findShopSchema = z.object({
location: z.string().describe("The location the user is in. E.g. 'San Francisco' or 'New York'"),
pizza_company: z.string().optional().describe("The name of the pizza company. E.g. 'Dominos' or 'Papa John's'. Optional, if not defined it will search for all pizza shops"),
}).describe("The schema for finding a pizza shop for the user")
const model = new ChatAnthropic({ model: "claude-3-5-sonnet-latest", temperature: 0 }).withStructuredOutput(findShopSchema, {
name: "find_pizza_shop",
includeRaw: true,
})
const response = await model.invoke([
{
role: "system",
content: "You are a helpful AI assistant, tasked with extracting information from the conversation between you, and the user, in order to find a pizza shop for them."
},
...state.messages,
])
await sleep();
const toolResponse: ToolMessage = {
type: "tool",
id: uuidv4(),
content: "I've found a pizza shop at 1119 19th St, San Francisco, CA 94107. The phone number for the shop is 415-555-1234.",
tool_call_id: (response.raw as unknown as AIMessage).tool_calls?.[0].id ?? "",
}
return {
messages: [response.raw, toolResponse]
}
})
.addNode("orderPizza", async (state) => {
await sleep(1500);
const placeOrderSchema = z.object({
address: z.string().describe("The address of the store to order the pizza from"),
phone_number: z.string().describe("The phone number of the store to order the pizza from"),
order: z.string().describe("The full pizza order for the user"),
}).describe("The schema for ordering a pizza for the user")
const model = new ChatAnthropic({ model: "claude-3-5-sonnet-latest", temperature: 0 }).withStructuredOutput(placeOrderSchema, {
name: "place_pizza_order",
includeRaw: true,
})
const response = await model.invoke([
{
role: "system",
content: "You are a helpful AI assistant, tasked with placing an order for a pizza for the user."
},
...state.messages,
])
const toolResponse: ToolMessage = {
type: "tool",
id: uuidv4(),
content: "Pizza order successfully placed.",
tool_call_id: (response.raw as unknown as AIMessage).tool_calls?.[0].id ?? "",
}
return {
messages: [response.raw, toolResponse]
}
})
.addEdge(START, "findStore")
.addEdge("findStore", "orderPizza")
.addEdge("orderPizza", END)
export const graph = workflow.compile()
graph.name = "Order Pizza Graph";

View File

@@ -13,7 +13,7 @@ export const GenerativeUIAnnotation = Annotation.Root({
>({ default: () => [], reducer: uiMessageReducer }), >({ default: () => [], reducer: uiMessageReducer }),
timestamp: Annotation<number>, timestamp: Annotation<number>,
next: Annotation< next: Annotation<
"stockbroker" | "tripPlanner" | "openCode" | "generalInput" "stockbroker" | "tripPlanner" | "openCode" | "orderPizza" | "generalInput"
>(), >(),
}); });

View File

@@ -16,6 +16,7 @@
"@assistant-ui/react-markdown": "^0.8.0", "@assistant-ui/react-markdown": "^0.8.0",
"@assistant-ui/react-syntax-highlighter": "^0.7.2", "@assistant-ui/react-syntax-highlighter": "^0.7.2",
"@faker-js/faker": "^9.5.1", "@faker-js/faker": "^9.5.1",
"@langchain/anthropic": "^0.3.15",
"@langchain/core": "^0.3.41", "@langchain/core": "^0.3.41",
"@langchain/google-genai": "^0.1.10", "@langchain/google-genai": "^0.1.10",
"@langchain/langgraph": "^0.2.49", "@langchain/langgraph": "^0.2.49",

6418
pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,6 @@
import { parsePartialJson } from "@langchain/core/output_parsers";
import { useStreamContext } from "@/providers/Stream"; import { useStreamContext } from "@/providers/Stream";
import { Checkpoint, Message } from "@langchain/langgraph-sdk"; import { AIMessage, Checkpoint, Message } from "@langchain/langgraph-sdk";
import { getContentString } from "../utils"; import { getContentString } from "../utils";
import { BranchSwitcher, CommandBar } from "./shared"; import { BranchSwitcher, CommandBar } from "./shared";
import { MarkdownText } from "../markdown-text"; import { MarkdownText } from "../markdown-text";
@@ -7,6 +8,7 @@ import { LoadExternalComponent } from "@langchain/langgraph-sdk/react-ui";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
import { ToolCalls, ToolResult } from "./tool-calls"; import { ToolCalls, ToolResult } from "./tool-calls";
import { StringParam, useQueryParam } from "use-query-params"; import { StringParam, useQueryParam } from "use-query-params";
import { MessageContentComplex } from "@langchain/core/messages";
function CustomComponent({ function CustomComponent({
message, message,
@@ -44,6 +46,28 @@ function CustomComponent({
); );
} }
function parseAnthropicStreamedToolCalls(content: MessageContentComplex[]): AIMessage["tool_calls"] {
const toolCallContents = content.filter((c) => c.type === "tool_use" && c.id);
return toolCallContents.map((tc) => {
const toolCall = tc as Record<string, any>
let json: Record<string, any> = {};
if (toolCall?.input) {
try {
json = parsePartialJson(toolCall.input) ?? {}
} catch {
// Pass
}
}
return {
name: toolCall.name ?? "",
id: toolCall.id ?? "",
args: json,
type: "tool_call",
}
})
}
export function AssistantMessage({ export function AssistantMessage({
message, message,
isLoading, isLoading,
@@ -58,11 +82,14 @@ export function AssistantMessage({
const thread = useStreamContext(); const thread = useStreamContext();
const meta = thread.getMessagesMetadata(message); const meta = thread.getMessagesMetadata(message);
const parentCheckpoint = meta?.firstSeenState?.parent_checkpoint; const parentCheckpoint = meta?.firstSeenState?.parent_checkpoint;
const anthropicStreamedToolCalls = Array.isArray(message.content) ? parseAnthropicStreamedToolCalls(message.content) : undefined;
const hasToolCalls = const hasToolCalls =
"tool_calls" in message && ("tool_calls" in message &&
message.tool_calls && message.tool_calls &&
message.tool_calls.length > 0; message.tool_calls.length > 0);
const toolCallsHaveContents = hasToolCalls && message.tool_calls?.some((tc) => tc.args && Object.keys(tc.args).length > 0);
const hasAnthropicToolCalls = !!anthropicStreamedToolCalls?.length;
const isToolResult = message.type === "tool"; const isToolResult = message.type === "tool";
return ( return (
@@ -76,7 +103,9 @@ export function AssistantMessage({
<MarkdownText>{contentString}</MarkdownText> <MarkdownText>{contentString}</MarkdownText>
</div> </div>
)} )}
{hasToolCalls && <ToolCalls toolCalls={message.tool_calls} />} {(hasToolCalls && toolCallsHaveContents && <ToolCalls toolCalls={message.tool_calls} />) ||
(hasAnthropicToolCalls && <ToolCalls toolCalls={anthropicStreamedToolCalls} />) ||
(hasToolCalls && <ToolCalls toolCalls={message.tool_calls} />)}
<CustomComponent message={message} thread={thread} /> <CustomComponent message={message} thread={thread} />
<div <div
className={cn( className={cn(