This commit is contained in:
bracesproul
2025-03-06 11:51:11 -08:00
parent a76ebb2bc5
commit 5eb600f60d

View File

@@ -1,7 +1,10 @@
import { v4 as uuidv4 } from "uuid";
import { ChatOpenAI } from "@langchain/openai";
import { TripDetails, TripPlannerState, TripPlannerUpdate } from "../types";
import { z } from "zod";
import { formatMessages } from "agent/utils/format-messages";
import { ToolMessage } from "@langchain/langgraph-sdk";
import { DO_NOT_RENDER_ID_PREFIX } from "@/lib/ensure-tool-responses";
function calculateDates(
startDate: string | undefined,
@@ -60,9 +63,7 @@ export async function extraction(
.describe("The end date of the trip. Should be in YYYY-MM-DD format"),
numberOfGuests: z
.number()
.optional()
.default(2)
.describe("The number of guests for the trip"),
.describe("The number of guests for the trip. Should default to 2 if not specified"),
});
const model = new ChatOpenAI({ model: "gpt-4o", temperature: 0 }).bindTools([
@@ -96,15 +97,13 @@ Extract only what is specified by the user. It is okay to leave fields blank if
{ role: "human", content: humanMessage },
]);
const extractedDetails = response.tool_calls?.[0]?.args as
| z.infer<typeof schema>
| undefined;
if (!extractedDetails) {
const toolCall = response.tool_calls?.[0];
if (!toolCall) {
return {
messages: [response],
};
}
const extractedDetails = toolCall.args as z.infer<typeof schema>;
const { startDate, endDate } = calculateDates(
extractedDetails.startDate,
@@ -114,13 +113,19 @@ Extract only what is specified by the user. It is okay to leave fields blank if
const extractionDetailsWithDefaults: TripDetails = {
startDate,
endDate,
numberOfGuests: extractedDetails.numberOfGuests
? extractedDetails.numberOfGuests
: 2,
numberOfGuests: extractedDetails.numberOfGuests ?? 2,
location: extractedDetails.location,
};
const extractToolResponse: ToolMessage = {
type: "tool",
id: `${DO_NOT_RENDER_ID_PREFIX}${uuidv4()}`,
tool_call_id: toolCall.id ?? "",
content: "Successfully extracted trip details",
};
return {
tripDetails: extractionDetailsWithDefaults,
messages: [response, extractToolResponse]
};
}