update
This commit is contained in:
@@ -27,7 +27,7 @@ class Configuration:
|
||||
"kind": "llm",
|
||||
}
|
||||
},
|
||||
] = "claude-3-5-sonnet-20240620"
|
||||
] = "anthropic/claude-3-5-sonnet-20240620"
|
||||
"""The name of the language model to use for our chatbot."""
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -6,18 +6,17 @@ Works with a chat model with tool calling support.
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import anthropic
|
||||
from agent.configuration import Configuration
|
||||
from agent.state import State
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from agent.configuration import Configuration
|
||||
from agent.state import State
|
||||
from agent.utils import load_chat_model
|
||||
|
||||
# Define the function that calls the model
|
||||
|
||||
|
||||
async def call_model(
|
||||
state: State, config: RunnableConfig
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
async def call_model(state: State, config: RunnableConfig) -> Dict[str, List[Any]]:
|
||||
"""Call the LLM powering our "agent".
|
||||
|
||||
This function prepares the prompt, initializes the model, and processes the response.
|
||||
@@ -33,23 +32,11 @@ async def call_model(
|
||||
system_prompt = configuration.system_prompt.format(
|
||||
system_time=datetime.now(tz=timezone.utc).isoformat()
|
||||
)
|
||||
toks = []
|
||||
async with anthropic.AsyncAnthropic() as client:
|
||||
async with client.messages.stream(
|
||||
model=configuration.model_name,
|
||||
max_tokens=1024,
|
||||
system=system_prompt,
|
||||
messages=state.messages,
|
||||
) as stream:
|
||||
async for text in stream.text_stream:
|
||||
toks.append(text)
|
||||
model = load_chat_model(configuration.model_name)
|
||||
res = await model.ainvoke([("system", system_prompt), *state.messages])
|
||||
|
||||
# Return the model's response as a list to be added to existing messages
|
||||
return {
|
||||
"messages": [
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "".join(toks)}]}
|
||||
]
|
||||
}
|
||||
return {"messages": [res]}
|
||||
|
||||
|
||||
# Define a new graph
|
||||
|
||||
@@ -2,10 +2,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Sequence
|
||||
from typing import List
|
||||
|
||||
from langchain_core.messages import AnyMessage
|
||||
from langgraph.graph import add_messages
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
@@ -16,7 +17,7 @@ class State:
|
||||
This class is used to define the initial state and structure of incoming data.
|
||||
"""
|
||||
|
||||
messages: Annotated[Sequence[dict], operator.add] = field(default_factory=list)
|
||||
messages: Annotated[List[AnyMessage], add_messages] = field(default_factory=list)
|
||||
"""
|
||||
Messages tracking the primary execution state of the agent.
|
||||
|
||||
|
||||
14
src/agent/utils.py
Normal file
14
src/agent/utils.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Utility & helper functions."""
|
||||
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
|
||||
def load_chat_model(fully_specified_name: str) -> BaseChatModel:
|
||||
"""Load a chat model from a fully specified name.
|
||||
|
||||
Args:
|
||||
fully_specified_name (str): String in the format 'provider/model'.
|
||||
"""
|
||||
provider, model = fully_specified_name.split("/", maxsplit=1)
|
||||
return init_chat_model(model, model_provider=provider)
|
||||
Reference in New Issue
Block a user