get working

This commit is contained in:
William Fu-Hinthorn
2024-09-13 16:57:34 -07:00
parent 4e6582bd91
commit 6cb0b0506e
8 changed files with 51 additions and 183 deletions

View File

@@ -1,11 +1,3 @@
TAVILY_API_KEY=... ANTHROPIC_API_KEY=...
# To separate your traces from other application # To separate your traces from other application
LANGSMITH_PROJECT=retrieval-agent LANGSMITH_PROJECT=new-agent
# The following depend on your selected configuration
## LLM choice:
ANTHROPIC_API_KEY=....
FIREWORKS_API_KEY=...
OPENAI_API_KEY=...

View File

@@ -8,7 +8,11 @@ authors = [
readme = "README.md" readme = "README.md"
license = { text = "MIT" } license = { text = "MIT" }
requires-python = ">=3.9" requires-python = ">=3.9"
dependencies = ["langgraph>=0.2.6", "python-dotenv>=1.0.1"] dependencies = [
"anthropic>=0.34.2",
"langgraph>=0.2.6",
"python-dotenv>=1.0.1",
]
[project.optional-dependencies] [project.optional-dependencies]

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field, fields
from typing import Annotated, Optional from typing import Annotated, Optional
from langchain_core.runnables import RunnableConfig, ensure_config from langchain_core.runnables import RunnableConfig
from agent import prompts from agent import prompts
@@ -20,30 +20,21 @@ class Configuration:
This prompt sets the context and behavior for the agent. This prompt sets the context and behavior for the agent.
""" """
model_name: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = ( model_name: Annotated[
"anthropic/claude-3-5-sonnet-20240620" str,
) {
"""The name of the language model to use for the agent's main interactions. "__template_metadata__": {
"kind": "llm",
Should be in the form: provider/model-name. }
""" },
] = "claude-3-5-sonnet-20240620"
scraper_tool_model_name: Annotated[ """The name of the language model to use for our chatbot."""
str, {"__template_metadata__": {"kind": "llm"}}
] = "accounts/fireworks/models/firefunction-v2"
"""The name of the language model to use for the web scraping tool.
This model is specifically used for summarizing and extracting information from web pages.
"""
max_search_results: int = 10
"""The maximum number of search results to return for each search query."""
@classmethod @classmethod
def from_runnable_config( def from_runnable_config(
cls, config: Optional[RunnableConfig] = None cls, config: Optional[RunnableConfig] = None
) -> Configuration: ) -> Configuration:
"""Create a Configuration instance from a RunnableConfig object.""" """Create a Configuration instance from a RunnableConfig object."""
config = ensure_config(config) configurable = (config.get("configurable") or {}) if config else {}
configurable = config.get("configurable") or {}
_fields = {f.name for f in fields(cls) if f.init} _fields = {f.name for f in fields(cls) if f.init}
return cls(**{k: v for k, v in configurable.items() if k in _fields}) return cls(**{k: v for k, v in configurable.items() if k in _fields})

View File

@@ -4,25 +4,20 @@ Works with a chat model with tool calling support.
""" """
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Dict, List, Literal, cast from typing import Any, Dict, List
from langchain_core.messages import AIMessage import anthropic
from langchain_core.prompts import ChatPromptTemplate from agent.configuration import Configuration
from agent.state import State
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
from langgraph.graph import StateGraph from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode
from agent.configuration import Configuration
from agent.state import InputState, State
from agent.tools import TOOLS
from agent.utils import load_chat_model
# Define the function that calls the model # Define the function that calls the model
async def call_model( async def call_model(
state: State, config: RunnableConfig state: State, config: RunnableConfig
) -> Dict[str, List[AIMessage]]: ) -> Dict[str, List[Dict[str, Any]]]:
"""Call the LLM powering our "agent". """Call the LLM powering our "agent".
This function prepares the prompt, initializes the model, and processes the response. This function prepares the prompt, initializes the model, and processes the response.
@@ -35,93 +30,44 @@ async def call_model(
dict: A dictionary containing the model's response message. dict: A dictionary containing the model's response message.
""" """
configuration = Configuration.from_runnable_config(config) configuration = Configuration.from_runnable_config(config)
system_prompt = configuration.system_prompt.format(
# Create a prompt template. Customize this to change the agent's behavior. system_time=datetime.now(tz=timezone.utc).isoformat()
prompt = ChatPromptTemplate.from_messages(
[("system", configuration.system_prompt), ("placeholder", "{messages}")]
) )
toks = []
# Initialize the model with tool binding. Change the model or add more tools here. async with anthropic.AsyncAnthropic() as client:
model = load_chat_model(configuration.model_name).bind_tools(TOOLS) async with client.messages.stream(
model=configuration.model_name,
# Prepare the input for the model, including the current system time max_tokens=1024,
message_value = await prompt.ainvoke( system=system_prompt,
{ messages=state.messages,
"messages": state.messages, ) as stream:
"system_time": datetime.now(tz=timezone.utc).isoformat(), async for text in stream.text_stream:
}, toks.append(text)
config,
)
# Get the model's response
response = cast(AIMessage, await model.ainvoke(message_value, config))
# Handle the case when it's the last step and the model still wants to use a tool
if state.is_last_step and response.tool_calls:
return {
"messages": [
AIMessage(
id=response.id,
content="Sorry, I could not find an answer to your question in the specified number of steps.",
)
]
}
# Return the model's response as a list to be added to existing messages # Return the model's response as a list to be added to existing messages
return {"messages": [response]} return {
"messages": [
{"role": "assistant", "content": [{"type": "text", "text": "".join(toks)}]}
]
}
# Define a new graph # Define a new graph
workflow = StateGraph(State, input=InputState, config_schema=Configuration) workflow = StateGraph(State, config_schema=Configuration)
# Define the two nodes we will cycle between # Define the two nodes we will cycle between
workflow.add_node(call_model) workflow.add_node(call_model)
workflow.add_node("tools", ToolNode(TOOLS))
# Set the entrypoint as `call_model` # Set the entrypoint as `call_model`
# This means that this node is the first one called # This means that this node is the first one called
workflow.add_edge("__start__", "call_model") workflow.add_edge("__start__", "call_model")
def route_model_output(state: State) -> Literal["__end__", "tools"]:
"""Determine the next node based on the model's output.
This function checks if the model's last message contains tool calls.
Args:
state (State): The current state of the conversation.
Returns:
str: The name of the next node to call ("__end__" or "tools").
"""
last_message = state.messages[-1]
if not isinstance(last_message, AIMessage):
raise ValueError(
f"Expected AIMessage in output edges, but got {type(last_message).__name__}"
)
# If there is no tool call, then we finish
if not last_message.tool_calls:
return "__end__"
# Otherwise we execute the requested actions
return "tools"
# Add a conditional edge to determine the next step after `call_model`
workflow.add_conditional_edges(
"call_model",
# After call_model finishes running, the next node(s) are scheduled
# based on the output from route_model_output
route_model_output,
)
# Add a normal edge from `tools` to `call_model`
# This creates a cycle: after using tools, we always return to the model
workflow.add_edge("tools", "call_model")
# Compile the workflow into an executable graph # Compile the workflow into an executable graph
# You can customize this by adding interrupt points for state updates # You can customize this by adding interrupt points for state updates
graph = workflow.compile( graph = workflow.compile(
interrupt_before=[], # Add node names here to update state before they're called interrupt_before=[], # Add node names here to update state before they're called
interrupt_after=[], # Add node names here to update state after they're called interrupt_after=[], # Add node names here to update state after they're called
) )
graph.name = "My New Graph" # This defines the custom name in LangSmith

View File

@@ -1,5 +1,5 @@
"""Default prompts used by the agent.""" """Default prompts used by the chatbot."""
SYSTEM_PROMPT = """You are a helpful AI assistant. SYSTEM_PROMPT = """You are a helpful (if not sassy) personal assistant.
System time: {system_time}""" System time: {system_time}"""

View File

@@ -2,59 +2,23 @@
from __future__ import annotations from __future__ import annotations
import operator
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Sequence from typing import Sequence
from langchain_core.messages import AnyMessage
from langgraph.graph import add_messages
from langgraph.managed import IsLastStep
from typing_extensions import Annotated from typing_extensions import Annotated
@dataclass @dataclass
class InputState: class State:
"""Defines the input state for the agent, representing a narrower interface to the outside world. """Defines the input state for the agent, representing a narrower interface to the outside world.
This class is used to define the initial state and structure of incoming data. This class is used to define the initial state and structure of incoming data.
""" """
messages: Annotated[Sequence[AnyMessage], add_messages] = field( messages: Annotated[Sequence[dict], operator.add] = field(default_factory=list)
default_factory=list
)
""" """
Messages tracking the primary execution state of the agent. Messages tracking the primary execution state of the agent.
Typically accumulates a pattern of: Typically accumulates a pattern of user, assistant, user, ... etc. messages.
1. HumanMessage - user input
2. AIMessage with .tool_calls - agent picking tool(s) to use to collect information
3. ToolMessage(s) - the responses (or errors) from the executed tools
4. AIMessage without .tool_calls - agent responding in unstructured format to the user
5. HumanMessage - user responds with the next conversational turn
Steps 2-5 may repeat as needed.
The `add_messages` annotation ensures that new messages are merged with existing ones,
updating by ID to maintain an "append-only" state unless a message with the same ID is provided.
""" """
@dataclass
class State(InputState):
"""Represents the complete state of the agent, extending InputState with additional attributes.
This class can be used to store any information needed throughout the agent's lifecycle.
"""
is_last_step: IsLastStep = field(default=False)
"""
Indicates whether the current step is the last one before the graph raises an error.
This is a 'managed' variable, controlled by the state machine rather than user code.
It is set to 'True' when the step count reaches recursion_limit - 1.
"""
# Additional attributes can be added here as needed.
# Common examples include:
# retrieved_documents: List[Document] = field(default_factory=list)
# extracted_entities: Dict[str, Any] = field(default_factory=dict)
# api_connections: Dict[str, Any] = field(default_factory=dict)

View File

@@ -1,27 +0,0 @@
"""Utility & helper functions."""
from langchain.chat_models import init_chat_model
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage
def get_message_text(msg: BaseMessage) -> str:
"""Get the text content of a message."""
content = msg.content
if isinstance(content, str):
return content
elif isinstance(content, dict):
return content.get("text", "")
else:
txts = [c if isinstance(c, str) else (c.get("text") or "") for c in content]
return "".join(txts).strip()
def load_chat_model(fully_specified_name: str) -> BaseChatModel:
"""Load a chat model from a fully specified name.
Args:
fully_specified_name (str): String in the format 'provider/model'.
"""
provider, model = fully_specified_name.split("/", maxsplit=1)
return init_chat_model(model, model_provider=provider)

View File

@@ -1,14 +1,12 @@
import pytest import pytest
from langsmith import unit
from agent import graph from agent import graph
from langsmith import expect, unit
@pytest.mark.asyncio @pytest.mark.asyncio
@unit @unit
async def test_agent_simple_passthrough() -> None: async def test_agent_simple_passthrough() -> None:
res = await graph.ainvoke( res = await graph.ainvoke(
{"messages": [("user", "Who is the founder of LangChain?")]} {"messages": [{"role": "user", "content": "What's 62 - 19?"}]}
) )
expect(res["messages"][-1]["content"][0]["text"]).to_contain("43")
assert "harrison" in str(res["messages"][-1].content).lower()