This commit is contained in:
William Fu-Hinthorn
2024-09-13 16:30:31 -07:00
parent 523dc46aff
commit 4e6582bd91
13 changed files with 29 additions and 113 deletions

8
src/agent/__init__.py Normal file
View File

@@ -0,0 +1,8 @@
"""New LangGraph Agent.
This module defines a custom graph.
"""
from agent.graph import graph
__all__ = ["graph"]

View File

@@ -0,0 +1,49 @@
"""Define the configurable parameters for the agent."""
from __future__ import annotations
from dataclasses import dataclass, field, fields
from typing import Annotated, Optional
from langchain_core.runnables import RunnableConfig, ensure_config
from agent import prompts
@dataclass(kw_only=True)
class Configuration:
"""The configuration for the agent."""
system_prompt: str = field(default=prompts.SYSTEM_PROMPT)
"""The system prompt to use for the agent's interactions.
This prompt sets the context and behavior for the agent.
"""
model_name: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = (
"anthropic/claude-3-5-sonnet-20240620"
)
"""The name of the language model to use for the agent's main interactions.
Should be in the form: provider/model-name.
"""
scraper_tool_model_name: Annotated[
str, {"__template_metadata__": {"kind": "llm"}}
] = "accounts/fireworks/models/firefunction-v2"
"""The name of the language model to use for the web scraping tool.
This model is specifically used for summarizing and extracting information from web pages.
"""
max_search_results: int = 10
"""The maximum number of search results to return for each search query."""
@classmethod
def from_runnable_config(
cls, config: Optional[RunnableConfig] = None
) -> Configuration:
"""Create a Configuration instance from a RunnableConfig object."""
config = ensure_config(config)
configurable = config.get("configurable") or {}
_fields = {f.name for f in fields(cls) if f.init}
return cls(**{k: v for k, v in configurable.items() if k in _fields})

127
src/agent/graph.py Normal file
View File

@@ -0,0 +1,127 @@
"""Define a custom Reasoning and Action agent.
Works with a chat model with tool calling support.
"""
from datetime import datetime, timezone
from typing import Dict, List, Literal, cast
from langchain_core.messages import AIMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode
from agent.configuration import Configuration
from agent.state import InputState, State
from agent.tools import TOOLS
from agent.utils import load_chat_model
# Define the function that calls the model
async def call_model(
state: State, config: RunnableConfig
) -> Dict[str, List[AIMessage]]:
"""Call the LLM powering our "agent".
This function prepares the prompt, initializes the model, and processes the response.
Args:
state (State): The current state of the conversation.
config (RunnableConfig): Configuration for the model run.
Returns:
dict: A dictionary containing the model's response message.
"""
configuration = Configuration.from_runnable_config(config)
# Create a prompt template. Customize this to change the agent's behavior.
prompt = ChatPromptTemplate.from_messages(
[("system", configuration.system_prompt), ("placeholder", "{messages}")]
)
# Initialize the model with tool binding. Change the model or add more tools here.
model = load_chat_model(configuration.model_name).bind_tools(TOOLS)
# Prepare the input for the model, including the current system time
message_value = await prompt.ainvoke(
{
"messages": state.messages,
"system_time": datetime.now(tz=timezone.utc).isoformat(),
},
config,
)
# Get the model's response
response = cast(AIMessage, await model.ainvoke(message_value, config))
# Handle the case when it's the last step and the model still wants to use a tool
if state.is_last_step and response.tool_calls:
return {
"messages": [
AIMessage(
id=response.id,
content="Sorry, I could not find an answer to your question in the specified number of steps.",
)
]
}
# Return the model's response as a list to be added to existing messages
return {"messages": [response]}
# Define a new graph
workflow = StateGraph(State, input=InputState, config_schema=Configuration)
# Define the two nodes we will cycle between
workflow.add_node(call_model)
workflow.add_node("tools", ToolNode(TOOLS))
# Set the entrypoint as `call_model`
# This means that this node is the first one called
workflow.add_edge("__start__", "call_model")
def route_model_output(state: State) -> Literal["__end__", "tools"]:
"""Determine the next node based on the model's output.
This function checks if the model's last message contains tool calls.
Args:
state (State): The current state of the conversation.
Returns:
str: The name of the next node to call ("__end__" or "tools").
"""
last_message = state.messages[-1]
if not isinstance(last_message, AIMessage):
raise ValueError(
f"Expected AIMessage in output edges, but got {type(last_message).__name__}"
)
# If there is no tool call, then we finish
if not last_message.tool_calls:
return "__end__"
# Otherwise we execute the requested actions
return "tools"
# Add a conditional edge to determine the next step after `call_model`
workflow.add_conditional_edges(
"call_model",
# After call_model finishes running, the next node(s) are scheduled
# based on the output from route_model_output
route_model_output,
)
# Add a normal edge from `tools` to `call_model`
# This creates a cycle: after using tools, we always return to the model
workflow.add_edge("tools", "call_model")
# Compile the workflow into an executable graph
# You can customize this by adding interrupt points for state updates
graph = workflow.compile(
interrupt_before=[], # Add node names here to update state before they're called
interrupt_after=[], # Add node names here to update state after they're called
)

5
src/agent/prompts.py Normal file
View File

@@ -0,0 +1,5 @@
"""Default prompts used by the agent."""
SYSTEM_PROMPT = """You are a helpful AI assistant.
System time: {system_time}"""

60
src/agent/state.py Normal file
View File

@@ -0,0 +1,60 @@
"""Define the state structures for the agent."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Sequence
from langchain_core.messages import AnyMessage
from langgraph.graph import add_messages
from langgraph.managed import IsLastStep
from typing_extensions import Annotated
@dataclass
class InputState:
"""Defines the input state for the agent, representing a narrower interface to the outside world.
This class is used to define the initial state and structure of incoming data.
"""
messages: Annotated[Sequence[AnyMessage], add_messages] = field(
default_factory=list
)
"""
Messages tracking the primary execution state of the agent.
Typically accumulates a pattern of:
1. HumanMessage - user input
2. AIMessage with .tool_calls - agent picking tool(s) to use to collect information
3. ToolMessage(s) - the responses (or errors) from the executed tools
4. AIMessage without .tool_calls - agent responding in unstructured format to the user
5. HumanMessage - user responds with the next conversational turn
Steps 2-5 may repeat as needed.
The `add_messages` annotation ensures that new messages are merged with existing ones,
updating by ID to maintain an "append-only" state unless a message with the same ID is provided.
"""
@dataclass
class State(InputState):
"""Represents the complete state of the agent, extending InputState with additional attributes.
This class can be used to store any information needed throughout the agent's lifecycle.
"""
is_last_step: IsLastStep = field(default=False)
"""
Indicates whether the current step is the last one before the graph raises an error.
This is a 'managed' variable, controlled by the state machine rather than user code.
It is set to 'True' when the step count reaches recursion_limit - 1.
"""
# Additional attributes can be added here as needed.
# Common examples include:
# retrieved_documents: List[Document] = field(default_factory=list)
# extracted_entities: Dict[str, Any] = field(default_factory=dict)
# api_connections: Dict[str, Any] = field(default_factory=dict)

27
src/agent/utils.py Normal file
View File

@@ -0,0 +1,27 @@
"""Utility & helper functions."""
from langchain.chat_models import init_chat_model
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage
def get_message_text(msg: BaseMessage) -> str:
"""Get the text content of a message."""
content = msg.content
if isinstance(content, str):
return content
elif isinstance(content, dict):
return content.get("text", "")
else:
txts = [c if isinstance(c, str) else (c.get("text") or "") for c in content]
return "".join(txts).strip()
def load_chat_model(fully_specified_name: str) -> BaseChatModel:
"""Load a chat model from a fully specified name.
Args:
fully_specified_name (str): String in the format 'provider/model'.
"""
provider, model = fully_specified_name.split("/", maxsplit=1)
return init_chat_model(model, model_provider=provider)