fa_agents / agent.py
j14i's picture
Got 45%
e04e3db
import os
import warnings
from typing import Annotated, TypedDict
from dotenv import load_dotenv
from langchain_chroma import Chroma
from langchain_community.cache import SQLiteCache
from langchain_core.globals import set_llm_cache
from langchain_core.messages.human import HumanMessage
from langchain_core.messages.system import SystemMessage
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph.message import add_messages
from langgraph.graph.state import END, START, CompiledStateGraph, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition
from pydantic import SecretStr
warnings.filterwarnings("ignore", category=UserWarning, module="langchain_tavily")
load_dotenv()
# from langchain_core.caches import InMemoryCache
# set_llm_cache(InMemoryCache())
set_llm_cache(SQLiteCache(database_path=".langchain_cache.db"))
# Initialize RAG vector store
CHROMA_PATH = "./chroma_gaia_db"
EMBEDDINGS = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
VECTOR_STORE = Chroma(persist_directory=CHROMA_PATH, embedding_function=EMBEDDINGS)
class AgentState(TypedDict):
"""State passed between nods in the graph"""
messages: Annotated[list, add_messages]
def load_system_prompt() -> SystemMessage:
with open("system_prompt.txt", "r") as f:
system_prompt = f.read()
return SystemMessage(content=system_prompt)
SYSTEM_PROMPT: SystemMessage = load_system_prompt()
class GaiaAgent:
"""
A LangGraph agent for Gaia questions
"""
def __init__(self, model: str, temperature: float):
"""Initialize the agent with a specific model"""
import asyncio
from tools import get_tools
self.tools = asyncio.run(get_tools())
if model.startswith("glm"):
api_key = SecretStr(secret_value=os.getenv("ZAI_API_KEY", ""))
api_base = "https://api.z.ai/api/coding/paas/v4/"
else:
api_key = SecretStr(secret_value=os.getenv("OPENAI_API_KEY") or "")
api_base = None
self.llm = ChatOpenAI(
model=model, temperature=temperature, base_url=api_base, api_key=api_key
).bind_tools(self.tools)
self.graph = self._build_graph()
print(f"Initialized GaiaAgent with model: {model}, temperature: {temperature}")
print(f"Available tools: {[tool.name for tool in self.tools]}")
def _build_graph(self) -> CompiledStateGraph:
"""Build the state graph for the agent"""
graph = StateGraph(AgentState)
graph.add_node("agent", self._agent_node)
graph.add_node("tools", ToolNode(self.tools))
graph.add_edge(START, "agent")
graph.add_conditional_edges("agent", tools_condition)
graph.add_edge("tools", "agent")
memory = MemorySaver()
return graph.compile(checkpointer=memory)
def _retriever_node(self, state: AgentState) -> AgentState:
"""Retrieve similar questions and inject solving strategy into the question."""
original_question = state["messages"][0].content
similar_docs = VECTOR_STORE.similarity_search(original_question, k=1)
if similar_docs:
doc = similar_docs[0]
steps = (
doc.page_content.split("Steps to solve:")[-1]
.split("Tools needed:")[0]
.strip()
)
tools = doc.metadata.get("tools", "")
# Build enhanced question with strategy
enhanced_question = f"""{original_question}
---
Strategy (from similar solved question):
{steps}
Tools needed: {tools}
Follow a similar approach to solve the question above."""
enhanced_msg = HumanMessage(content=enhanced_question)
return {"messages": [SYSTEM_PROMPT, enhanced_msg]}
return {"messages": [SYSTEM_PROMPT] + state["messages"]}
def _tools_node(self, state: AgentState) -> AgentState:
"""Execute tools and log results."""
tool_node = ToolNode(self.tools)
result = tool_node.invoke(state)
# Log tool results and check for answers
for msg in result.get("messages", []):
content = getattr(msg, "content", str(msg))
name = getattr(msg, "name", "unknown")
print(f" Tool result [{name}]: {content[:300]}...")
return result
async def __call__(self, question: str) -> str:
"""
Run the agent on a given question and return the answer
Args:
question (str): The input question to the agent
Returns:
str: The agent's answer to the question
"""
print(f"\n{'='*60}")
print(f"Agent received question: {question[:100]}...")
print(f"{'='*60}\n")
initial_state = {
"messages": [HumanMessage(content=question)],
}
try:
import uuid
thread_id = str(uuid.uuid4())
config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 50}
final_state = await self.graph.ainvoke(initial_state, config)
last_message = final_state["messages"][-1]
answer = (
str(last_message.content)
if hasattr(last_message, "content")
else str(last_message)
)
# Clean up answer - extract from tags if present
answer = self._clean_answer(answer)
print(f"Agent final response: {answer[:200]}...\n")
return answer
except Exception as e:
print(f"Error during agent execution: {e}")
return f"AGENT ERROR: {e}"
def _clean_answer(self, answer: str) -> str:
"""Extract clean answer from various formats."""
import re
# Extract from <solution>...</solution>
match = re.search(r"<solution>(.*?)</solution>", answer, re.DOTALL)
if match:
return match.group(1).strip()
# Extract from FINAL ANSWER: ... (to end of line or string)
match = re.search(r"FINAL ANSWER:\s*(.+?)(?:\n|$)", answer, re.IGNORECASE)
if match:
return match.group(1).strip()
# Extract from **FINAL ANSWER:** or similar markdown
match = re.search(
r"\*\*FINAL ANSWER:?\*\*:?\s*(.+?)(?:\n|$)", answer, re.IGNORECASE
)
if match:
return match.group(1).strip()
# If answer contains a colon followed by a list, extract just the list part
# e.g., "...ingredients: cornstarch, sugar, ..."
match = re.search(
r":\s*\n?\s*([a-z][a-z\s,]+(?:,\s*[a-z][a-z\s]+)+)\s*$",
answer,
re.IGNORECASE,
)
if match:
return match.group(1).strip()
# Last resort: if there's a clear comma-separated list at the end, extract it
lines = answer.strip().split("\n")
last_line = lines[-1].strip()
if "," in last_line and len(last_line) < 500:
# Check if it looks like a list (multiple comma-separated items)
items = [i.strip() for i in last_line.split(",")]
if len(items) >= 2 and all(len(i) < 100 for i in items):
return last_line
return answer.strip()
def _agent_node(self, state: AgentState) -> AgentState:
"""The main agent node that processes messages and generates responses"""
messages = state["messages"]
# Debug: show message count
print(f"\n[AGENT] Message count: {len(messages)}")
# Prepend system prompt if not already there
if not messages or not isinstance(messages[0], SystemMessage):
messages = [SYSTEM_PROMPT] + messages
# Print the full prompt/messages
print("[AGENT] === MESSAGES ===")
for i, msg in enumerate(messages):
msg_type = type(msg).__name__
content = (
str(msg.content)[:500] if hasattr(msg, "content") else str(msg)[:500]
)
print(f" [{i}] {msg_type}: {content}...")
print("[AGENT] === END MESSAGES ===\n")
response = self.llm.invoke(messages)
# Log what the agent is doing
if hasattr(response, "tool_calls") and response.tool_calls:
print(
f"[AGENT] Calling tools: {[tc['name'] for tc in response.tool_calls]}"
)
else:
content = (
str(response.content)[:200]
if hasattr(response, "content")
else str(response)[:200]
)
print(f"[AGENT] Final response: {content}...")
return {"messages": [response]}
# model="o3-mini"
MODEL = "glm-4.7"
BasicAgent = GaiaAgent(model=MODEL, temperature=1.0)