This commit is contained in:
William Fu-Hinthorn
2024-09-13 17:06:33 -07:00
parent ef0d1af289
commit 609708eafe
8 changed files with 60 additions and 248 deletions

View File

@@ -27,7 +27,7 @@ class Configuration:
"kind": "llm",
}
},
] = "claude-3-5-sonnet-20240620"
] = "anthropic/claude-3-5-sonnet-20240620"
"""The name of the language model to use for our chatbot."""
@classmethod

View File

@@ -6,18 +6,17 @@ Works with a chat model with tool calling support.
from datetime import datetime, timezone
from typing import Any, Dict, List
import anthropic
from agent.configuration import Configuration
from agent.state import State
from langchain_core.runnables import RunnableConfig
from langgraph.graph import StateGraph
from agent.configuration import Configuration
from agent.state import State
from agent.utils import load_chat_model
# Define the function that calls the model
async def call_model(
state: State, config: RunnableConfig
) -> Dict[str, List[Dict[str, Any]]]:
async def call_model(state: State, config: RunnableConfig) -> Dict[str, List[Any]]:
"""Call the LLM powering our "agent".
This function prepares the prompt, initializes the model, and processes the response.
@@ -33,23 +32,11 @@ async def call_model(
system_prompt = configuration.system_prompt.format(
system_time=datetime.now(tz=timezone.utc).isoformat()
)
toks = []
async with anthropic.AsyncAnthropic() as client:
async with client.messages.stream(
model=configuration.model_name,
max_tokens=1024,
system=system_prompt,
messages=state.messages,
) as stream:
async for text in stream.text_stream:
toks.append(text)
model = load_chat_model(configuration.model_name)
res = await model.ainvoke([("system", system_prompt), *state.messages])
# Return the model's response as a list to be added to existing messages
return {
"messages": [
{"role": "assistant", "content": [{"type": "text", "text": "".join(toks)}]}
]
}
return {"messages": [res]}
# Define a new graph

View File

@@ -2,10 +2,11 @@
from __future__ import annotations
import operator
from dataclasses import dataclass, field
from typing import Sequence
from typing import List
from langchain_core.messages import AnyMessage
from langgraph.graph import add_messages
from typing_extensions import Annotated
@@ -16,7 +17,7 @@ class State:
This class is used to define the initial state and structure of incoming data.
"""
messages: Annotated[Sequence[dict], operator.add] = field(default_factory=list)
messages: Annotated[List[AnyMessage], add_messages] = field(default_factory=list)
"""
Messages tracking the primary execution state of the agent.

14
src/agent/utils.py Normal file
View File

@@ -0,0 +1,14 @@
"""Utility & helper functions."""
from langchain.chat_models import init_chat_model
from langchain_core.language_models import BaseChatModel
def load_chat_model(fully_specified_name: str) -> BaseChatModel:
"""Load a chat model from a fully specified name.
Args:
fully_specified_name (str): String in the format 'provider/model'.
"""
provider, model = fully_specified_name.split("/", maxsplit=1)
return init_chat_model(model, model_provider=provider)