Spaces:
Running
Running
File size: 12,681 Bytes
fae4e5b ddbf0ce fae4e5b 739f384 fae4e5b 739f384 fae4e5b 739f384 fae4e5b 739f384 fae4e5b 739f384 fae4e5b 739f384 fae4e5b 739f384 fae4e5b 739f384 fae4e5b 739f384 fae4e5b 739f384 fae4e5b 739f384 fae4e5b 739f384 fae4e5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 |
"""
MCP Client for connecting to TraceMind-mcp-server
Uses MCP protocol over HTTP to call remote MCP tools
"""
import os
import asyncio
from typing import Optional, Dict, Any, List
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
import aiohttp
class MCPClient:
"""Client for interacting with TraceMind MCP Server"""
def __init__(self, server_url: Optional[str] = None):
"""
Initialize MCP Client
Args:
server_url: URL of the TraceMind-mcp-server endpoint
If None, uses MCP_SERVER_URL from environment
"""
self.server_url = server_url or os.getenv(
'MCP_SERVER_URL',
'https://mcp-1st-birthday-tracemind-mcp-server.hf.space/gradio_api/mcp/'
)
self.session: Optional[ClientSession] = None
self._initialized = False
self._sse_context = None
self._session_context = None
async def initialize(self):
"""Initialize connection to MCP server"""
if self._initialized:
return
try:
# Connect to SSE endpoint and keep it open
self._sse_context = sse_client(self.server_url)
read, write = await self._sse_context.__aenter__()
# Create session and keep it open
self._session_context = ClientSession(read, write)
self.session = await self._session_context.__aenter__()
await self.session.initialize()
self._initialized = True
# List available tools for verification
tools_result = await self.session.list_tools()
print(f"β
Connected to TraceMind MCP Server at {self.server_url}")
print(f"π Available tools: {len(tools_result.tools)}")
for tool in tools_result.tools:
print(f" - {tool.name}: {tool.description}")
except Exception as e:
print(f"β Failed to connect to MCP server: {e}")
# Clean up on error
await self._cleanup_connections()
raise
async def _ensure_connected(self):
"""Ensure the connection is active, reconnect if needed"""
if not self._initialized or self.session is None:
print("π Reconnecting to MCP server...")
await self._cleanup_connections()
await self.initialize()
async def _call_tool_with_retry(self, tool_name: str, arguments: dict, max_retries: int = 2):
"""Call MCP tool with automatic retry on connection errors"""
for attempt in range(max_retries):
try:
await self._ensure_connected()
result = await self.session.call_tool(tool_name, arguments=arguments)
return result
except Exception as e:
error_str = str(e)
if "ClosedResourceError" in error_str or "closed" in error_str.lower():
if attempt < max_retries - 1:
print(f"β οΈ Connection lost, retrying... (attempt {attempt + 1}/{max_retries})")
await self._cleanup_connections()
continue
raise
async def analyze_leaderboard(
self,
leaderboard_repo: str = "kshitijthakkar/smoltrace-leaderboard",
metric_focus: str = "overall",
time_range: str = "last_week",
top_n: int = 5,
hf_token: Optional[str] = None,
gemini_api_key: Optional[str] = None
) -> str:
"""
Call the analyze_leaderboard tool on MCP server
Args:
leaderboard_repo: HuggingFace dataset repo for leaderboard
metric_focus: Focus metric (overall, accuracy, cost, latency, co2)
time_range: Time range filter (last_week, last_month, all_time)
top_n: Number of top models to highlight
hf_token: HuggingFace API token (optional if public dataset)
gemini_api_key: Google Gemini API key (optional, server may have it)
Returns:
AI-generated analysis of the leaderboard
"""
try:
# Build arguments
args = {
"leaderboard_repo": leaderboard_repo,
"metric_focus": metric_focus,
"time_range": time_range,
"top_n": top_n
}
# Add optional tokens if provided
if hf_token:
args["hf_token"] = hf_token
if gemini_api_key:
args["gemini_api_key"] = gemini_api_key
# Call MCP tool with retry
result = await self._call_tool_with_retry("analyze_leaderboard", args)
# Extract text from result
if result.content and len(result.content) > 0:
return result.content[0].text
else:
return "No analysis generated"
except Exception as e:
return f"β Error calling analyze_leaderboard: {str(e)}"
async def debug_trace(
self,
trace_data: Dict[str, Any],
question: str,
metrics_data: Optional[Dict[str, Any]] = None,
hf_token: Optional[str] = None,
gemini_api_key: Optional[str] = None
) -> str:
"""
Call the debug_trace tool on MCP server
Args:
trace_data: OpenTelemetry trace data (dict with spans)
question: User question about the trace
metrics_data: Optional GPU metrics data
hf_token: HuggingFace API token
gemini_api_key: Google Gemini API key
Returns:
AI-generated answer to the trace question
"""
try:
args = {
"trace_data": trace_data,
"question": question
}
if metrics_data:
args["metrics_data"] = metrics_data
if hf_token:
args["hf_token"] = hf_token
if gemini_api_key:
args["gemini_api_key"] = gemini_api_key
result = await self._call_tool_with_retry("debug_trace", args)
if result.content and len(result.content) > 0:
return result.content[0].text
else:
return "No answer generated"
except Exception as e:
return f"β Error calling debug_trace: {str(e)}"
async def estimate_cost(
self,
model: str,
agent_type: str = "both",
num_tests: int = 100,
hardware: Optional[str] = None,
hf_token: Optional[str] = None,
gemini_api_key: Optional[str] = None
) -> str:
"""
Call the estimate_cost tool on MCP server
Args:
model: Model name (e.g., 'openai/gpt-4', 'meta-llama/Llama-3.1-8B')
agent_type: Agent type (tool, code, both)
num_tests: Number of tests to run
hardware: Hardware type (cpu, gpu_a10, gpu_h200)
hf_token: HuggingFace API token
gemini_api_key: Google Gemini API key
Returns:
Cost estimation with breakdown
"""
try:
args = {
"model": model,
"agent_type": agent_type,
"num_tests": num_tests
}
if hardware:
args["hardware"] = hardware
if hf_token:
args["hf_token"] = hf_token
if gemini_api_key:
args["gemini_api_key"] = gemini_api_key
result = await self._call_tool_with_retry("estimate_cost", args)
if result.content and len(result.content) > 0:
return result.content[0].text
else:
return "No estimation generated"
except Exception as e:
return f"β Error calling estimate_cost: {str(e)}"
async def compare_runs(
self,
run_data_list: List[Dict[str, Any]],
focus_metrics: Optional[List[str]] = None,
hf_token: Optional[str] = None,
gemini_api_key: Optional[str] = None
) -> str:
"""
Call the compare_runs tool on MCP server
Args:
run_data_list: List of run data dicts from leaderboard
focus_metrics: List of metrics to focus on
hf_token: HuggingFace API token
gemini_api_key: Google Gemini API key
Returns:
AI-generated comparison analysis
"""
try:
args = {
"run_data_list": run_data_list
}
if focus_metrics:
args["focus_metrics"] = focus_metrics
if hf_token:
args["hf_token"] = hf_token
if gemini_api_key:
args["gemini_api_key"] = gemini_api_key
result = await self._call_tool_with_retry("compare_runs", args)
if result.content and len(result.content) > 0:
return result.content[0].text
else:
return "No comparison generated"
except Exception as e:
return f"β Error calling compare_runs: {str(e)}"
async def analyze_results(
self,
results_data: List[Dict[str, Any]],
analysis_focus: str = "optimization",
hf_token: Optional[str] = None,
gemini_api_key: Optional[str] = None
) -> str:
"""
Call the analyze_results tool on MCP server
Args:
results_data: List of test case results
analysis_focus: Focus area (optimization, failures, performance, cost)
hf_token: HuggingFace API token
gemini_api_key: Google Gemini API key
Returns:
AI-generated results analysis with recommendations
"""
try:
args = {
"results_data": results_data,
"analysis_focus": analysis_focus
}
if hf_token:
args["hf_token"] = hf_token
if gemini_api_key:
args["gemini_api_key"] = gemini_api_key
result = await self._call_tool_with_retry("analyze_results", args)
if result.content and len(result.content) > 0:
return result.content[0].text
else:
return "No analysis generated"
except Exception as e:
return f"β Error calling analyze_results: {str(e)}"
async def get_dataset_info(
self,
dataset_repo: str,
hf_token: Optional[str] = None,
gemini_api_key: Optional[str] = None
) -> str:
"""
Call the get_dataset tool on MCP server (resource)
Args:
dataset_repo: HuggingFace dataset repo
hf_token: HuggingFace API token
gemini_api_key: Google Gemini API key
Returns:
Dataset information and structure
"""
try:
args = {
"dataset_repo": dataset_repo
}
if hf_token:
args["hf_token"] = hf_token
if gemini_api_key:
args["gemini_api_key"] = gemini_api_key
result = await self._call_tool_with_retry("get_dataset", args)
if result.content and len(result.content) > 0:
return result.content[0].text
else:
return "No dataset info generated"
except Exception as e:
return f"β Error calling get_dataset: {str(e)}"
async def _cleanup_connections(self):
"""Internal helper to clean up connections"""
if self._session_context:
try:
await self._session_context.__aexit__(None, None, None)
except Exception as e:
print(f"β οΈ Error closing session context: {e}")
self._session_context = None
self.session = None
if self._sse_context:
try:
await self._sse_context.__aexit__(None, None, None)
except Exception as e:
print(f"β οΈ Error closing SSE context: {e}")
self._sse_context = None
self._initialized = False
async def close(self):
"""Close the MCP client session"""
await self._cleanup_connections()
# Singleton instance for use across the app
_mcp_client_instance: Optional[MCPClient] = None
def get_mcp_client() -> MCPClient:
"""Get or create the global MCP client instance"""
global _mcp_client_instance
if _mcp_client_instance is None:
_mcp_client_instance = MCPClient()
return _mcp_client_instance
|