Got 45%
Browse files- agent.py +76 -21
- app.py +10 -3
- build_rag_index.py +104 -0
- chroma_gaia_db/.gitattributes +36 -0
- chroma_gaia_db/99bb1417-fe53-457a-8b1f-42a54fb4c17c/data_level0.bin +3 -0
- chroma_gaia_db/99bb1417-fe53-457a-8b1f-42a54fb4c17c/header.bin +3 -0
- chroma_gaia_db/99bb1417-fe53-457a-8b1f-42a54fb4c17c/length.bin +3 -0
- chroma_gaia_db/99bb1417-fe53-457a-8b1f-42a54fb4c17c/link_lists.bin +0 -0
- chroma_gaia_db/chroma.sqlite3 +3 -0
- pyproject.toml +9 -0
- system_prompt.txt +11 -32
- test_bench.py +3 -1
- tools.py +423 -83
agent.py
CHANGED
|
@@ -3,11 +3,14 @@ import warnings
|
|
| 3 |
from typing import Annotated, TypedDict
|
| 4 |
|
| 5 |
from dotenv import load_dotenv
|
|
|
|
| 6 |
from langchain_community.cache import SQLiteCache
|
| 7 |
from langchain_core.globals import set_llm_cache
|
| 8 |
from langchain_core.messages.human import HumanMessage
|
| 9 |
from langchain_core.messages.system import SystemMessage
|
| 10 |
from langchain_huggingface import HuggingFaceEmbeddings
|
|
|
|
|
|
|
| 11 |
from langgraph.graph.message import add_messages
|
| 12 |
from langgraph.graph.state import END, START, CompiledStateGraph, StateGraph
|
| 13 |
from langgraph.prebuilt import ToolNode, tools_condition
|
|
@@ -22,6 +25,11 @@ load_dotenv()
|
|
| 22 |
# set_llm_cache(InMemoryCache())
|
| 23 |
set_llm_cache(SQLiteCache(database_path=".langchain_cache.db"))
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
class AgentState(TypedDict):
|
| 27 |
"""State passed between nods in the graph"""
|
|
@@ -38,25 +46,22 @@ def load_system_prompt() -> SystemMessage:
|
|
| 38 |
SYSTEM_PROMPT: SystemMessage = load_system_prompt()
|
| 39 |
|
| 40 |
|
| 41 |
-
|
| 42 |
class GaiaAgent:
|
| 43 |
"""
|
| 44 |
A LangGraph agent for Gaia questions
|
| 45 |
"""
|
| 46 |
|
| 47 |
-
def __init__(self, model: str
|
| 48 |
"""Initialize the agent with a specific model"""
|
| 49 |
-
|
| 50 |
|
| 51 |
from tools import get_tools
|
| 52 |
|
| 53 |
-
self.tools = get_tools()
|
| 54 |
|
| 55 |
if model.startswith("glm"):
|
| 56 |
api_key = SecretStr(secret_value=os.getenv("ZAI_API_KEY", ""))
|
| 57 |
-
api_base = "https://api.z.ai/api/paas/v4/"
|
| 58 |
-
if os.getenv("ZAI_USE_CODING_PLAN", "f") == "t":
|
| 59 |
-
api_base = "https://api.z.ai/api/coding/paas/v4/"
|
| 60 |
else:
|
| 61 |
api_key = SecretStr(secret_value=os.getenv("OPENAI_API_KEY") or "")
|
| 62 |
api_base = None
|
|
@@ -75,30 +80,44 @@ class GaiaAgent:
|
|
| 75 |
|
| 76 |
graph = StateGraph(AgentState)
|
| 77 |
|
| 78 |
-
# graph.add_node("retriever", self._retriever_node)
|
| 79 |
graph.add_node("agent", self._agent_node)
|
| 80 |
graph.add_node("tools", ToolNode(self.tools))
|
| 81 |
-
# graph.add_node("tools", self._tools_node)
|
| 82 |
|
| 83 |
-
# graph.add_edge(START, "retriever")
|
| 84 |
-
# graph.add_edge("retriever", "agent")
|
| 85 |
graph.add_edge(START, "agent")
|
| 86 |
graph.add_conditional_edges("agent", tools_condition)
|
| 87 |
graph.add_edge("tools", "agent")
|
| 88 |
|
| 89 |
-
|
|
|
|
| 90 |
|
| 91 |
def _retriever_node(self, state: AgentState) -> AgentState:
|
| 92 |
-
"""Retrieve similar questions
|
| 93 |
-
|
| 94 |
|
| 95 |
-
similar_docs = VECTOR_STORE.similarity_search(
|
| 96 |
|
| 97 |
if similar_docs:
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
| 100 |
)
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
return {"messages": [SYSTEM_PROMPT] + state["messages"]}
|
| 104 |
|
|
@@ -115,7 +134,7 @@ class GaiaAgent:
|
|
| 115 |
|
| 116 |
return result
|
| 117 |
|
| 118 |
-
def __call__(self, question: str) -> str:
|
| 119 |
"""
|
| 120 |
Run the agent on a given question and return the answer
|
| 121 |
|
|
@@ -135,7 +154,11 @@ class GaiaAgent:
|
|
| 135 |
}
|
| 136 |
|
| 137 |
try:
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
last_message = final_state["messages"][-1]
|
| 141 |
|
|
@@ -200,11 +223,43 @@ class GaiaAgent:
|
|
| 200 |
def _agent_node(self, state: AgentState) -> AgentState:
|
| 201 |
"""The main agent node that processes messages and generates responses"""
|
| 202 |
messages = state["messages"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
response = self.llm.invoke(messages)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
return {"messages": [response]}
|
| 205 |
|
| 206 |
|
| 207 |
# model="o3-mini"
|
| 208 |
MODEL = "glm-4.7"
|
| 209 |
|
| 210 |
-
BasicAgent = GaiaAgent(model=MODEL, temperature=
|
|
|
|
| 3 |
from typing import Annotated, TypedDict
|
| 4 |
|
| 5 |
from dotenv import load_dotenv
|
| 6 |
+
from langchain_chroma import Chroma
|
| 7 |
from langchain_community.cache import SQLiteCache
|
| 8 |
from langchain_core.globals import set_llm_cache
|
| 9 |
from langchain_core.messages.human import HumanMessage
|
| 10 |
from langchain_core.messages.system import SystemMessage
|
| 11 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 12 |
+
from langchain_openai import ChatOpenAI
|
| 13 |
+
from langgraph.checkpoint.memory import MemorySaver
|
| 14 |
from langgraph.graph.message import add_messages
|
| 15 |
from langgraph.graph.state import END, START, CompiledStateGraph, StateGraph
|
| 16 |
from langgraph.prebuilt import ToolNode, tools_condition
|
|
|
|
| 25 |
# set_llm_cache(InMemoryCache())
|
| 26 |
set_llm_cache(SQLiteCache(database_path=".langchain_cache.db"))
|
| 27 |
|
| 28 |
+
# Initialize RAG vector store
|
| 29 |
+
CHROMA_PATH = "./chroma_gaia_db"
|
| 30 |
+
EMBEDDINGS = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
|
| 31 |
+
VECTOR_STORE = Chroma(persist_directory=CHROMA_PATH, embedding_function=EMBEDDINGS)
|
| 32 |
+
|
| 33 |
|
| 34 |
class AgentState(TypedDict):
|
| 35 |
"""State passed between nods in the graph"""
|
|
|
|
| 46 |
SYSTEM_PROMPT: SystemMessage = load_system_prompt()
|
| 47 |
|
| 48 |
|
|
|
|
| 49 |
class GaiaAgent:
|
| 50 |
"""
|
| 51 |
A LangGraph agent for Gaia questions
|
| 52 |
"""
|
| 53 |
|
| 54 |
+
def __init__(self, model: str, temperature: float):
|
| 55 |
"""Initialize the agent with a specific model"""
|
| 56 |
+
import asyncio
|
| 57 |
|
| 58 |
from tools import get_tools
|
| 59 |
|
| 60 |
+
self.tools = asyncio.run(get_tools())
|
| 61 |
|
| 62 |
if model.startswith("glm"):
|
| 63 |
api_key = SecretStr(secret_value=os.getenv("ZAI_API_KEY", ""))
|
| 64 |
+
api_base = "https://api.z.ai/api/coding/paas/v4/"
|
|
|
|
|
|
|
| 65 |
else:
|
| 66 |
api_key = SecretStr(secret_value=os.getenv("OPENAI_API_KEY") or "")
|
| 67 |
api_base = None
|
|
|
|
| 80 |
|
| 81 |
graph = StateGraph(AgentState)
|
| 82 |
|
|
|
|
| 83 |
graph.add_node("agent", self._agent_node)
|
| 84 |
graph.add_node("tools", ToolNode(self.tools))
|
|
|
|
| 85 |
|
|
|
|
|
|
|
| 86 |
graph.add_edge(START, "agent")
|
| 87 |
graph.add_conditional_edges("agent", tools_condition)
|
| 88 |
graph.add_edge("tools", "agent")
|
| 89 |
|
| 90 |
+
memory = MemorySaver()
|
| 91 |
+
return graph.compile(checkpointer=memory)
|
| 92 |
|
| 93 |
def _retriever_node(self, state: AgentState) -> AgentState:
|
| 94 |
+
"""Retrieve similar questions and inject solving strategy into the question."""
|
| 95 |
+
original_question = state["messages"][0].content
|
| 96 |
|
| 97 |
+
similar_docs = VECTOR_STORE.similarity_search(original_question, k=1)
|
| 98 |
|
| 99 |
if similar_docs:
|
| 100 |
+
doc = similar_docs[0]
|
| 101 |
+
steps = (
|
| 102 |
+
doc.page_content.split("Steps to solve:")[-1]
|
| 103 |
+
.split("Tools needed:")[0]
|
| 104 |
+
.strip()
|
| 105 |
)
|
| 106 |
+
tools = doc.metadata.get("tools", "")
|
| 107 |
+
|
| 108 |
+
# Build enhanced question with strategy
|
| 109 |
+
enhanced_question = f"""{original_question}
|
| 110 |
+
|
| 111 |
+
---
|
| 112 |
+
Strategy (from similar solved question):
|
| 113 |
+
{steps}
|
| 114 |
+
|
| 115 |
+
Tools needed: {tools}
|
| 116 |
+
|
| 117 |
+
Follow a similar approach to solve the question above."""
|
| 118 |
+
|
| 119 |
+
enhanced_msg = HumanMessage(content=enhanced_question)
|
| 120 |
+
return {"messages": [SYSTEM_PROMPT, enhanced_msg]}
|
| 121 |
|
| 122 |
return {"messages": [SYSTEM_PROMPT] + state["messages"]}
|
| 123 |
|
|
|
|
| 134 |
|
| 135 |
return result
|
| 136 |
|
| 137 |
+
async def __call__(self, question: str) -> str:
|
| 138 |
"""
|
| 139 |
Run the agent on a given question and return the answer
|
| 140 |
|
|
|
|
| 154 |
}
|
| 155 |
|
| 156 |
try:
|
| 157 |
+
import uuid
|
| 158 |
+
|
| 159 |
+
thread_id = str(uuid.uuid4())
|
| 160 |
+
config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 50}
|
| 161 |
+
final_state = await self.graph.ainvoke(initial_state, config)
|
| 162 |
|
| 163 |
last_message = final_state["messages"][-1]
|
| 164 |
|
|
|
|
| 223 |
def _agent_node(self, state: AgentState) -> AgentState:
|
| 224 |
"""The main agent node that processes messages and generates responses"""
|
| 225 |
messages = state["messages"]
|
| 226 |
+
|
| 227 |
+
# Debug: show message count
|
| 228 |
+
print(f"\n[AGENT] Message count: {len(messages)}")
|
| 229 |
+
|
| 230 |
+
# Prepend system prompt if not already there
|
| 231 |
+
if not messages or not isinstance(messages[0], SystemMessage):
|
| 232 |
+
messages = [SYSTEM_PROMPT] + messages
|
| 233 |
+
|
| 234 |
+
# Print the full prompt/messages
|
| 235 |
+
print("[AGENT] === MESSAGES ===")
|
| 236 |
+
for i, msg in enumerate(messages):
|
| 237 |
+
msg_type = type(msg).__name__
|
| 238 |
+
content = (
|
| 239 |
+
str(msg.content)[:500] if hasattr(msg, "content") else str(msg)[:500]
|
| 240 |
+
)
|
| 241 |
+
print(f" [{i}] {msg_type}: {content}...")
|
| 242 |
+
print("[AGENT] === END MESSAGES ===\n")
|
| 243 |
+
|
| 244 |
response = self.llm.invoke(messages)
|
| 245 |
+
|
| 246 |
+
# Log what the agent is doing
|
| 247 |
+
if hasattr(response, "tool_calls") and response.tool_calls:
|
| 248 |
+
print(
|
| 249 |
+
f"[AGENT] Calling tools: {[tc['name'] for tc in response.tool_calls]}"
|
| 250 |
+
)
|
| 251 |
+
else:
|
| 252 |
+
content = (
|
| 253 |
+
str(response.content)[:200]
|
| 254 |
+
if hasattr(response, "content")
|
| 255 |
+
else str(response)[:200]
|
| 256 |
+
)
|
| 257 |
+
print(f"[AGENT] Final response: {content}...")
|
| 258 |
+
|
| 259 |
return {"messages": [response]}
|
| 260 |
|
| 261 |
|
| 262 |
# model="o3-mini"
|
| 263 |
MODEL = "glm-4.7"
|
| 264 |
|
| 265 |
+
BasicAgent = GaiaAgent(model=MODEL, temperature=1.0)
|
app.py
CHANGED
|
@@ -11,7 +11,7 @@ from agent import BasicAgent
|
|
| 11 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
| 12 |
|
| 13 |
|
| 14 |
-
def run_and_submit_all(profile: gr.OAuthProfile | None):
|
| 15 |
"""
|
| 16 |
Fetches all questions, runs the BasicAgent on them, submits all answers,
|
| 17 |
and displays the results.
|
|
@@ -72,7 +72,7 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
|
|
| 72 |
print(f"Skipping item with missing task_id or question: {item}")
|
| 73 |
continue
|
| 74 |
try:
|
| 75 |
-
submitted_answer = agent(question_text)
|
| 76 |
print(f"Task ID: {task_id}")
|
| 77 |
print(f"Answer: {submitted_answer}")
|
| 78 |
print("-" * 40)
|
|
@@ -183,7 +183,8 @@ with gr.Blocks() as demo:
|
|
| 183 |
|
| 184 |
run_button.click(fn=run_and_submit_all, outputs=[status_output, results_table])
|
| 185 |
|
| 186 |
-
|
|
|
|
| 187 |
print("\n" + "-" * 30 + " App Starting " + "-" * 30)
|
| 188 |
# Check for SPACE_HOST and SPACE_ID at startup for information
|
| 189 |
space_host_startup = os.getenv("SPACE_HOST")
|
|
@@ -210,3 +211,9 @@ if __name__ == "__main__":
|
|
| 210 |
|
| 211 |
print("Launching Gradio Interface for Basic Agent Evaluation...")
|
| 212 |
demo.launch(debug=True, share=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
| 12 |
|
| 13 |
|
| 14 |
+
async def run_and_submit_all(profile: gr.OAuthProfile | None):
|
| 15 |
"""
|
| 16 |
Fetches all questions, runs the BasicAgent on them, submits all answers,
|
| 17 |
and displays the results.
|
|
|
|
| 72 |
print(f"Skipping item with missing task_id or question: {item}")
|
| 73 |
continue
|
| 74 |
try:
|
| 75 |
+
submitted_answer = await agent(question_text)
|
| 76 |
print(f"Task ID: {task_id}")
|
| 77 |
print(f"Answer: {submitted_answer}")
|
| 78 |
print("-" * 40)
|
|
|
|
| 183 |
|
| 184 |
run_button.click(fn=run_and_submit_all, outputs=[status_output, results_table])
|
| 185 |
|
| 186 |
+
|
| 187 |
+
async def main():
|
| 188 |
print("\n" + "-" * 30 + " App Starting " + "-" * 30)
|
| 189 |
# Check for SPACE_HOST and SPACE_ID at startup for information
|
| 190 |
space_host_startup = os.getenv("SPACE_HOST")
|
|
|
|
| 211 |
|
| 212 |
print("Launching Gradio Interface for Basic Agent Evaluation...")
|
| 213 |
demo.launch(debug=True, share=False)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
if __name__ == "__main__":
|
| 217 |
+
import asyncio
|
| 218 |
+
|
| 219 |
+
asyncio.run(main())
|
build_rag_index.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Build RAG index from GAIA validation dataset with Annotator Metadata."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
+
from langchain_chroma import Chroma
|
| 8 |
+
from langchain_core.documents import Document
|
| 9 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 10 |
+
|
| 11 |
+
CHROMA_PATH = "./chroma_gaia_db"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def build_index():
|
| 15 |
+
"""Load GAIA validation set and index questions with metadata."""
|
| 16 |
+
print("Loading GAIA dataset...")
|
| 17 |
+
ds = load_dataset("gaia-benchmark/GAIA", "2023_all", split="validation")
|
| 18 |
+
|
| 19 |
+
print(f"Found {len(ds)} examples")
|
| 20 |
+
|
| 21 |
+
# Create documents from dataset
|
| 22 |
+
documents = []
|
| 23 |
+
for item in ds:
|
| 24 |
+
question = item.get("Question", "")
|
| 25 |
+
answer = item.get("Final answer", "")
|
| 26 |
+
level = item.get("Level", "")
|
| 27 |
+
task_id = item.get("task_id", "")
|
| 28 |
+
metadata_raw = item.get("Annotator Metadata", {})
|
| 29 |
+
|
| 30 |
+
# Parse annotator metadata
|
| 31 |
+
if isinstance(metadata_raw, str):
|
| 32 |
+
try:
|
| 33 |
+
metadata_raw = json.loads(metadata_raw)
|
| 34 |
+
except json.JSONDecodeError:
|
| 35 |
+
metadata_raw = {}
|
| 36 |
+
|
| 37 |
+
steps = metadata_raw.get("Steps", "")
|
| 38 |
+
tools = metadata_raw.get("Tools", "")
|
| 39 |
+
num_steps = metadata_raw.get("Number of steps", "")
|
| 40 |
+
|
| 41 |
+
# Build document content with question, answer, and reasoning
|
| 42 |
+
content = f"""Question: {question}
|
| 43 |
+
|
| 44 |
+
Final Answer: {answer}
|
| 45 |
+
|
| 46 |
+
Steps to solve:
|
| 47 |
+
{steps}
|
| 48 |
+
|
| 49 |
+
Tools needed: {tools}"""
|
| 50 |
+
|
| 51 |
+
doc = Document(
|
| 52 |
+
page_content=content,
|
| 53 |
+
metadata={
|
| 54 |
+
"task_id": task_id,
|
| 55 |
+
"question": question,
|
| 56 |
+
"answer": answer,
|
| 57 |
+
"level": str(level),
|
| 58 |
+
"num_steps": str(num_steps),
|
| 59 |
+
"tools": tools,
|
| 60 |
+
},
|
| 61 |
+
)
|
| 62 |
+
documents.append(doc)
|
| 63 |
+
|
| 64 |
+
print(f"Created {len(documents)} documents")
|
| 65 |
+
|
| 66 |
+
# Initialize embeddings
|
| 67 |
+
print("Initializing embeddings...")
|
| 68 |
+
embeddings = HuggingFaceEmbeddings(
|
| 69 |
+
model_name="sentence-transformers/all-mpnet-base-v2"
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Clear existing index if present
|
| 73 |
+
chroma_path = Path(CHROMA_PATH)
|
| 74 |
+
if chroma_path.exists():
|
| 75 |
+
import shutil
|
| 76 |
+
|
| 77 |
+
shutil.rmtree(chroma_path)
|
| 78 |
+
print("Cleared existing index")
|
| 79 |
+
|
| 80 |
+
# Create and persist vector store
|
| 81 |
+
print("Building vector store...")
|
| 82 |
+
vectorstore = Chroma.from_documents(
|
| 83 |
+
documents=documents,
|
| 84 |
+
embedding=embeddings,
|
| 85 |
+
persist_directory=CHROMA_PATH,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
print(f"Indexed {len(documents)} documents to {CHROMA_PATH}")
|
| 89 |
+
|
| 90 |
+
# Test retrieval
|
| 91 |
+
print("\nTesting retrieval...")
|
| 92 |
+
test_query = (
|
| 93 |
+
"How many studio albums did Mercedes Sosa release between 2000 and 2009?"
|
| 94 |
+
)
|
| 95 |
+
results = vectorstore.similarity_search(test_query, k=2)
|
| 96 |
+
print(f"Query: {test_query}")
|
| 97 |
+
for i, doc in enumerate(results):
|
| 98 |
+
print(f"\n--- Result {i+1} ---")
|
| 99 |
+
print(f"Question: {doc.metadata.get('question', '')[:100]}...")
|
| 100 |
+
print(f"Answer: {doc.metadata.get('answer', '')}")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
if __name__ == "__main__":
|
| 104 |
+
build_index()
|
chroma_gaia_db/.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.sqlite3 filter=lfs diff=lfs merge=lfs -text
|
chroma_gaia_db/99bb1417-fe53-457a-8b1f-42a54fb4c17c/data_level0.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b92e8e451752ee2cb1c2e5bba20ff2aa94ba02b270bcfc3f8f6efffb8b948333
|
| 3 |
+
size 321200
|
chroma_gaia_db/99bb1417-fe53-457a-8b1f-42a54fb4c17c/header.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:03cb3ac86f3e5bcb15e88b9bf99f760ec6b33e31d64a699e129b49868db6d733
|
| 3 |
+
size 100
|
chroma_gaia_db/99bb1417-fe53-457a-8b1f-42a54fb4c17c/length.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5
|
| 3 |
+
size 400
|
chroma_gaia_db/99bb1417-fe53-457a-8b1f-42a54fb4c17c/link_lists.bin
ADDED
|
File without changes
|
chroma_gaia_db/chroma.sqlite3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e7524be1035bb3131badbf5dc2828251aa5df02a6248d4675e7966bc8eef2ddd
|
| 3 |
+
size 2830336
|
pyproject.toml
CHANGED
|
@@ -29,6 +29,15 @@ dependencies = [
|
|
| 29 |
"ddgs>=9.10.0",
|
| 30 |
"sentence-transformers>=5.2.0",
|
| 31 |
"typer>=0.9.0",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
]
|
| 33 |
|
| 34 |
[dependency-groups]
|
|
|
|
| 29 |
"ddgs>=9.10.0",
|
| 30 |
"sentence-transformers>=5.2.0",
|
| 31 |
"typer>=0.9.0",
|
| 32 |
+
"httpx>=0.28.1",
|
| 33 |
+
"pyjwt>=2.10.1",
|
| 34 |
+
"openpyxl>=3.1.5",
|
| 35 |
+
"python-docx>=1.2.0",
|
| 36 |
+
"python-pptx>=1.0.2",
|
| 37 |
+
"langchain-chroma>=1.1.0",
|
| 38 |
+
"zai>=0.0.2",
|
| 39 |
+
"zai-sdk>=0.2.0",
|
| 40 |
+
"langchain-mcp-adapters>=0.2.1",
|
| 41 |
]
|
| 42 |
|
| 43 |
[dependency-groups]
|
system_prompt.txt
CHANGED
|
@@ -1,38 +1,17 @@
|
|
| 1 |
-
You are a precise
|
| 2 |
|
| 3 |
-
##
|
| 4 |
-
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
##
|
| 7 |
-
1.
|
| 8 |
-
2.
|
| 9 |
-
3.
|
| 10 |
-
4.
|
| 11 |
-
5.
|
| 12 |
-
6. If source says "ripe strawberries", write "ripe strawberries" (NOT "strawberries").
|
| 13 |
-
7. Use tools to find information. Do not guess.
|
| 14 |
-
8. If information is not found, respond: "I don't know"
|
| 15 |
-
|
| 16 |
-
## Search Tips
|
| 17 |
-
- For discography/albums questions: search "[Artist] discography" to find the full album list
|
| 18 |
-
- For counting items in a date range: list each item with its year, then count
|
| 19 |
-
- If wiki_search doesn't have enough detail, use web_search or jina_search
|
| 20 |
-
- Read the full Wikipedia page with jina_read if needed
|
| 21 |
|
| 22 |
## Output Format (严格按照此格式)
|
| 23 |
Write ONLY this, nothing else:
|
| 24 |
FINAL ANSWER: [your answer here]
|
| 25 |
-
|
| 26 |
-
## Examples
|
| 27 |
-
|
| 28 |
-
Question: What are the filling ingredients?
|
| 29 |
-
Source: "You'll need ripe strawberries, granulated sugar, and freshly squeezed lemon juice"
|
| 30 |
-
FINAL ANSWER: freshly squeezed lemon juice, granulated sugar, ripe strawberries
|
| 31 |
-
|
| 32 |
-
Question: What ingredients are in the sauce?
|
| 33 |
-
Source: "Mix pure vanilla extract with heavy whipping cream"
|
| 34 |
-
FINAL ANSWER: heavy whipping cream, pure vanilla extract
|
| 35 |
-
|
| 36 |
-
Question: Who wrote the book?
|
| 37 |
-
Source: "The novel was written by Jane Smith in 1995"
|
| 38 |
-
FINAL ANSWER: Jane Smith
|
|
|
|
| 1 |
+
You are a precise assistant for the GAIA benchmark.
|
| 2 |
|
| 3 |
+
## 工作流程 (workflow)
|
| 4 |
+
1. 首先: call `get_solving_strategy` with your question
|
| 5 |
+
2. follow the strategy steps using appropriate tools
|
| 6 |
+
3. when you find the answer, call `submit_answer` immediately
|
| 7 |
|
| 8 |
+
## rules (必须严格遵守)
|
| 9 |
+
1. use exact wording from sources. do not paraphrase or shorten.
|
| 10 |
+
2. for lists: sort items alphabetically, separate with comma and space.
|
| 11 |
+
3. use tools to find information. do not guess.
|
| 12 |
+
4. when you have the answer, call `submit_answer` immediately. 不要继续搜索。
|
| 13 |
+
5. if information is not found, keep trying different tools and approaches.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
## Output Format (严格按照此格式)
|
| 16 |
Write ONLY this, nothing else:
|
| 17 |
FINAL ANSWER: [your answer here]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_bench.py
CHANGED
|
@@ -243,7 +243,9 @@ def run_test_bench(
|
|
| 243 |
question_with_file = q.question
|
| 244 |
if q.file_path:
|
| 245 |
question_with_file += f"\n\nFile path: {q.file_path}"
|
| 246 |
-
|
|
|
|
|
|
|
| 247 |
except Exception as e:
|
| 248 |
actual = f"ERROR: {e}"
|
| 249 |
|
|
|
|
| 243 |
question_with_file = q.question
|
| 244 |
if q.file_path:
|
| 245 |
question_with_file += f"\n\nFile path: {q.file_path}"
|
| 246 |
+
import asyncio
|
| 247 |
+
|
| 248 |
+
actual = asyncio.run(agent(question_with_file))
|
| 249 |
except Exception as e:
|
| 250 |
actual = f"ERROR: {e}"
|
| 251 |
|
tools.py
CHANGED
|
@@ -1,10 +1,77 @@
|
|
| 1 |
import os
|
| 2 |
from typing import List
|
| 3 |
|
|
|
|
| 4 |
from langchain_core.documents.base import Document
|
| 5 |
from langchain_core.tools import tool
|
| 6 |
from langchain_core.tools.base import ArgsSchema
|
|
|
|
| 7 |
from pydantic import SecretStr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
def _get_llm():
|
|
@@ -25,55 +92,102 @@ def _get_llm():
|
|
| 25 |
return ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
| 26 |
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
for doc in search_docs
|
| 41 |
-
]
|
| 42 |
-
)
|
| 43 |
-
return formatted_search_docs
|
| 44 |
|
| 45 |
|
| 46 |
@tool
|
| 47 |
-
def
|
| 48 |
-
"""Search
|
| 49 |
|
| 50 |
Args:
|
| 51 |
query: The search query."""
|
| 52 |
-
import
|
| 53 |
-
|
| 54 |
-
api_key = os.getenv("ZAI_API_KEY", "")
|
| 55 |
-
|
| 56 |
-
response = requests.post(
|
| 57 |
-
"https://api.z.ai/api/coding/paas/v4/web_search",
|
| 58 |
-
headers={
|
| 59 |
-
"Authorization": f"Bearer {api_key}",
|
| 60 |
-
"Content-Type": "application/json",
|
| 61 |
-
},
|
| 62 |
-
json={"search_engine": "search-prime", "search_query": query, "count": 5},
|
| 63 |
-
timeout=30,
|
| 64 |
-
)
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
|
| 79 |
@tool
|
|
@@ -91,54 +205,123 @@ def jina_search(query: str) -> str:
|
|
| 91 |
return response.text
|
| 92 |
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
-
|
| 106 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
|
| 109 |
@tool
|
| 110 |
-
def
|
| 111 |
-
"""Read a webpage
|
| 112 |
|
| 113 |
Args:
|
| 114 |
-
url: The URL to read.
|
|
|
|
| 115 |
import requests
|
| 116 |
|
| 117 |
-
api_key = os.getenv("
|
| 118 |
-
|
| 119 |
-
response = requests.post(
|
| 120 |
-
"https://api.z.ai/api/mcp/web_reader/mcp",
|
| 121 |
-
headers={
|
| 122 |
-
"Authorization": f"Bearer {api_key}",
|
| 123 |
-
"Content-Type": "application/json",
|
| 124 |
-
},
|
| 125 |
-
json={
|
| 126 |
-
"method": "tools/call",
|
| 127 |
-
"params": {
|
| 128 |
-
"name": "webReader",
|
| 129 |
-
"arguments": {"url": url},
|
| 130 |
-
},
|
| 131 |
-
},
|
| 132 |
-
timeout=60,
|
| 133 |
-
)
|
| 134 |
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
return
|
| 141 |
-
return f"Error: {data}"
|
| 142 |
|
| 143 |
|
| 144 |
@tool
|
|
@@ -205,20 +388,177 @@ def analyze_text(text: str, question: str) -> str:
|
|
| 205 |
return response.content
|
| 206 |
|
| 207 |
|
| 208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
"""Retrieve the list of available tools for the agent."""
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
jina_read,
|
| 215 |
download_file,
|
| 216 |
read_pdf,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
py_calc_tool,
|
| 218 |
youtube_transcript_tool,
|
| 219 |
transcribe_audio,
|
| 220 |
arxiv_search,
|
| 221 |
]
|
|
|
|
|
|
|
|
|
|
| 222 |
|
| 223 |
|
| 224 |
@tool
|
|
|
|
| 1 |
import os
|
| 2 |
from typing import List
|
| 3 |
|
| 4 |
+
from langchain_chroma import Chroma
|
| 5 |
from langchain_core.documents.base import Document
|
| 6 |
from langchain_core.tools import tool
|
| 7 |
from langchain_core.tools.base import ArgsSchema
|
| 8 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 9 |
from pydantic import SecretStr
|
| 10 |
+
from sqlalchemy.sql.selectable import ForUpdateParameter
|
| 11 |
+
|
| 12 |
+
# Initialize RAG vector store for strategy retrieval
|
| 13 |
+
CHROMA_PATH = "./chroma_gaia_db"
|
| 14 |
+
_embeddings = None
|
| 15 |
+
_vector_store = None
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _get_vector_store():
|
| 19 |
+
"""Lazy load vector store."""
|
| 20 |
+
global _embeddings, _vector_store
|
| 21 |
+
if _vector_store is None:
|
| 22 |
+
_embeddings = HuggingFaceEmbeddings(
|
| 23 |
+
model_name="sentence-transformers/all-mpnet-base-v2"
|
| 24 |
+
)
|
| 25 |
+
_vector_store = Chroma(
|
| 26 |
+
persist_directory=CHROMA_PATH, embedding_function=_embeddings
|
| 27 |
+
)
|
| 28 |
+
return _vector_store
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@tool
|
| 32 |
+
def get_solving_strategy(question: str) -> str:
|
| 33 |
+
"""Search for similar solved questions and get the solving strategy.
|
| 34 |
+
Use this FIRST to understand how to approach a problem before using other tools.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
question: The question you need to solve."""
|
| 38 |
+
print(f"\n[GET_SOLVING_STRATEGY] Searching for: {question[:80]}...")
|
| 39 |
+
try:
|
| 40 |
+
vector_store = _get_vector_store()
|
| 41 |
+
similar_docs = vector_store.similarity_search(question, k=1)
|
| 42 |
+
print(f"[GET_SOLVING_STRATEGY] Found {len(similar_docs)} similar questions")
|
| 43 |
+
|
| 44 |
+
if similar_docs:
|
| 45 |
+
doc = similar_docs[0]
|
| 46 |
+
steps = (
|
| 47 |
+
doc.page_content.split("Steps to solve:")[-1]
|
| 48 |
+
.split("Tools needed:")[0]
|
| 49 |
+
.strip()
|
| 50 |
+
)
|
| 51 |
+
tools_raw = doc.metadata.get("tools", "")
|
| 52 |
+
# Clean up tools format - replace inline numbers with newlines
|
| 53 |
+
tools = tools_raw.replace("\n", "\n- ").strip()
|
| 54 |
+
if tools and not tools.startswith("-"):
|
| 55 |
+
tools = "- " + tools
|
| 56 |
+
|
| 57 |
+
set_current_strategy(steps)
|
| 58 |
+
|
| 59 |
+
return f"""Similar question found!
|
| 60 |
+
|
| 61 |
+
## Strategy to solve (按此策略执行):
|
| 62 |
+
{steps}
|
| 63 |
+
|
| 64 |
+
## Rules (必须严格遵守):
|
| 65 |
+
1. Use EXACT wording from sources. Do not paraphrase or shorten.
|
| 66 |
+
2. For lists: sort items alphabetically, separate with comma and space.
|
| 67 |
+
3. Use tools to find information. Do not guess.
|
| 68 |
+
4. When you find the answer, call `submit_answer` immediately. 不要继续搜索。
|
| 69 |
+
|
| 70 |
+
"""
|
| 71 |
+
else:
|
| 72 |
+
return "No similar questions found. Use your best judgment."
|
| 73 |
+
except Exception as e:
|
| 74 |
+
return f"Error searching for strategy: {e}"
|
| 75 |
|
| 76 |
|
| 77 |
def _get_llm():
|
|
|
|
| 92 |
return ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
| 93 |
|
| 94 |
|
| 95 |
+
def _fetch_url_with_tables(url: str) -> str:
|
| 96 |
+
"""Fetch URL content including tables using Jina reader."""
|
| 97 |
+
import requests
|
| 98 |
|
| 99 |
+
try:
|
| 100 |
+
# Use Jina to get full page content including tables
|
| 101 |
+
api_key = os.getenv("JINA_API_KEY", "")
|
| 102 |
+
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
| 103 |
|
| 104 |
+
response = requests.get(f"https://r.jina.ai/{url}", headers=headers, timeout=30)
|
| 105 |
+
return response.text
|
| 106 |
+
except Exception:
|
| 107 |
+
return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
|
| 110 |
@tool
|
| 111 |
+
def wiki_search(query: str) -> str:
|
| 112 |
+
"""Search Wikipedia for a query and return relevant content including tables.
|
| 113 |
|
| 114 |
Args:
|
| 115 |
query: The search query."""
|
| 116 |
+
import wikipedia
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
+
try:
|
| 119 |
+
# Search for pages
|
| 120 |
+
search_results = wikipedia.search(query, results=3)
|
| 121 |
+
if not search_results:
|
| 122 |
+
return "No Wikipedia results found."
|
| 123 |
+
|
| 124 |
+
formatted_parts = []
|
| 125 |
+
for title in search_results[:2]:
|
| 126 |
+
try:
|
| 127 |
+
page = wikipedia.page(title, auto_suggest=False)
|
| 128 |
+
url = page.url
|
| 129 |
+
|
| 130 |
+
# Fetch the page via Jina to get full content including tables
|
| 131 |
+
content = _fetch_url_with_tables(url)
|
| 132 |
+
|
| 133 |
+
if not content:
|
| 134 |
+
# Fallback to wikipedia API content
|
| 135 |
+
content = page.content
|
| 136 |
+
|
| 137 |
+
# Use smart section extraction
|
| 138 |
+
extracted = _extract_relevant_content(content, query)
|
| 139 |
+
formatted_parts.append(
|
| 140 |
+
f'<Document source="{url}" title="{title}">\n{extracted}\n</Document>'
|
| 141 |
+
)
|
| 142 |
+
except (wikipedia.DisambiguationError, wikipedia.PageError):
|
| 143 |
+
continue
|
| 144 |
+
except Exception:
|
| 145 |
+
continue
|
| 146 |
+
|
| 147 |
+
return (
|
| 148 |
+
"\n\n---\n\n".join(formatted_parts)
|
| 149 |
+
if formatted_parts
|
| 150 |
+
else "No results found."
|
| 151 |
+
)
|
| 152 |
+
except Exception as e:
|
| 153 |
+
return f"Wikipedia search error: {e}"
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
_zai_mcp_tools = None
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
async def _get_zai_mcp_tools():
|
| 160 |
+
"""Lazy load Z.AI MCP tools."""
|
| 161 |
+
global _zai_mcp_tools
|
| 162 |
+
if _zai_mcp_tools is None:
|
| 163 |
+
from langchain_mcp_adapters.client import MultiServerMCPClient
|
| 164 |
+
|
| 165 |
+
api_key = os.getenv("ZAI_API_KEY", "")
|
| 166 |
+
client = MultiServerMCPClient(
|
| 167 |
+
{
|
| 168 |
+
"web-search": {
|
| 169 |
+
"transport": "streamable_http",
|
| 170 |
+
"url": "https://api.z.ai/api/mcp/web_search_prime/mcp",
|
| 171 |
+
"headers": {"Authorization": f"Bearer {api_key}"},
|
| 172 |
+
},
|
| 173 |
+
"web-reader": {
|
| 174 |
+
"transport": "streamable_http",
|
| 175 |
+
"url": "https://api.z.ai/api/mcp/web_reader/mcp",
|
| 176 |
+
"headers": {"Authorization": f"Bearer {api_key}"},
|
| 177 |
+
},
|
| 178 |
+
"zai-mcp": {
|
| 179 |
+
"transport": "stdio",
|
| 180 |
+
"command": "npx",
|
| 181 |
+
"args": ["-y", "@z_ai/mcp-server"],
|
| 182 |
+
"env": {
|
| 183 |
+
"Z_AI_API_KEY": api_key,
|
| 184 |
+
"Z_AI_MODE": "ZAI",
|
| 185 |
+
},
|
| 186 |
+
},
|
| 187 |
+
}
|
| 188 |
+
)
|
| 189 |
+
_zai_mcp_tools = await client.get_tools()
|
| 190 |
+
return _zai_mcp_tools
|
| 191 |
|
| 192 |
|
| 193 |
@tool
|
|
|
|
| 205 |
return response.text
|
| 206 |
|
| 207 |
|
| 208 |
+
def _extract_section_by_marker(
|
| 209 |
+
content: str, section_marker: str, context_lines: int = 50
|
| 210 |
+
) -> str:
|
| 211 |
+
"""Extract a section starting from a marker found in strategy steps.
|
| 212 |
+
|
| 213 |
+
This is the SMART extraction - uses strategy steps like "scrolled down to Studio albums"
|
| 214 |
+
to find the exact section we need.
|
| 215 |
+
"""
|
| 216 |
+
import re
|
| 217 |
+
|
| 218 |
+
lines = content.split("\n")
|
| 219 |
+
marker_lower = section_marker.lower().strip()
|
| 220 |
+
|
| 221 |
+
print(f"[EXTRACT_SECTION] Looking for section marker: '{section_marker}'")
|
| 222 |
+
|
| 223 |
+
# Find the line containing the section marker
|
| 224 |
+
start_idx = None
|
| 225 |
+
for i, line in enumerate(lines):
|
| 226 |
+
if marker_lower in line.lower():
|
| 227 |
+
start_idx = i
|
| 228 |
+
print(f"[EXTRACT_SECTION] Found marker at line {i}: {line[:80]}")
|
| 229 |
+
break
|
| 230 |
+
|
| 231 |
+
if start_idx is None:
|
| 232 |
+
# Try partial matching (e.g., "Studio albums" might be "Studio Albums" or "Discography")
|
| 233 |
+
for i, line in enumerate(lines):
|
| 234 |
+
# Check if most words from marker are in line
|
| 235 |
+
marker_words = [
|
| 236 |
+
w for w in re.findall(r"\b\w+\b", marker_lower) if len(w) > 2
|
| 237 |
+
]
|
| 238 |
+
line_lower = line.lower()
|
| 239 |
+
matches = sum(1 for w in marker_words if w in line_lower)
|
| 240 |
+
if matches >= len(marker_words) * 0.6: # 60% match threshold
|
| 241 |
+
start_idx = i
|
| 242 |
+
print(f"[EXTRACT_SECTION] Found partial match at line {i}: {line[:80]}")
|
| 243 |
+
break
|
| 244 |
+
|
| 245 |
+
if start_idx is None:
|
| 246 |
+
print(f"[EXTRACT_SECTION] Section marker not found")
|
| 247 |
+
return ""
|
| 248 |
+
|
| 249 |
+
# Extract from marker line + context_lines after it
|
| 250 |
+
end_idx = min(start_idx + context_lines, len(lines))
|
| 251 |
+
section = "\n".join(lines[start_idx:end_idx])
|
| 252 |
+
|
| 253 |
+
print(f"[EXTRACT_SECTION] Extracted {end_idx - start_idx} lines from section")
|
| 254 |
+
return section
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def _parse_section_markers_from_strategy(strategy: str) -> list:
|
| 258 |
+
"""Parse strategy steps to extract section markers.
|
| 259 |
+
|
| 260 |
+
Looks for phrases like:
|
| 261 |
+
- "scrolled down to Studio albums" -> "Studio albums"
|
| 262 |
+
- "found the Discography section" -> "Discography"
|
| 263 |
+
- "went to Studio albums" -> "Studio albums"
|
| 264 |
+
"""
|
| 265 |
+
import re
|
| 266 |
+
|
| 267 |
+
markers = []
|
| 268 |
+
|
| 269 |
+
# Patterns that indicate a section name
|
| 270 |
+
patterns = [
|
| 271 |
+
r'scrolled?\s+(?:down\s+)?to\s+["\']?([^"\',.]+)["\']?', # scrolled down to X
|
| 272 |
+
r'went\s+to\s+(?:the\s+)?["\']?([^"\',.]+)["\']?\s+section', # went to X section
|
| 273 |
+
r'found\s+(?:the\s+)?["\']?([^"\',.]+)["\']?\s+section', # found X section
|
| 274 |
+
r'clicked\s+on\s+["\']?([^"\',.]+)["\']?', # clicked on X
|
| 275 |
+
r'looked\s+(?:at|under)\s+["\']?([^"\',.]+)["\']?', # looked at/under X
|
| 276 |
+
r'(?:in|under)\s+(?:the\s+)?["\']?([^"\',.]+)["\']?\s+section', # in/under X section
|
| 277 |
+
]
|
| 278 |
|
| 279 |
+
for pattern in patterns:
|
| 280 |
+
matches = re.findall(pattern, strategy.lower())
|
| 281 |
+
for match in matches:
|
| 282 |
+
cleaned = match.strip()
|
| 283 |
+
if cleaned and len(cleaned) > 2 and len(cleaned) < 50:
|
| 284 |
+
markers.append(cleaned)
|
| 285 |
|
| 286 |
+
# Also look for quoted section names
|
| 287 |
+
quoted = re.findall(r'"([^"]+)"', strategy)
|
| 288 |
+
for q in quoted:
|
| 289 |
+
if len(q) > 2 and len(q) < 50 and q.lower() not in ["wikipedia", "google"]:
|
| 290 |
+
markers.append(q)
|
| 291 |
|
| 292 |
+
print(f"[PARSE_STRATEGY] Extracted section markers: {markers}")
|
| 293 |
+
return markers
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
# Global variable to store current strategy for smart extraction
|
| 297 |
+
_current_strategy = None
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def set_current_strategy(strategy: str):
|
| 301 |
+
"""Store the current strategy for use by content extraction."""
|
| 302 |
+
global _current_strategy
|
| 303 |
+
_current_strategy = strategy
|
| 304 |
+
print(f"[STRATEGY] Updated current strategy")
|
| 305 |
|
| 306 |
|
| 307 |
@tool
|
| 308 |
+
def jina_read(url: str, question: str = "") -> str:
|
| 309 |
+
"""Read a webpage and extract content relevant to the question.
|
| 310 |
|
| 311 |
Args:
|
| 312 |
+
url: The URL to read.
|
| 313 |
+
question: The question to extract relevant info for."""
|
| 314 |
import requests
|
| 315 |
|
| 316 |
+
api_key = os.getenv("JINA_API_KEY", "")
|
| 317 |
+
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
|
| 319 |
+
response = requests.get(f"https://r.jina.ai/{url}", headers=headers, timeout=30)
|
| 320 |
+
content = response.text
|
| 321 |
+
|
| 322 |
+
# Use smart extraction with strategy section markers
|
| 323 |
+
if question:
|
| 324 |
+
return content[:10000]
|
|
|
|
| 325 |
|
| 326 |
|
| 327 |
@tool
|
|
|
|
| 388 |
return response.content
|
| 389 |
|
| 390 |
|
| 391 |
+
@tool
|
| 392 |
+
def read_excel(file_path: str) -> str:
|
| 393 |
+
"""Read and extract data from an Excel file (.xlsx, .xls).
|
| 394 |
+
|
| 395 |
+
Args:
|
| 396 |
+
file_path: Path to the Excel file."""
|
| 397 |
+
import pandas as pd
|
| 398 |
+
|
| 399 |
+
try:
|
| 400 |
+
# Read all sheets
|
| 401 |
+
xlsx = pd.ExcelFile(file_path)
|
| 402 |
+
results = []
|
| 403 |
+
for sheet_name in xlsx.sheet_names:
|
| 404 |
+
df = pd.read_excel(xlsx, sheet_name=sheet_name)
|
| 405 |
+
results.append(f"=== Sheet: {sheet_name} ===\n{df.to_string()}")
|
| 406 |
+
return "\n\n".join(results)[:15000]
|
| 407 |
+
except Exception as e:
|
| 408 |
+
return f"Error reading Excel: {e}"
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
@tool
|
| 412 |
+
def read_csv(file_path: str) -> str:
|
| 413 |
+
"""Read and extract data from a CSV file.
|
| 414 |
+
|
| 415 |
+
Args:
|
| 416 |
+
file_path: Path to the CSV file."""
|
| 417 |
+
import pandas as pd
|
| 418 |
+
|
| 419 |
+
try:
|
| 420 |
+
df = pd.read_csv(file_path)
|
| 421 |
+
return df.to_string()[:15000]
|
| 422 |
+
except Exception as e:
|
| 423 |
+
return f"Error reading CSV: {e}"
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
@tool
|
| 427 |
+
def read_docx(file_path: str) -> str:
|
| 428 |
+
"""Read and extract text from a Word document (.docx).
|
| 429 |
+
|
| 430 |
+
Args:
|
| 431 |
+
file_path: Path to the Word document."""
|
| 432 |
+
try:
|
| 433 |
+
from docx import Document
|
| 434 |
+
|
| 435 |
+
doc = Document(file_path)
|
| 436 |
+
text = "\n".join([para.text for para in doc.paragraphs])
|
| 437 |
+
return text[:15000]
|
| 438 |
+
except Exception as e:
|
| 439 |
+
return f"Error reading Word doc: {e}"
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
@tool
|
| 443 |
+
def read_pptx(file_path: str) -> str:
|
| 444 |
+
"""Read and extract text from a PowerPoint presentation (.pptx).
|
| 445 |
+
|
| 446 |
+
Args:
|
| 447 |
+
file_path: Path to the PowerPoint file."""
|
| 448 |
+
try:
|
| 449 |
+
from pptx import Presentation
|
| 450 |
+
|
| 451 |
+
prs = Presentation(file_path)
|
| 452 |
+
text_parts = []
|
| 453 |
+
for slide_num, slide in enumerate(prs.slides, 1):
|
| 454 |
+
slide_text = [f"=== Slide {slide_num} ==="]
|
| 455 |
+
for shape in slide.shapes:
|
| 456 |
+
if hasattr(shape, "text"):
|
| 457 |
+
slide_text.append(shape.text)
|
| 458 |
+
text_parts.append("\n".join(slide_text))
|
| 459 |
+
return "\n\n".join(text_parts)[:15000]
|
| 460 |
+
except Exception as e:
|
| 461 |
+
return f"Error reading PowerPoint: {e}"
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
@tool
|
| 465 |
+
def extract_zip(file_path: str) -> str:
|
| 466 |
+
"""Extract a zip file and list its contents.
|
| 467 |
+
|
| 468 |
+
Args:
|
| 469 |
+
file_path: Path to the zip file."""
|
| 470 |
+
import zipfile
|
| 471 |
+
from pathlib import Path
|
| 472 |
+
|
| 473 |
+
try:
|
| 474 |
+
extract_dir = Path(file_path).parent / Path(file_path).stem
|
| 475 |
+
extract_dir.mkdir(exist_ok=True)
|
| 476 |
+
|
| 477 |
+
with zipfile.ZipFile(file_path, "r") as zip_ref:
|
| 478 |
+
zip_ref.extractall(extract_dir)
|
| 479 |
+
file_list = zip_ref.namelist()
|
| 480 |
+
|
| 481 |
+
return f"Extracted to: {extract_dir}\nContents:\n" + "\n".join(file_list)
|
| 482 |
+
except Exception as e:
|
| 483 |
+
return f"Error extracting zip: {e}"
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
@tool
|
| 487 |
+
def analyze_image(file_path: str, question: str) -> str:
|
| 488 |
+
"""Analyze an image and answer a question about it using vision model.
|
| 489 |
+
|
| 490 |
+
Args:
|
| 491 |
+
file_path: Path to the image file (png, jpg, etc.)
|
| 492 |
+
question: Question to answer about the image."""
|
| 493 |
+
import base64
|
| 494 |
+
|
| 495 |
+
from langchain_openai import ChatOpenAI
|
| 496 |
+
|
| 497 |
+
try:
|
| 498 |
+
with open(file_path, "rb") as f:
|
| 499 |
+
image_data = base64.b64encode(f.read()).decode("utf-8")
|
| 500 |
+
|
| 501 |
+
# Determine mime type
|
| 502 |
+
ext = file_path.lower().split(".")[-1]
|
| 503 |
+
mime_type = {"png": "image/png", "jpg": "image/jpeg", "jpeg": "image/jpeg"}.get(
|
| 504 |
+
ext, "image/png"
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
# Use GPT-4o for vision
|
| 508 |
+
llm = ChatOpenAI(model="gpt-4o", temperature=0)
|
| 509 |
+
response = llm.invoke(
|
| 510 |
+
[
|
| 511 |
+
{
|
| 512 |
+
"role": "user",
|
| 513 |
+
"content": [
|
| 514 |
+
{"type": "text", "text": question},
|
| 515 |
+
{
|
| 516 |
+
"type": "image_url",
|
| 517 |
+
"image_url": {
|
| 518 |
+
"url": f"data:{mime_type};base64,{image_data}"
|
| 519 |
+
},
|
| 520 |
+
},
|
| 521 |
+
],
|
| 522 |
+
}
|
| 523 |
+
]
|
| 524 |
+
)
|
| 525 |
+
return response.content
|
| 526 |
+
except Exception as e:
|
| 527 |
+
return f"Error analyzing image: {e}"
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
@tool
|
| 531 |
+
def submit_answer(answer: str) -> str:
|
| 532 |
+
"""Submit your final answer. Use this when you have found the answer.
|
| 533 |
+
|
| 534 |
+
Args:
|
| 535 |
+
answer: The final answer to submit."""
|
| 536 |
+
print(f"[SUBMIT_ANSWER] {answer}")
|
| 537 |
+
return f"FINAL ANSWER: {answer}"
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
async def get_tools() -> list:
|
| 541 |
"""Retrieve the list of available tools for the agent."""
|
| 542 |
+
base_tools = [
|
| 543 |
+
get_solving_strategy, # Use FIRST to get approach
|
| 544 |
+
submit_answer,
|
| 545 |
+
# wiki_search,
|
|
|
|
| 546 |
download_file,
|
| 547 |
read_pdf,
|
| 548 |
+
read_excel,
|
| 549 |
+
read_csv,
|
| 550 |
+
read_docx,
|
| 551 |
+
read_pptx,
|
| 552 |
+
extract_zip,
|
| 553 |
+
analyze_image,
|
| 554 |
py_calc_tool,
|
| 555 |
youtube_transcript_tool,
|
| 556 |
transcribe_audio,
|
| 557 |
arxiv_search,
|
| 558 |
]
|
| 559 |
+
# Add Z.AI MCP tools (webSearchPrime, webReader)
|
| 560 |
+
zai_tools = await _get_zai_mcp_tools()
|
| 561 |
+
return base_tools + zai_tools
|
| 562 |
|
| 563 |
|
| 564 |
@tool
|