|
|
import os |
|
|
import warnings |
|
|
from typing import Annotated, TypedDict |
|
|
|
|
|
from dotenv import load_dotenv |
|
|
from langchain_chroma import Chroma |
|
|
from langchain_community.cache import SQLiteCache |
|
|
from langchain_core.globals import set_llm_cache |
|
|
from langchain_core.messages.human import HumanMessage |
|
|
from langchain_core.messages.system import SystemMessage |
|
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
from langchain_openai import ChatOpenAI |
|
|
from langgraph.checkpoint.memory import MemorySaver |
|
|
from langgraph.graph.message import add_messages |
|
|
from langgraph.graph.state import END, START, CompiledStateGraph, StateGraph |
|
|
from langgraph.prebuilt import ToolNode, tools_condition |
|
|
from pydantic import SecretStr |
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning, module="langchain_tavily") |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
set_llm_cache(SQLiteCache(database_path=".langchain_cache.db")) |
|
|
|
|
|
|
|
|
CHROMA_PATH = "./chroma_gaia_db" |
|
|
EMBEDDINGS = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") |
|
|
VECTOR_STORE = Chroma(persist_directory=CHROMA_PATH, embedding_function=EMBEDDINGS) |
|
|
|
|
|
|
|
|
class AgentState(TypedDict): |
|
|
"""State passed between nods in the graph""" |
|
|
|
|
|
messages: Annotated[list, add_messages] |
|
|
|
|
|
|
|
|
def load_system_prompt() -> SystemMessage: |
|
|
with open("system_prompt.txt", "r") as f: |
|
|
system_prompt = f.read() |
|
|
return SystemMessage(content=system_prompt) |
|
|
|
|
|
|
|
|
SYSTEM_PROMPT: SystemMessage = load_system_prompt() |
|
|
|
|
|
|
|
|
class GaiaAgent: |
|
|
""" |
|
|
A LangGraph agent for Gaia questions |
|
|
""" |
|
|
|
|
|
def __init__(self, model: str, temperature: float): |
|
|
"""Initialize the agent with a specific model""" |
|
|
import asyncio |
|
|
|
|
|
from tools import get_tools |
|
|
|
|
|
self.tools = asyncio.run(get_tools()) |
|
|
|
|
|
if model.startswith("glm"): |
|
|
api_key = SecretStr(secret_value=os.getenv("ZAI_API_KEY", "")) |
|
|
api_base = "https://api.z.ai/api/coding/paas/v4/" |
|
|
else: |
|
|
api_key = SecretStr(secret_value=os.getenv("OPENAI_API_KEY") or "") |
|
|
api_base = None |
|
|
|
|
|
self.llm = ChatOpenAI( |
|
|
model=model, temperature=temperature, base_url=api_base, api_key=api_key |
|
|
).bind_tools(self.tools) |
|
|
|
|
|
self.graph = self._build_graph() |
|
|
|
|
|
print(f"Initialized GaiaAgent with model: {model}, temperature: {temperature}") |
|
|
print(f"Available tools: {[tool.name for tool in self.tools]}") |
|
|
|
|
|
def _build_graph(self) -> CompiledStateGraph: |
|
|
"""Build the state graph for the agent""" |
|
|
|
|
|
graph = StateGraph(AgentState) |
|
|
|
|
|
graph.add_node("agent", self._agent_node) |
|
|
graph.add_node("tools", ToolNode(self.tools)) |
|
|
|
|
|
graph.add_edge(START, "agent") |
|
|
graph.add_conditional_edges("agent", tools_condition) |
|
|
graph.add_edge("tools", "agent") |
|
|
|
|
|
memory = MemorySaver() |
|
|
return graph.compile(checkpointer=memory) |
|
|
|
|
|
def _retriever_node(self, state: AgentState) -> AgentState: |
|
|
"""Retrieve similar questions and inject solving strategy into the question.""" |
|
|
original_question = state["messages"][0].content |
|
|
|
|
|
similar_docs = VECTOR_STORE.similarity_search(original_question, k=1) |
|
|
|
|
|
if similar_docs: |
|
|
doc = similar_docs[0] |
|
|
steps = ( |
|
|
doc.page_content.split("Steps to solve:")[-1] |
|
|
.split("Tools needed:")[0] |
|
|
.strip() |
|
|
) |
|
|
tools = doc.metadata.get("tools", "") |
|
|
|
|
|
|
|
|
enhanced_question = f"""{original_question} |
|
|
|
|
|
--- |
|
|
Strategy (from similar solved question): |
|
|
{steps} |
|
|
|
|
|
Tools needed: {tools} |
|
|
|
|
|
Follow a similar approach to solve the question above.""" |
|
|
|
|
|
enhanced_msg = HumanMessage(content=enhanced_question) |
|
|
return {"messages": [SYSTEM_PROMPT, enhanced_msg]} |
|
|
|
|
|
return {"messages": [SYSTEM_PROMPT] + state["messages"]} |
|
|
|
|
|
def _tools_node(self, state: AgentState) -> AgentState: |
|
|
"""Execute tools and log results.""" |
|
|
tool_node = ToolNode(self.tools) |
|
|
result = tool_node.invoke(state) |
|
|
|
|
|
|
|
|
for msg in result.get("messages", []): |
|
|
content = getattr(msg, "content", str(msg)) |
|
|
name = getattr(msg, "name", "unknown") |
|
|
print(f" Tool result [{name}]: {content[:300]}...") |
|
|
|
|
|
return result |
|
|
|
|
|
async def __call__(self, question: str) -> str: |
|
|
""" |
|
|
Run the agent on a given question and return the answer |
|
|
|
|
|
Args: |
|
|
question (str): The input question to the agent |
|
|
|
|
|
Returns: |
|
|
str: The agent's answer to the question |
|
|
""" |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"Agent received question: {question[:100]}...") |
|
|
print(f"{'='*60}\n") |
|
|
|
|
|
initial_state = { |
|
|
"messages": [HumanMessage(content=question)], |
|
|
} |
|
|
|
|
|
try: |
|
|
import uuid |
|
|
|
|
|
thread_id = str(uuid.uuid4()) |
|
|
config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 50} |
|
|
final_state = await self.graph.ainvoke(initial_state, config) |
|
|
|
|
|
last_message = final_state["messages"][-1] |
|
|
|
|
|
answer = ( |
|
|
str(last_message.content) |
|
|
if hasattr(last_message, "content") |
|
|
else str(last_message) |
|
|
) |
|
|
|
|
|
|
|
|
answer = self._clean_answer(answer) |
|
|
|
|
|
print(f"Agent final response: {answer[:200]}...\n") |
|
|
|
|
|
return answer |
|
|
except Exception as e: |
|
|
print(f"Error during agent execution: {e}") |
|
|
return f"AGENT ERROR: {e}" |
|
|
|
|
|
def _clean_answer(self, answer: str) -> str: |
|
|
"""Extract clean answer from various formats.""" |
|
|
import re |
|
|
|
|
|
|
|
|
match = re.search(r"<solution>(.*?)</solution>", answer, re.DOTALL) |
|
|
if match: |
|
|
return match.group(1).strip() |
|
|
|
|
|
|
|
|
match = re.search(r"FINAL ANSWER:\s*(.+?)(?:\n|$)", answer, re.IGNORECASE) |
|
|
if match: |
|
|
return match.group(1).strip() |
|
|
|
|
|
|
|
|
match = re.search( |
|
|
r"\*\*FINAL ANSWER:?\*\*:?\s*(.+?)(?:\n|$)", answer, re.IGNORECASE |
|
|
) |
|
|
if match: |
|
|
return match.group(1).strip() |
|
|
|
|
|
|
|
|
|
|
|
match = re.search( |
|
|
r":\s*\n?\s*([a-z][a-z\s,]+(?:,\s*[a-z][a-z\s]+)+)\s*$", |
|
|
answer, |
|
|
re.IGNORECASE, |
|
|
) |
|
|
if match: |
|
|
return match.group(1).strip() |
|
|
|
|
|
|
|
|
lines = answer.strip().split("\n") |
|
|
last_line = lines[-1].strip() |
|
|
if "," in last_line and len(last_line) < 500: |
|
|
|
|
|
items = [i.strip() for i in last_line.split(",")] |
|
|
if len(items) >= 2 and all(len(i) < 100 for i in items): |
|
|
return last_line |
|
|
|
|
|
return answer.strip() |
|
|
|
|
|
def _agent_node(self, state: AgentState) -> AgentState: |
|
|
"""The main agent node that processes messages and generates responses""" |
|
|
messages = state["messages"] |
|
|
|
|
|
|
|
|
print(f"\n[AGENT] Message count: {len(messages)}") |
|
|
|
|
|
|
|
|
if not messages or not isinstance(messages[0], SystemMessage): |
|
|
messages = [SYSTEM_PROMPT] + messages |
|
|
|
|
|
|
|
|
print("[AGENT] === MESSAGES ===") |
|
|
for i, msg in enumerate(messages): |
|
|
msg_type = type(msg).__name__ |
|
|
content = ( |
|
|
str(msg.content)[:500] if hasattr(msg, "content") else str(msg)[:500] |
|
|
) |
|
|
print(f" [{i}] {msg_type}: {content}...") |
|
|
print("[AGENT] === END MESSAGES ===\n") |
|
|
|
|
|
response = self.llm.invoke(messages) |
|
|
|
|
|
|
|
|
if hasattr(response, "tool_calls") and response.tool_calls: |
|
|
print( |
|
|
f"[AGENT] Calling tools: {[tc['name'] for tc in response.tool_calls]}" |
|
|
) |
|
|
else: |
|
|
content = ( |
|
|
str(response.content)[:200] |
|
|
if hasattr(response, "content") |
|
|
else str(response)[:200] |
|
|
) |
|
|
print(f"[AGENT] Final response: {content}...") |
|
|
|
|
|
return {"messages": [response]} |
|
|
|
|
|
|
|
|
|
|
|
MODEL = "glm-4.7" |
|
|
|
|
|
BasicAgent = GaiaAgent(model=MODEL, temperature=1.0) |
|
|
|