2025-02-27 15:41:47 -08:00
import { StateGraph , START , END } from "@langchain/langgraph" ;
import { ChatGoogleGenerativeAI } from "@langchain/google-genai" ;
import { z } from "zod" ;
import { GenerativeUIAnnotation , GenerativeUIState } from "./types" ;
import { stockbrokerGraph } from "./stockbroker" ;
2025-02-18 19:35:46 +01:00
import { ChatOpenAI } from "@langchain/openai" ;
2025-03-03 16:51:46 -08:00
import { tripPlannerGraph } from "./trip-planner" ;
const allToolDescriptions = ` - stockbroker: can fetch the price of a ticker, purchase/sell a ticker, or get the user's portfolio
- tripPlanner : helps the user plan their trip . it can suggest restaurants , and places to stay in any given location . ` ;
2025-02-18 19:35:46 +01:00
2025-02-27 15:41:47 -08:00
async function router (
2025-03-03 12:31:27 -08:00
state : GenerativeUIState ,
2025-02-27 15:41:47 -08:00
) : Promise < Partial < GenerativeUIState > > {
const routerDescription = ` The route to take based on the user's input.
- stockbroker : can fetch the price of a ticker , purchase / sell a ticker , or get the user ' s portfolio
2025-03-03 16:51:46 -08:00
- tripPlanner : helps the user plan their trip . it can suggest restaurants , and places to stay in any given location .
2025-02-27 15:41:47 -08:00
- generalInput : handles all other cases where the above tools don ' t apply
` ;
const routerSchema = z . object ( {
route : z
2025-03-03 16:51:46 -08:00
. enum ( [ "stockbroker" , "tripPlanner" , "generalInput" ] )
2025-02-27 15:41:47 -08:00
. describe ( routerDescription ) ,
} ) ;
const routerTool = {
name : "router" ,
description : "A tool to route the user's query to the appropriate tool." ,
schema : routerSchema ,
} ;
2025-02-18 19:35:46 +01:00
2025-02-27 15:41:47 -08:00
const llm = new ChatGoogleGenerativeAI ( {
model : "gemini-2.0-flash" ,
temperature : 0 ,
2025-03-03 14:50:43 +01:00
} )
. bindTools ( [ routerTool ] , { tool_choice : "router" } )
. withConfig ( { tags : [ "langsmith:nostream" ] } ) ;
2025-02-27 14:08:24 -08:00
2025-02-27 15:41:47 -08:00
const prompt = ` You're a highly helpful AI assistant, tasked with routing the user's query to the appropriate tool.
You should analyze the user ' s input , and choose the appropriate tool to use . ` ;
2025-02-27 14:08:24 -08:00
2025-03-02 19:09:59 +01:00
const recentHumanMessage = state . messages . findLast (
2025-03-03 12:31:27 -08:00
( m ) = > m . getType ( ) === "human" ,
2025-03-02 19:09:59 +01:00
) ;
2025-02-18 19:35:46 +01:00
2025-02-27 15:41:47 -08:00
if ( ! recentHumanMessage ) {
throw new Error ( "No human message found in state" ) ;
}
2025-02-18 19:35:46 +01:00
2025-02-27 15:41:47 -08:00
const response = await llm . invoke ( [
{ role : "system" , content : prompt } ,
recentHumanMessage ,
] ) ;
2025-02-18 19:35:46 +01:00
2025-02-27 15:41:47 -08:00
const toolCall = response . tool_calls ? . [ 0 ] ? . args as
| z . infer < typeof routerSchema >
| undefined ;
if ( ! toolCall ) {
throw new Error ( "No tool call found in response" ) ;
}
2025-02-18 19:35:46 +01:00
2025-02-27 15:41:47 -08:00
return {
next : toolCall.route ,
} ;
}
2025-02-18 19:35:46 +01:00
2025-02-27 15:41:47 -08:00
function handleRoute (
2025-03-03 12:31:27 -08:00
state : GenerativeUIState ,
2025-03-03 16:51:46 -08:00
) : "stockbroker" | "tripPlanner" | "generalInput" {
2025-02-27 15:41:47 -08:00
return state . next ;
}
2025-02-18 19:35:46 +01:00
2025-02-27 15:41:47 -08:00
async function handleGeneralInput ( state : GenerativeUIState ) {
const llm = new ChatOpenAI ( { model : "gpt-4o-mini" , temperature : 0 } ) ;
2025-03-03 16:51:46 -08:00
const response = await llm . invoke ( [
{
role : "system" ,
content : ` You are an AI assistant. \ nIf the user asks what you can do, describe these tools. Otherwise, just answer as normal. \ n \ n ${ allToolDescriptions } ` ,
} ,
. . . state . messages ,
] ) ;
2025-02-27 14:08:24 -08:00
2025-02-27 15:41:47 -08:00
return {
messages : [ response ] ,
} ;
}
2025-02-18 19:35:46 +01:00
2025-02-27 15:41:47 -08:00
const builder = new StateGraph ( GenerativeUIAnnotation )
. addNode ( "router" , router )
. addNode ( "stockbroker" , stockbrokerGraph )
2025-03-03 16:51:46 -08:00
. addNode ( "tripPlanner" , tripPlannerGraph )
2025-02-27 15:41:47 -08:00
. addNode ( "generalInput" , handleGeneralInput )
2025-02-18 19:35:46 +01:00
2025-02-27 15:41:47 -08:00
. addConditionalEdges ( "router" , handleRoute , [
"stockbroker" ,
2025-03-03 16:51:46 -08:00
"tripPlanner" ,
2025-02-27 15:41:47 -08:00
"generalInput" ,
] )
. addEdge ( START , "router" )
. addEdge ( "stockbroker" , END )
2025-03-03 16:51:46 -08:00
. addEdge ( "tripPlanner" , END )
2025-02-27 15:41:47 -08:00
. addEdge ( "generalInput" , END ) ;
2025-02-18 19:35:46 +01:00
2025-02-27 15:41:47 -08:00
export const graph = builder . compile ( ) ;
graph . name = "Generative UI Agent" ;