TraceMind / mcp_client /client.py
kshitijthakkar's picture
fix: Resolve compare screen and MCP connection issues
739f384
"""
MCP Client for connecting to TraceMind-mcp-server
Uses MCP protocol over HTTP to call remote MCP tools
"""
import os
import asyncio
from typing import Optional, Dict, Any, List
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
import aiohttp
class MCPClient:
"""Client for interacting with TraceMind MCP Server"""
def __init__(self, server_url: Optional[str] = None):
"""
Initialize MCP Client
Args:
server_url: URL of the TraceMind-mcp-server endpoint
If None, uses MCP_SERVER_URL from environment
"""
self.server_url = server_url or os.getenv(
'MCP_SERVER_URL',
'https://mcp-1st-birthday-tracemind-mcp-server.hf.space/gradio_api/mcp/'
)
self.session: Optional[ClientSession] = None
self._initialized = False
self._sse_context = None
self._session_context = None
async def initialize(self):
"""Initialize connection to MCP server"""
if self._initialized:
return
try:
# Connect to SSE endpoint and keep it open
self._sse_context = sse_client(self.server_url)
read, write = await self._sse_context.__aenter__()
# Create session and keep it open
self._session_context = ClientSession(read, write)
self.session = await self._session_context.__aenter__()
await self.session.initialize()
self._initialized = True
# List available tools for verification
tools_result = await self.session.list_tools()
print(f"✅ Connected to TraceMind MCP Server at {self.server_url}")
print(f"📊 Available tools: {len(tools_result.tools)}")
for tool in tools_result.tools:
print(f" - {tool.name}: {tool.description}")
except Exception as e:
print(f"❌ Failed to connect to MCP server: {e}")
# Clean up on error
await self._cleanup_connections()
raise
async def _ensure_connected(self):
"""Ensure the connection is active, reconnect if needed"""
if not self._initialized or self.session is None:
print("🔄 Reconnecting to MCP server...")
await self._cleanup_connections()
await self.initialize()
async def _call_tool_with_retry(self, tool_name: str, arguments: dict, max_retries: int = 2):
"""Call MCP tool with automatic retry on connection errors"""
for attempt in range(max_retries):
try:
await self._ensure_connected()
result = await self.session.call_tool(tool_name, arguments=arguments)
return result
except Exception as e:
error_str = str(e)
if "ClosedResourceError" in error_str or "closed" in error_str.lower():
if attempt < max_retries - 1:
print(f"⚠️ Connection lost, retrying... (attempt {attempt + 1}/{max_retries})")
await self._cleanup_connections()
continue
raise
async def analyze_leaderboard(
self,
leaderboard_repo: str = "kshitijthakkar/smoltrace-leaderboard",
metric_focus: str = "overall",
time_range: str = "last_week",
top_n: int = 5,
hf_token: Optional[str] = None,
gemini_api_key: Optional[str] = None
) -> str:
"""
Call the analyze_leaderboard tool on MCP server
Args:
leaderboard_repo: HuggingFace dataset repo for leaderboard
metric_focus: Focus metric (overall, accuracy, cost, latency, co2)
time_range: Time range filter (last_week, last_month, all_time)
top_n: Number of top models to highlight
hf_token: HuggingFace API token (optional if public dataset)
gemini_api_key: Google Gemini API key (optional, server may have it)
Returns:
AI-generated analysis of the leaderboard
"""
try:
# Build arguments
args = {
"leaderboard_repo": leaderboard_repo,
"metric_focus": metric_focus,
"time_range": time_range,
"top_n": top_n
}
# Add optional tokens if provided
if hf_token:
args["hf_token"] = hf_token
if gemini_api_key:
args["gemini_api_key"] = gemini_api_key
# Call MCP tool with retry
result = await self._call_tool_with_retry("analyze_leaderboard", args)
# Extract text from result
if result.content and len(result.content) > 0:
return result.content[0].text
else:
return "No analysis generated"
except Exception as e:
return f"❌ Error calling analyze_leaderboard: {str(e)}"
async def debug_trace(
self,
trace_data: Dict[str, Any],
question: str,
metrics_data: Optional[Dict[str, Any]] = None,
hf_token: Optional[str] = None,
gemini_api_key: Optional[str] = None
) -> str:
"""
Call the debug_trace tool on MCP server
Args:
trace_data: OpenTelemetry trace data (dict with spans)
question: User question about the trace
metrics_data: Optional GPU metrics data
hf_token: HuggingFace API token
gemini_api_key: Google Gemini API key
Returns:
AI-generated answer to the trace question
"""
try:
args = {
"trace_data": trace_data,
"question": question
}
if metrics_data:
args["metrics_data"] = metrics_data
if hf_token:
args["hf_token"] = hf_token
if gemini_api_key:
args["gemini_api_key"] = gemini_api_key
result = await self._call_tool_with_retry("debug_trace", args)
if result.content and len(result.content) > 0:
return result.content[0].text
else:
return "No answer generated"
except Exception as e:
return f"❌ Error calling debug_trace: {str(e)}"
async def estimate_cost(
self,
model: str,
agent_type: str = "both",
num_tests: int = 100,
hardware: Optional[str] = None,
hf_token: Optional[str] = None,
gemini_api_key: Optional[str] = None
) -> str:
"""
Call the estimate_cost tool on MCP server
Args:
model: Model name (e.g., 'openai/gpt-4', 'meta-llama/Llama-3.1-8B')
agent_type: Agent type (tool, code, both)
num_tests: Number of tests to run
hardware: Hardware type (cpu, gpu_a10, gpu_h200)
hf_token: HuggingFace API token
gemini_api_key: Google Gemini API key
Returns:
Cost estimation with breakdown
"""
try:
args = {
"model": model,
"agent_type": agent_type,
"num_tests": num_tests
}
if hardware:
args["hardware"] = hardware
if hf_token:
args["hf_token"] = hf_token
if gemini_api_key:
args["gemini_api_key"] = gemini_api_key
result = await self._call_tool_with_retry("estimate_cost", args)
if result.content and len(result.content) > 0:
return result.content[0].text
else:
return "No estimation generated"
except Exception as e:
return f"❌ Error calling estimate_cost: {str(e)}"
async def compare_runs(
self,
run_data_list: List[Dict[str, Any]],
focus_metrics: Optional[List[str]] = None,
hf_token: Optional[str] = None,
gemini_api_key: Optional[str] = None
) -> str:
"""
Call the compare_runs tool on MCP server
Args:
run_data_list: List of run data dicts from leaderboard
focus_metrics: List of metrics to focus on
hf_token: HuggingFace API token
gemini_api_key: Google Gemini API key
Returns:
AI-generated comparison analysis
"""
try:
args = {
"run_data_list": run_data_list
}
if focus_metrics:
args["focus_metrics"] = focus_metrics
if hf_token:
args["hf_token"] = hf_token
if gemini_api_key:
args["gemini_api_key"] = gemini_api_key
result = await self._call_tool_with_retry("compare_runs", args)
if result.content and len(result.content) > 0:
return result.content[0].text
else:
return "No comparison generated"
except Exception as e:
return f"❌ Error calling compare_runs: {str(e)}"
async def analyze_results(
self,
results_data: List[Dict[str, Any]],
analysis_focus: str = "optimization",
hf_token: Optional[str] = None,
gemini_api_key: Optional[str] = None
) -> str:
"""
Call the analyze_results tool on MCP server
Args:
results_data: List of test case results
analysis_focus: Focus area (optimization, failures, performance, cost)
hf_token: HuggingFace API token
gemini_api_key: Google Gemini API key
Returns:
AI-generated results analysis with recommendations
"""
try:
args = {
"results_data": results_data,
"analysis_focus": analysis_focus
}
if hf_token:
args["hf_token"] = hf_token
if gemini_api_key:
args["gemini_api_key"] = gemini_api_key
result = await self._call_tool_with_retry("analyze_results", args)
if result.content and len(result.content) > 0:
return result.content[0].text
else:
return "No analysis generated"
except Exception as e:
return f"❌ Error calling analyze_results: {str(e)}"
async def get_dataset_info(
self,
dataset_repo: str,
hf_token: Optional[str] = None,
gemini_api_key: Optional[str] = None
) -> str:
"""
Call the get_dataset tool on MCP server (resource)
Args:
dataset_repo: HuggingFace dataset repo
hf_token: HuggingFace API token
gemini_api_key: Google Gemini API key
Returns:
Dataset information and structure
"""
try:
args = {
"dataset_repo": dataset_repo
}
if hf_token:
args["hf_token"] = hf_token
if gemini_api_key:
args["gemini_api_key"] = gemini_api_key
result = await self._call_tool_with_retry("get_dataset", args)
if result.content and len(result.content) > 0:
return result.content[0].text
else:
return "No dataset info generated"
except Exception as e:
return f"❌ Error calling get_dataset: {str(e)}"
async def _cleanup_connections(self):
"""Internal helper to clean up connections"""
if self._session_context:
try:
await self._session_context.__aexit__(None, None, None)
except Exception as e:
print(f"⚠️ Error closing session context: {e}")
self._session_context = None
self.session = None
if self._sse_context:
try:
await self._sse_context.__aexit__(None, None, None)
except Exception as e:
print(f"⚠️ Error closing SSE context: {e}")
self._sse_context = None
self._initialized = False
async def close(self):
"""Close the MCP client session"""
await self._cleanup_connections()
# Singleton instance for use across the app
_mcp_client_instance: Optional[MCPClient] = None
def get_mcp_client() -> MCPClient:
"""Get or create the global MCP client instance"""
global _mcp_client_instance
if _mcp_client_instance is None:
_mcp_client_instance = MCPClient()
return _mcp_client_instance