import os import warnings from typing import Annotated, TypedDict from dotenv import load_dotenv from langchain_chroma import Chroma from langchain_community.cache import SQLiteCache from langchain_core.globals import set_llm_cache from langchain_core.messages.human import HumanMessage from langchain_core.messages.system import SystemMessage from langchain_huggingface import HuggingFaceEmbeddings from langchain_openai import ChatOpenAI from langgraph.checkpoint.memory import MemorySaver from langgraph.graph.message import add_messages from langgraph.graph.state import END, START, CompiledStateGraph, StateGraph from langgraph.prebuilt import ToolNode, tools_condition from pydantic import SecretStr warnings.filterwarnings("ignore", category=UserWarning, module="langchain_tavily") load_dotenv() # from langchain_core.caches import InMemoryCache # set_llm_cache(InMemoryCache()) set_llm_cache(SQLiteCache(database_path=".langchain_cache.db")) # Initialize RAG vector store CHROMA_PATH = "./chroma_gaia_db" EMBEDDINGS = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") VECTOR_STORE = Chroma(persist_directory=CHROMA_PATH, embedding_function=EMBEDDINGS) class AgentState(TypedDict): """State passed between nods in the graph""" messages: Annotated[list, add_messages] def load_system_prompt() -> SystemMessage: with open("system_prompt.txt", "r") as f: system_prompt = f.read() return SystemMessage(content=system_prompt) SYSTEM_PROMPT: SystemMessage = load_system_prompt() class GaiaAgent: """ A LangGraph agent for Gaia questions """ def __init__(self, model: str, temperature: float): """Initialize the agent with a specific model""" import asyncio from tools import get_tools self.tools = asyncio.run(get_tools()) if model.startswith("glm"): api_key = SecretStr(secret_value=os.getenv("ZAI_API_KEY", "")) api_base = "https://api.z.ai/api/coding/paas/v4/" else: api_key = SecretStr(secret_value=os.getenv("OPENAI_API_KEY") or "") api_base = None self.llm = ChatOpenAI( model=model, temperature=temperature, base_url=api_base, api_key=api_key ).bind_tools(self.tools) self.graph = self._build_graph() print(f"Initialized GaiaAgent with model: {model}, temperature: {temperature}") print(f"Available tools: {[tool.name for tool in self.tools]}") def _build_graph(self) -> CompiledStateGraph: """Build the state graph for the agent""" graph = StateGraph(AgentState) graph.add_node("agent", self._agent_node) graph.add_node("tools", ToolNode(self.tools)) graph.add_edge(START, "agent") graph.add_conditional_edges("agent", tools_condition) graph.add_edge("tools", "agent") memory = MemorySaver() return graph.compile(checkpointer=memory) def _retriever_node(self, state: AgentState) -> AgentState: """Retrieve similar questions and inject solving strategy into the question.""" original_question = state["messages"][0].content similar_docs = VECTOR_STORE.similarity_search(original_question, k=1) if similar_docs: doc = similar_docs[0] steps = ( doc.page_content.split("Steps to solve:")[-1] .split("Tools needed:")[0] .strip() ) tools = doc.metadata.get("tools", "") # Build enhanced question with strategy enhanced_question = f"""{original_question} --- Strategy (from similar solved question): {steps} Tools needed: {tools} Follow a similar approach to solve the question above.""" enhanced_msg = HumanMessage(content=enhanced_question) return {"messages": [SYSTEM_PROMPT, enhanced_msg]} return {"messages": [SYSTEM_PROMPT] + state["messages"]} def _tools_node(self, state: AgentState) -> AgentState: """Execute tools and log results.""" tool_node = ToolNode(self.tools) result = tool_node.invoke(state) # Log tool results and check for answers for msg in result.get("messages", []): content = getattr(msg, "content", str(msg)) name = getattr(msg, "name", "unknown") print(f" Tool result [{name}]: {content[:300]}...") return result async def __call__(self, question: str) -> str: """ Run the agent on a given question and return the answer Args: question (str): The input question to the agent Returns: str: The agent's answer to the question """ print(f"\n{'='*60}") print(f"Agent received question: {question[:100]}...") print(f"{'='*60}\n") initial_state = { "messages": [HumanMessage(content=question)], } try: import uuid thread_id = str(uuid.uuid4()) config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 50} final_state = await self.graph.ainvoke(initial_state, config) last_message = final_state["messages"][-1] answer = ( str(last_message.content) if hasattr(last_message, "content") else str(last_message) ) # Clean up answer - extract from tags if present answer = self._clean_answer(answer) print(f"Agent final response: {answer[:200]}...\n") return answer except Exception as e: print(f"Error during agent execution: {e}") return f"AGENT ERROR: {e}" def _clean_answer(self, answer: str) -> str: """Extract clean answer from various formats.""" import re # Extract from ... match = re.search(r"(.*?)", answer, re.DOTALL) if match: return match.group(1).strip() # Extract from FINAL ANSWER: ... (to end of line or string) match = re.search(r"FINAL ANSWER:\s*(.+?)(?:\n|$)", answer, re.IGNORECASE) if match: return match.group(1).strip() # Extract from **FINAL ANSWER:** or similar markdown match = re.search( r"\*\*FINAL ANSWER:?\*\*:?\s*(.+?)(?:\n|$)", answer, re.IGNORECASE ) if match: return match.group(1).strip() # If answer contains a colon followed by a list, extract just the list part # e.g., "...ingredients: cornstarch, sugar, ..." match = re.search( r":\s*\n?\s*([a-z][a-z\s,]+(?:,\s*[a-z][a-z\s]+)+)\s*$", answer, re.IGNORECASE, ) if match: return match.group(1).strip() # Last resort: if there's a clear comma-separated list at the end, extract it lines = answer.strip().split("\n") last_line = lines[-1].strip() if "," in last_line and len(last_line) < 500: # Check if it looks like a list (multiple comma-separated items) items = [i.strip() for i in last_line.split(",")] if len(items) >= 2 and all(len(i) < 100 for i in items): return last_line return answer.strip() def _agent_node(self, state: AgentState) -> AgentState: """The main agent node that processes messages and generates responses""" messages = state["messages"] # Debug: show message count print(f"\n[AGENT] Message count: {len(messages)}") # Prepend system prompt if not already there if not messages or not isinstance(messages[0], SystemMessage): messages = [SYSTEM_PROMPT] + messages # Print the full prompt/messages print("[AGENT] === MESSAGES ===") for i, msg in enumerate(messages): msg_type = type(msg).__name__ content = ( str(msg.content)[:500] if hasattr(msg, "content") else str(msg)[:500] ) print(f" [{i}] {msg_type}: {content}...") print("[AGENT] === END MESSAGES ===\n") response = self.llm.invoke(messages) # Log what the agent is doing if hasattr(response, "tool_calls") and response.tool_calls: print( f"[AGENT] Calling tools: {[tc['name'] for tc in response.tool_calls]}" ) else: content = ( str(response.content)[:200] if hasattr(response, "content") else str(response)[:200] ) print(f"[AGENT] Final response: {content}...") return {"messages": [response]} # model="o3-mini" MODEL = "glm-4.7" BasicAgent = GaiaAgent(model=MODEL, temperature=1.0)