get working
This commit is contained in:
12
.env.example
12
.env.example
@@ -1,11 +1,3 @@
|
|||||||
TAVILY_API_KEY=...
|
ANTHROPIC_API_KEY=...
|
||||||
|
|
||||||
# To separate your traces from other application
|
# To separate your traces from other application
|
||||||
LANGSMITH_PROJECT=retrieval-agent
|
LANGSMITH_PROJECT=new-agent
|
||||||
|
|
||||||
# The following depend on your selected configuration
|
|
||||||
|
|
||||||
## LLM choice:
|
|
||||||
ANTHROPIC_API_KEY=....
|
|
||||||
FIREWORKS_API_KEY=...
|
|
||||||
OPENAI_API_KEY=...
|
|
||||||
|
|||||||
@@ -8,7 +8,11 @@ authors = [
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { text = "MIT" }
|
license = { text = "MIT" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
dependencies = ["langgraph>=0.2.6", "python-dotenv>=1.0.1"]
|
dependencies = [
|
||||||
|
"anthropic>=0.34.2",
|
||||||
|
"langgraph>=0.2.6",
|
||||||
|
"python-dotenv>=1.0.1",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
from dataclasses import dataclass, field, fields
|
from dataclasses import dataclass, field, fields
|
||||||
from typing import Annotated, Optional
|
from typing import Annotated, Optional
|
||||||
|
|
||||||
from langchain_core.runnables import RunnableConfig, ensure_config
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
from agent import prompts
|
from agent import prompts
|
||||||
|
|
||||||
@@ -20,30 +20,21 @@ class Configuration:
|
|||||||
This prompt sets the context and behavior for the agent.
|
This prompt sets the context and behavior for the agent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_name: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = (
|
model_name: Annotated[
|
||||||
"anthropic/claude-3-5-sonnet-20240620"
|
str,
|
||||||
)
|
{
|
||||||
"""The name of the language model to use for the agent's main interactions.
|
"__template_metadata__": {
|
||||||
|
"kind": "llm",
|
||||||
Should be in the form: provider/model-name.
|
}
|
||||||
"""
|
},
|
||||||
|
] = "claude-3-5-sonnet-20240620"
|
||||||
scraper_tool_model_name: Annotated[
|
"""The name of the language model to use for our chatbot."""
|
||||||
str, {"__template_metadata__": {"kind": "llm"}}
|
|
||||||
] = "accounts/fireworks/models/firefunction-v2"
|
|
||||||
"""The name of the language model to use for the web scraping tool.
|
|
||||||
|
|
||||||
This model is specifically used for summarizing and extracting information from web pages.
|
|
||||||
"""
|
|
||||||
max_search_results: int = 10
|
|
||||||
"""The maximum number of search results to return for each search query."""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_runnable_config(
|
def from_runnable_config(
|
||||||
cls, config: Optional[RunnableConfig] = None
|
cls, config: Optional[RunnableConfig] = None
|
||||||
) -> Configuration:
|
) -> Configuration:
|
||||||
"""Create a Configuration instance from a RunnableConfig object."""
|
"""Create a Configuration instance from a RunnableConfig object."""
|
||||||
config = ensure_config(config)
|
configurable = (config.get("configurable") or {}) if config else {}
|
||||||
configurable = config.get("configurable") or {}
|
|
||||||
_fields = {f.name for f in fields(cls) if f.init}
|
_fields = {f.name for f in fields(cls) if f.init}
|
||||||
return cls(**{k: v for k, v in configurable.items() if k in _fields})
|
return cls(**{k: v for k, v in configurable.items() if k in _fields})
|
||||||
|
|||||||
@@ -4,25 +4,20 @@ Works with a chat model with tool calling support.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Dict, List, Literal, cast
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage
|
import anthropic
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
from agent.configuration import Configuration
|
||||||
|
from agent.state import State
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
from langgraph.graph import StateGraph
|
from langgraph.graph import StateGraph
|
||||||
from langgraph.prebuilt import ToolNode
|
|
||||||
|
|
||||||
from agent.configuration import Configuration
|
|
||||||
from agent.state import InputState, State
|
|
||||||
from agent.tools import TOOLS
|
|
||||||
from agent.utils import load_chat_model
|
|
||||||
|
|
||||||
# Define the function that calls the model
|
# Define the function that calls the model
|
||||||
|
|
||||||
|
|
||||||
async def call_model(
|
async def call_model(
|
||||||
state: State, config: RunnableConfig
|
state: State, config: RunnableConfig
|
||||||
) -> Dict[str, List[AIMessage]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
"""Call the LLM powering our "agent".
|
"""Call the LLM powering our "agent".
|
||||||
|
|
||||||
This function prepares the prompt, initializes the model, and processes the response.
|
This function prepares the prompt, initializes the model, and processes the response.
|
||||||
@@ -35,93 +30,44 @@ async def call_model(
|
|||||||
dict: A dictionary containing the model's response message.
|
dict: A dictionary containing the model's response message.
|
||||||
"""
|
"""
|
||||||
configuration = Configuration.from_runnable_config(config)
|
configuration = Configuration.from_runnable_config(config)
|
||||||
|
system_prompt = configuration.system_prompt.format(
|
||||||
# Create a prompt template. Customize this to change the agent's behavior.
|
system_time=datetime.now(tz=timezone.utc).isoformat()
|
||||||
prompt = ChatPromptTemplate.from_messages(
|
|
||||||
[("system", configuration.system_prompt), ("placeholder", "{messages}")]
|
|
||||||
)
|
)
|
||||||
|
toks = []
|
||||||
# Initialize the model with tool binding. Change the model or add more tools here.
|
async with anthropic.AsyncAnthropic() as client:
|
||||||
model = load_chat_model(configuration.model_name).bind_tools(TOOLS)
|
async with client.messages.stream(
|
||||||
|
model=configuration.model_name,
|
||||||
# Prepare the input for the model, including the current system time
|
max_tokens=1024,
|
||||||
message_value = await prompt.ainvoke(
|
system=system_prompt,
|
||||||
{
|
messages=state.messages,
|
||||||
"messages": state.messages,
|
) as stream:
|
||||||
"system_time": datetime.now(tz=timezone.utc).isoformat(),
|
async for text in stream.text_stream:
|
||||||
},
|
toks.append(text)
|
||||||
config,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the model's response
|
|
||||||
response = cast(AIMessage, await model.ainvoke(message_value, config))
|
|
||||||
|
|
||||||
# Handle the case when it's the last step and the model still wants to use a tool
|
|
||||||
if state.is_last_step and response.tool_calls:
|
|
||||||
return {
|
|
||||||
"messages": [
|
|
||||||
AIMessage(
|
|
||||||
id=response.id,
|
|
||||||
content="Sorry, I could not find an answer to your question in the specified number of steps.",
|
|
||||||
)
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Return the model's response as a list to be added to existing messages
|
# Return the model's response as a list to be added to existing messages
|
||||||
return {"messages": [response]}
|
return {
|
||||||
|
"messages": [
|
||||||
|
{"role": "assistant", "content": [{"type": "text", "text": "".join(toks)}]}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# Define a new graph
|
# Define a new graph
|
||||||
|
|
||||||
workflow = StateGraph(State, input=InputState, config_schema=Configuration)
|
workflow = StateGraph(State, config_schema=Configuration)
|
||||||
|
|
||||||
# Define the two nodes we will cycle between
|
# Define the two nodes we will cycle between
|
||||||
workflow.add_node(call_model)
|
workflow.add_node(call_model)
|
||||||
workflow.add_node("tools", ToolNode(TOOLS))
|
|
||||||
|
|
||||||
# Set the entrypoint as `call_model`
|
# Set the entrypoint as `call_model`
|
||||||
# This means that this node is the first one called
|
# This means that this node is the first one called
|
||||||
workflow.add_edge("__start__", "call_model")
|
workflow.add_edge("__start__", "call_model")
|
||||||
|
|
||||||
|
|
||||||
def route_model_output(state: State) -> Literal["__end__", "tools"]:
|
|
||||||
"""Determine the next node based on the model's output.
|
|
||||||
|
|
||||||
This function checks if the model's last message contains tool calls.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state (State): The current state of the conversation.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The name of the next node to call ("__end__" or "tools").
|
|
||||||
"""
|
|
||||||
last_message = state.messages[-1]
|
|
||||||
if not isinstance(last_message, AIMessage):
|
|
||||||
raise ValueError(
|
|
||||||
f"Expected AIMessage in output edges, but got {type(last_message).__name__}"
|
|
||||||
)
|
|
||||||
# If there is no tool call, then we finish
|
|
||||||
if not last_message.tool_calls:
|
|
||||||
return "__end__"
|
|
||||||
# Otherwise we execute the requested actions
|
|
||||||
return "tools"
|
|
||||||
|
|
||||||
|
|
||||||
# Add a conditional edge to determine the next step after `call_model`
|
|
||||||
workflow.add_conditional_edges(
|
|
||||||
"call_model",
|
|
||||||
# After call_model finishes running, the next node(s) are scheduled
|
|
||||||
# based on the output from route_model_output
|
|
||||||
route_model_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add a normal edge from `tools` to `call_model`
|
|
||||||
# This creates a cycle: after using tools, we always return to the model
|
|
||||||
workflow.add_edge("tools", "call_model")
|
|
||||||
|
|
||||||
# Compile the workflow into an executable graph
|
# Compile the workflow into an executable graph
|
||||||
# You can customize this by adding interrupt points for state updates
|
# You can customize this by adding interrupt points for state updates
|
||||||
graph = workflow.compile(
|
graph = workflow.compile(
|
||||||
interrupt_before=[], # Add node names here to update state before they're called
|
interrupt_before=[], # Add node names here to update state before they're called
|
||||||
interrupt_after=[], # Add node names here to update state after they're called
|
interrupt_after=[], # Add node names here to update state after they're called
|
||||||
)
|
)
|
||||||
|
graph.name = "My New Graph" # This defines the custom name in LangSmith
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""Default prompts used by the agent."""
|
"""Default prompts used by the chatbot."""
|
||||||
|
|
||||||
SYSTEM_PROMPT = """You are a helpful AI assistant.
|
SYSTEM_PROMPT = """You are a helpful (if not sassy) personal assistant.
|
||||||
|
|
||||||
System time: {system_time}"""
|
System time: {system_time}"""
|
||||||
|
|||||||
@@ -2,59 +2,23 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import operator
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
from langchain_core.messages import AnyMessage
|
|
||||||
from langgraph.graph import add_messages
|
|
||||||
from langgraph.managed import IsLastStep
|
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InputState:
|
class State:
|
||||||
"""Defines the input state for the agent, representing a narrower interface to the outside world.
|
"""Defines the input state for the agent, representing a narrower interface to the outside world.
|
||||||
|
|
||||||
This class is used to define the initial state and structure of incoming data.
|
This class is used to define the initial state and structure of incoming data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
messages: Annotated[Sequence[AnyMessage], add_messages] = field(
|
messages: Annotated[Sequence[dict], operator.add] = field(default_factory=list)
|
||||||
default_factory=list
|
|
||||||
)
|
|
||||||
"""
|
"""
|
||||||
Messages tracking the primary execution state of the agent.
|
Messages tracking the primary execution state of the agent.
|
||||||
|
|
||||||
Typically accumulates a pattern of:
|
Typically accumulates a pattern of user, assistant, user, ... etc. messages.
|
||||||
1. HumanMessage - user input
|
|
||||||
2. AIMessage with .tool_calls - agent picking tool(s) to use to collect information
|
|
||||||
3. ToolMessage(s) - the responses (or errors) from the executed tools
|
|
||||||
4. AIMessage without .tool_calls - agent responding in unstructured format to the user
|
|
||||||
5. HumanMessage - user responds with the next conversational turn
|
|
||||||
|
|
||||||
Steps 2-5 may repeat as needed.
|
|
||||||
|
|
||||||
The `add_messages` annotation ensures that new messages are merged with existing ones,
|
|
||||||
updating by ID to maintain an "append-only" state unless a message with the same ID is provided.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class State(InputState):
|
|
||||||
"""Represents the complete state of the agent, extending InputState with additional attributes.
|
|
||||||
|
|
||||||
This class can be used to store any information needed throughout the agent's lifecycle.
|
|
||||||
"""
|
|
||||||
|
|
||||||
is_last_step: IsLastStep = field(default=False)
|
|
||||||
"""
|
|
||||||
Indicates whether the current step is the last one before the graph raises an error.
|
|
||||||
|
|
||||||
This is a 'managed' variable, controlled by the state machine rather than user code.
|
|
||||||
It is set to 'True' when the step count reaches recursion_limit - 1.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Additional attributes can be added here as needed.
|
|
||||||
# Common examples include:
|
|
||||||
# retrieved_documents: List[Document] = field(default_factory=list)
|
|
||||||
# extracted_entities: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
# api_connections: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|||||||
@@ -1,27 +0,0 @@
|
|||||||
"""Utility & helper functions."""
|
|
||||||
|
|
||||||
from langchain.chat_models import init_chat_model
|
|
||||||
from langchain_core.language_models import BaseChatModel
|
|
||||||
from langchain_core.messages import BaseMessage
|
|
||||||
|
|
||||||
|
|
||||||
def get_message_text(msg: BaseMessage) -> str:
|
|
||||||
"""Get the text content of a message."""
|
|
||||||
content = msg.content
|
|
||||||
if isinstance(content, str):
|
|
||||||
return content
|
|
||||||
elif isinstance(content, dict):
|
|
||||||
return content.get("text", "")
|
|
||||||
else:
|
|
||||||
txts = [c if isinstance(c, str) else (c.get("text") or "") for c in content]
|
|
||||||
return "".join(txts).strip()
|
|
||||||
|
|
||||||
|
|
||||||
def load_chat_model(fully_specified_name: str) -> BaseChatModel:
|
|
||||||
"""Load a chat model from a fully specified name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
fully_specified_name (str): String in the format 'provider/model'.
|
|
||||||
"""
|
|
||||||
provider, model = fully_specified_name.split("/", maxsplit=1)
|
|
||||||
return init_chat_model(model, model_provider=provider)
|
|
||||||
@@ -1,14 +1,12 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from langsmith import unit
|
|
||||||
|
|
||||||
from agent import graph
|
from agent import graph
|
||||||
|
from langsmith import expect, unit
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@unit
|
@unit
|
||||||
async def test_agent_simple_passthrough() -> None:
|
async def test_agent_simple_passthrough() -> None:
|
||||||
res = await graph.ainvoke(
|
res = await graph.ainvoke(
|
||||||
{"messages": [("user", "Who is the founder of LangChain?")]}
|
{"messages": [{"role": "user", "content": "What's 62 - 19?"}]}
|
||||||
)
|
)
|
||||||
|
expect(res["messages"][-1]["content"][0]["text"]).to_contain("43")
|
||||||
assert "harrison" in str(res["messages"][-1].content).lower()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user