""" Financial AI Assistant - Direct Method Library (不依赖 HTTP) 直接导入并调用 easy_financial_mcp.py 中的函数 支持本地和 HF Space 部署 """ import sys from pathlib import Path import os import json from dotenv import load_dotenv from huggingface_hub import InferenceClient import requests import warnings # 抑削 asyncio 警告 warnings.filterwarnings('ignore', category=DeprecationWarning) os.environ['PYTHONWARNINGS'] = 'ignore' # 先加载 .env 文件 load_dotenv() # 添加服务模块路径 PROJECT_ROOT = Path(__file__).parent.parent.absolute() sys.path.insert(0, str(PROJECT_ROOT)) # 直接导入 MCP 中定义的函数 try: from EasyFinancialAgent.easy_financial_mcp import ( search_company as _search_company, get_company_info as _get_company_info, get_company_filings as _get_company_filings, get_financial_data as _get_financial_data, extract_financial_metrics as _extract_financial_metrics, get_latest_financial_data as _get_latest_financial_data, advanced_search_company as _advanced_search_company ) MCP_DIRECT_AVAILABLE = True print("[FinancialAI] ✓ Direct MCP functions imported successfully") except ImportError as e: MCP_DIRECT_AVAILABLE = False print(f"[FinancialAI] ✗ Failed to import MCP functions: {e}") # 定义占位符函数 def _advanced_search_company(x): return {"error": "MCP not available"} def _get_company_info(x): return {"error": "MCP not available"} def _get_company_filings(x, y=None): return {"error": "MCP not available"} def _get_financial_data(x, y): return {"error": "MCP not available"} def _get_latest_financial_data(x): return {"error": "MCP not available"} def _extract_financial_metrics(x, y=3): return {"error": "MCP not available"} # ============================================================ # 便捷方法 - 公司搜索相关 # ============================================================ def search_company_direct(company_input): """ 批量搜索公司信息(直接调用) 使用 advanced_search_company 工具,支持公司名称、Ticker 或 CIK 代码 Args: company_input: 公司名称、Ticker 代码或 CIK 代码 Returns: 批量搜索结果 Example: result = search_company_direct("Apple") result = search_company_direct("AAPL") result = search_company_direct("0000320193") """ if not MCP_DIRECT_AVAILABLE: return {"error": "MCP functions not available"} try: result = _advanced_search_company(company_input) return [result] except Exception as e: return {"error": str(e)} def get_company_info_direct(cik): """ 获取公司详细信息(直接调用) Args: cik: 公司 CIK 代码 Returns: 公司信息 Example: result = get_company_info_direct("0000320193") """ if not MCP_DIRECT_AVAILABLE: return {"error": "MCP functions not available"} try: return _get_company_info(cik) except Exception as e: return {"error": str(e)} def get_company_filings_direct(cik): """ 获取公司 SEC 文件列表(直接调用) Args: cik: 公司 CIK 代码 Returns: 文件列表 Example: result = get_company_filings_direct("0000320193") """ if not MCP_DIRECT_AVAILABLE: return {"error": "MCP functions not available"} try: return _get_company_filings(cik) except Exception as e: return {"error": str(e)} def advanced_search_company_detailed(company_input): """ 高级公司搜索 - 支持公司名称、Ticker 或 CIK 的强大搜索方法 不同于 search_company_direct,该方法来自 EasyReportDataMCP 中的 mcp_server_fastmcp 更具有灵活性,可以自动检测输入的类型 Args: company_input: 公司名称 ("Tesla", "Apple Inc") Ticker 代码 ("TSLA", "AAPL", "MSFT") CIK 代码 ("0001318605", "0000320193") Returns: dict: 包含以下信息: - cik: 公司的 Central Index Key - name: 办公室注册名称 - tickers: 股票代码 - sic: Standard Industrial Classification 代码 - sic_description: 行业/行业描述 Example: # 按公司名称搜索 result = advanced_search_company_detailed("Tesla") # 按 Ticker 搜索 result = advanced_search_company_detailed("TSLA") # 按 CIK 搜索 result = advanced_search_company_detailed("0001318605") """ if not MCP_DIRECT_AVAILABLE: return {"error": "MCP functions not available"} try: # 直接调用 advanced_search_company 工具 result = _advanced_search_company(company_input) return result except Exception as e: import traceback return { "error": str(e), "traceback": traceback.format_exc() } def format_search_result(search_result): """ 提取并格式化搜索结果 将 advanced_search_company 的结果转换为标准格式: [{company_name: str, cik: str, ticker: str}] Args: search_result: advanced_search_company 的返回结果 格式: {'cik': '...', 'name': '...', 'tickers': [...], ...} Returns: list[dict]: 格式化的结果 [ { 'company_name': str, # 公司名称 'cik': str, # CIK 代码 'ticker': str # 第一个股票代码 } ] Example: search_result = {'cik': '0001577552', 'name': 'Alibaba Group Holding Ltd', 'tickers': ['BABA'], '_source': 'company_tickers_cache'} formatted = format_search_result(search_result) # 输出: [{'company_name': 'Alibaba Group Holding Ltd', 'cik': '0001577552', 'ticker': 'BABA'}] """ # 处理错误情况 if isinstance(search_result, dict) and 'error' in search_result: return [] # 处理列表情况 if isinstance(search_result, list): formatted_list = [] for item in search_result: formatted_item = format_search_result(item) formatted_list.extend(formatted_item) return formatted_list # 处理单个字典 if not isinstance(search_result, dict): return [] try: company_name = search_result.get('name', '') cik = search_result.get('cik', '') tickers = search_result.get('tickers', []) # 取数组的第一个元素,或使用空字符串 ticker = tickers[0] if isinstance(tickers, list) and len(tickers) > 0 else '' return [{ 'company_name': company_name, 'cik': cik, 'ticker': ticker }] except Exception as e: return [] def format_search_result_for_display(search_result): """ 格式化搜索结果为显示用的字符串列表 Args: search_result: advanced_search_company 的返回结果 Returns: list[str]: 格式化的字符串列表 ["公司名 (Ticker)"] Example: result = format_search_result_for_display({'cik': '0001577552', 'name': 'Alibaba Group', 'tickers': ['BABA']}) # 输出: ['Alibaba Group (BABA)'] """ formatted_data = format_search_result(search_result) # ✅ 更稳健的美股主要代码判断逻辑 def is_main_us_ticker(ticker): if not ticker: return False ticker = ticker.upper().strip() # 处理包含点号的情况(如 BRK.B) ticker_clean = ticker.replace('.', '') # 判断规则: # 1. 6+字母基本是OTC或基金 - 拒绝 if len(ticker_clean) > 5: return False # 2. 5个字母且以特定后缀结尾 - 拒绝常见OTC/权证/单位后缀 if len(ticker_clean) == 5 and ticker_clean.endswith(('F', 'Y', 'Q', 'D', 'W', 'U', 'P')): return False # 3. 其他情况接受(包括 GOOGL, BABA, BRK.B 等) return True display_list = [] for item in formatted_data: company_name = item.get('company_name', 'Unknown') ticker = item.get('ticker', '') # ✅ 只显示主要美股代码 if ticker and is_main_us_ticker(ticker): display_text = f"{company_name} ({ticker})" display_list.append(display_text) elif not ticker: # 如果没有ticker,也显示公司名 display_list.append(company_name) return display_list def search_and_format(company_input): """ 搎合搜索并立即格式化结果 一个一步到位的便法方法,执行搜索并格式化结果 Args: company_input: 公司名称、Ticker 或 CIK Returns: list[dict]: 格式化的结果 Example: result = search_and_format('BABA') # 输出: [{'company_name': 'Alibaba Group Holding Ltd', 'cik': '0001577552', 'ticker': 'BABA'}] """ # 执行搜索 search_result = advanced_search_company_detailed(company_input) # 检查是否有错误 if isinstance(search_result, dict) and 'error' in search_result: return [] # 格式化结果 return format_search_result(search_result) # ============================================================ # 便捷方法 - 财务数据相关 # ============================================================ def get_latest_financial_data_direct(cik): """ 获取公司最新财务数据(直接调用) Args: cik: 公司 CIK 代码 Returns: 最新财务数据 Example: result = get_latest_financial_data_direct("0000320193") """ if not MCP_DIRECT_AVAILABLE: return {"error": "MCP functions not available"} try: return _get_latest_financial_data(cik) except Exception as e: return {"error": str(e)} def extract_financial_metrics_direct(cik, years=5): """ 提取多年财务指标趋势(直接调用) Args: cik: 公司 CIK 代码 years: 年数(默认 3 年) Returns: 财务指标数据 Example: result = extract_financial_metrics_direct("0000320193", years=5) """ if not MCP_DIRECT_AVAILABLE: return {"error": "MCP functions not available"} try: return _extract_financial_metrics(cik, years) except Exception as e: return {"error": str(e)} # ============================================================ # 高级方法 - 综合查询 # ============================================================ def query_company_direct(company_input, get_filings=True, get_metrics=True): """ 综合查询公司信息(直接调用) 包括搜索、基本信息、文件列表和财务指标 Args: company_input: 公司名称或代码 get_filings: 是否获取文件列表 get_metrics: 是否获取财务指标 Returns: 综合结果字典,包含 search, company_info, filings, metrics Example: result = query_company_direct("Apple", get_filings=True, get_metrics=True) """ from datetime import datetime result = { "timestamp": datetime.now().isoformat(), "query_input": company_input, "status": "success", "data": { "company_search": None, "company_info": None, "filings": None, "metrics": None }, "errors": [] } if not MCP_DIRECT_AVAILABLE: result["status"] = "error" result["errors"].append("MCP functions not available") return result try: # 1. 搜索公司 search_result = search_company_direct(company_input) if "error" in search_result: result["errors"].append(f"Search error: {search_result['error']}") result["status"] = "error" return result result["data"]["company_search"] = search_result # 从搜索结果提取 CIK cik = None if isinstance(search_result, dict): cik = search_result.get("cik") elif isinstance(search_result, (list, tuple)) and len(search_result) > 0: # 从列表中获取第一个元素 try: first_item = search_result[0] if isinstance(search_result, (list, tuple)) else None if isinstance(first_item, dict): cik = first_item.get("cik") except (IndexError, TypeError): pass if not cik: result["errors"].append("Could not extract CIK from search result") result["status"] = "error" return result # 2. 获取公司信息 company_info = get_company_info_direct(cik) if "error" not in company_info: result["data"]["company_info"] = company_info else: result["errors"].append(f"Failed to get company info: {company_info.get('error')}") # 3. 获取文件列表 if get_filings: filings = get_company_filings_direct(cik) if "error" not in filings: result["data"]["filings"] = filings else: result["errors"].append(f"Failed to get filings: {filings.get('error')}") # 4. 获取财务指标 if get_metrics: metrics = extract_financial_metrics_direct(cik, years=3) if "error" not in metrics: result["data"]["metrics"] = metrics else: result["errors"].append(f"Failed to get metrics: {metrics.get('error')}") except Exception as e: result["status"] = "error" result["errors"].append(f"Exception: {str(e)}") import traceback result["errors"].append(traceback.format_exc()) return result # ============================================================ # LLM 模型配置与初始化 # ============================================================ # 初始化 LLM 客户端 def _init_llm_client(): """初始化 LLM 客户端""" global llm_client hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") llm_client = None try: if hf_token: llm_client = InferenceClient(api_key=hf_token) print("[FinancialAI] ✓ LLM client initialized with HF_TOKEN") return True else: print("[FinancialAI] ⚠ Warning: HF_TOKEN not found, LLM features disabled") return False except Exception as e: print(f"[FinancialAI] ✗ Failed to initialize LLM client: {e}") return False # 全局 llm_client 变量 llm_client = None _init_llm_client() def get_system_prompt(): """生成系统提示词""" from datetime import datetime current_date = datetime.now().strftime("%Y-%m-%d") return f"""You are a financial analysis expert. Today is {current_date}. Your role: - Analyze company financial data, reports, and market news - Provide investment insights based on factual data - Be concise, objective, and data-driven - Always include disclaimers about market risks ⚠️ IMPORTANT: You have a maximum of 5 tool calls. Choose the MOST RELEVANT tools carefully: - Use 'advanced_search_company' ONLY if you need to find a company's CIK - Use 'extract_financial_metrics' for comprehensive multi-year financial analysis (RECOMMENDED for most queries) - Use 'get_latest_financial_data' for quick recent snapshot - Use 'get_quote' for real-time stock price - Use 'get_company_news' for company-specific news - Use 'get_market_news' for general market trends Prioritize the most important tools for the user's question. Avoid redundant calls. Output should be in English.""" def analyze_company_with_llm(company_input, analysis_type="summary"): """ 使用 LLM 分析公司信息 Args: company_input: 公司名称或代码 analysis_type: 分析类型 ("summary", "investment", "risks") Returns: LLM 分析结果 Example: result = analyze_company_with_llm("Apple", "investment") """ if not llm_client: return {"error": "LLM client not available"} if not MCP_DIRECT_AVAILABLE: return {"error": "MCP functions not available"} try: # 先获取公司财务数据 company_data = get_company_summary_direct(company_input) if company_data["status"] == "error": return {"error": f"Failed to fetch company data: {company_data['errors']}"} # 构建提示 data_str = json.dumps(company_data["data"], ensure_ascii=False, indent=2) if analysis_type == "investment": prompt = f""" Based on the following company financial data, provide an investment recommendation: {data_str} Provide: 1. Investment Recommendation (Buy/Hold/Sell) 2. Key Strengths and Weaknesses 3. Price Target Range 4. Risk Assessment 5. Risk Disclaimer """ elif analysis_type == "risks": prompt = f""" Based on the following company data, analyze the key risks: {data_str} Identify: 1. Financial Risks 2. Market Risks 3. Operational Risks 4. Mitigation Strategies 5. Risk Disclaimer """ else: # summary prompt = f""" Provide a financial summary of the following company: {data_str} Include: 1. Company Overview 2. Financial Health 3. Recent Performance 4. Investment Outlook """ # 调用 LLM response = llm_client.chat.completions.create( model="Qwen/Qwen2.5-72B-Instruct", messages=[ {"role": "system", "content": get_system_prompt()}, {"role": "user", "content": prompt} ], max_tokens=1500, temperature=0.7, top_p=0.95, stream=False ) return { "company": company_input, "analysis_type": analysis_type, "analysis": response.choices[0].message.content, "data_used": company_data["data"] } except Exception as e: return {"error": f"LLM analysis failed: {str(e)}"} # ============================================================ # 便捷方法 - 获取单一时期财务数据 # ============================================================ def get_financial_data_direct(cik, period): """ 获取指定时期的财务数据(直接调用) Args: cik: 公司 CIK 代码 period: 时期 (e.g., "2024", "2024Q3") Returns: 财务数据 Example: result = get_financial_data_direct("0000320193", "2024") """ if not MCP_DIRECT_AVAILABLE: return {"error": "MCP functions not available"} try: return _get_financial_data(cik, period) except Exception as e: return {"error": str(e)} # ============================================================ # 便捷方法 - 获取文件列表 # ============================================================ def get_company_filings_with_form_direct(cik, form_types=None): """ 获取指定类型的公司 SEC 文件列表(直接调用) Args: cik: 公司 CIK 代码 form_types: 表单类型列表 (e.g., ["10-K", "10-Q"]) Returns: 文件列表 Example: result = get_company_filings_with_form_direct("0000320193", ["10-K"]) """ if not MCP_DIRECT_AVAILABLE: return {"error": "MCP functions not available"} try: return _get_company_filings(cik, form_types) except Exception as e: return {"error": str(e)} # ============================================================ # 便捷方法 - 轻量级查询 # ============================================================ def get_company_summary_direct(company_input): """ 获取公司简要摘要信息(轻量级查询,仅搜索和基本信息) Args: company_input: 公司名称或代码 Returns: 公司摘要数据 Example: result = get_company_summary_direct("Apple") """ from datetime import datetime result = { "timestamp": datetime.now().isoformat(), "query_input": company_input, "status": "success", "data": { "company_search": None, "company_info": None }, "errors": [] } if not MCP_DIRECT_AVAILABLE: result["status"] = "error" result["errors"].append("MCP functions not available") return result try: # 1. 搜索公司 search_result = search_company_direct(company_input) if "error" in search_result: result["errors"].append(f"Search error: {search_result['error']}") result["status"] = "error" return result result["data"]["company_search"] = search_result # 从搜索结果提取 CIK cik = None if isinstance(search_result, dict): cik = search_result.get("cik") elif isinstance(search_result, (list, tuple)) and len(search_result) > 0: try: first_item = search_result[0] if isinstance(first_item, dict): cik = first_item.get("cik") except (IndexError, TypeError): pass if not cik: result["errors"].append("Could not extract CIK from search result") result["status"] = "error" return result # 2. 获取公司信息 company_info = get_company_info_direct(cik) if "error" not in company_info: result["data"]["company_info"] = company_info else: result["errors"].append(f"Failed to get company info: {company_info.get('error')}") except Exception as e: result["status"] = "error" result["errors"].append(f"Exception: {str(e)}") import traceback result["errors"].append(traceback.format_exc()) return result def get_financial_metrics_only_direct(company_input, years=5): """ 获取公司财务指标趋势(仅财务指标,不获取文件列表) Args: company_input: 公司名称或代码 years: 年数(默认 5 年) Returns: 财务指标数据 Example: result = get_financial_metrics_only_direct("Apple", years=5) """ from datetime import datetime result = { "timestamp": datetime.now().isoformat(), "query_input": company_input, "years": years, "status": "success", "data": None, "errors": [] } if not MCP_DIRECT_AVAILABLE: result["status"] = "error" result["errors"].append("MCP functions not available") return result try: # 1. 搜索公司 search_result = search_company_direct(company_input) if "error" in search_result: result["errors"].append(f"Search error: {search_result['error']}") result["status"] = "error" return result # 从搜索结果提取 CIK cik = None if isinstance(search_result, dict): cik = search_result.get("cik") elif isinstance(search_result, (list, tuple)) and len(search_result) > 0: try: first_item = search_result[0] if isinstance(first_item, dict): cik = first_item.get("cik") except (IndexError, TypeError): pass if not cik: result["errors"].append("Could not extract CIK from search result") result["status"] = "error" return result # 2. 获取财务指标 metrics = extract_financial_metrics_direct(cik, years=years) if "error" in metrics: result["errors"].append(f"Failed to get metrics: {metrics['error']}") result["status"] = "error" else: result["data"] = metrics except Exception as e: result["status"] = "error" result["errors"].append(f"Exception: {str(e)}") import traceback result["errors"].append(traceback.format_exc()) return result # ============================================================ # 测试函数 # ============================================================ if __name__ == "__main__": print("\n" + "="*60) print("Financial AI Assistant - Direct Method Test") print("="*60) # 测试 1: 公司搜索 print("\n1. 搜索公司 (Apple)...") result = search_company_direct("Apple") print(f" 结果: {result}") # 测试 2: 公司摘要 print("\n2. 获取公司摘要信息 (Tesla)...") summary = get_company_summary_direct("Tesla") print(f" 状态: {summary['status']}") print(f" 数据: {summary['data']}") print(f" 错误: {summary['errors']}") # 测试 3: 财务指标 print("\n3. 获取财务指标 (Microsoft)...") metrics = get_financial_metrics_only_direct("Microsoft", years=3) print(f" 状态: {metrics['status']}") if metrics['status'] == 'success': print(f" 指标数据: {metrics['data']}") else: print(f" 错误: {metrics['errors']}") # 测试 4: 完整查询 print("\n4. 获取 Amazon 完整信息...") full_query = query_company_direct("Amazon", get_filings=True, get_metrics=True) print(f" 状态: {full_query['status']}") print(f" 错误: {full_query['errors']}") # 测试 5: LLM 分析 - 摘要 print("\n5. LLM 分析 - 公司摘要(Google)...") if llm_client: llm_result = analyze_company_with_llm("Google", "summary") if "error" in llm_result: print(f" 错误: {llm_result['error']}") else: print(f" 分析结果: {llm_result['analysis'][:200]}...") else: print(" LLM 客户端不可用") # 测试 6: LLM 分析 - 投资建议 print("\n6. LLM 分析 - 投资建议(NVIDIA)...") if llm_client: llm_result = analyze_company_with_llm("NVIDIA", "investment") if "error" in llm_result: print(f" 错误: {llm_result['error']}") else: print(f" 分析结果: {llm_result['analysis'][:200]}...") else: print(" LLM 客户端不可用") # 测试 7: LLM 分析 - 风险评估 print("\n7. LLM 分析 - 风险评估(Meta)...") if llm_client: llm_result = analyze_company_with_llm("Meta", "risks") if "error" in llm_result: print(f" 错误: {llm_result['error']}") else: print(f" 分析结果: {llm_result['analysis'][:200]}...") else: print(" LLM 客户端不可用") print("\n" + "="*60) # ============================================================ # 完整对话引擎 - chatbot_response # ============================================================ # Token 限制配置 MAX_TOTAL_TOKENS = 6000 MAX_TOOL_RESULT_CHARS = 1500 MAX_HISTORY_CHARS = 500 MAX_HISTORY_TURNS = 2 MAX_TOOL_ITERATIONS = 5 # ✅ 限制最多调用5个工具,确保选择最合适的工具 MAX_OUTPUT_TOKENS = 2000 # MCP 工具配置 - 包含财务数据和市场新闻工具 MCP_TOOLS = [ # 财务数据工具 (EasyReportDataMCP) {"type": "function", "function": {"name": "advanced_search_company", "description": "Search US companies by name, ticker, or CIK. Returns company information including CIK, name, tickers, and industry classification.", "parameters": {"type": "object", "properties": {"company_input": {"type": "string", "description": "Company name (e.g., 'Tesla'), ticker symbol (e.g., 'TSLA'), or CIK code (e.g., '0001318605')"}}, "required": ["company_input"]}}}, {"type": "function", "function": {"name": "get_latest_financial_data", "description": "Get the most recent financial data for a company including revenue, net income, EPS, operating expenses, and cash flow.", "parameters": {"type": "object", "properties": {"cik": {"type": "string", "description": "Company CIK code (10-digit format, e.g., '0001318605')"}}, "required": ["cik"]}}}, {"type": "function", "function": {"name": "extract_financial_metrics", "description": "Extract multi-year financial metrics trends showing historical performance over specified years.", "parameters": {"type": "object", "properties": {"cik": {"type": "string", "description": "Company CIK code (10-digit format)"}, "years": {"type": "integer", "description": "Number of years of data to retrieve (e.g., 3 or 5)", "default": 3}}, "required": ["cik", "years"]}}}, # 市场和新闻工具 (MarketandStockMCP) {"type": "function", "function": {"name": "get_quote", "description": "Get real-time stock quote data including current price, daily change, high/low, and previous close. Use when users ask about current stock prices or market performance.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string", "description": "Stock ticker symbol (e.g., 'AAPL', 'TSLA', 'MSFT')"}}, "required": ["symbol"]}}}, {"type": "function", "function": {"name": "get_market_news", "description": "Get latest market news by category. Use when users ask about general market trends, forex, crypto, or M&A news.", "parameters": {"type": "object", "properties": {"category": {"type": "string", "enum": ["general", "forex", "crypto", "merger"], "description": "News category: general (stocks/economy), forex (currency), crypto (cryptocurrency), merger (M&A)", "default": "general"}, "min_id": {"type": "integer", "description": "Minimum news ID for pagination (default: 0)", "default": 0}}, "required": ["category"]}}}, {"type": "function", "function": {"name": "get_company_news", "description": "Get company-specific news and announcements. Only available for North American companies. Use when users ask about specific company news.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string", "description": "Company stock ticker symbol (e.g., 'AAPL', 'TSLA')"}, "from_date": {"type": "string", "description": "Start date in YYYY-MM-DD format (optional, defaults to 7 days ago)"}, "to_date": {"type": "string", "description": "End date in YYYY-MM-DD format (optional, defaults to today)"}}, "required": ["symbol"]}}} ] def truncate_text(text, max_chars, suffix="...[truncated]"): """截断文本到指定长度""" text = str(text) if len(text) <= max_chars: return text return text[:max_chars] + suffix def call_mcp_tool(tool_name, arguments): """直接调用 MCP 工具函数(不通过HTTP)""" try: # ✅ 财务数据工具 - 直接调用 Python 函数 if tool_name == "advanced_search_company": company_input = arguments.get("company_input", "") return _advanced_search_company(company_input) elif tool_name == "get_latest_financial_data": cik = arguments.get("cik", "") return _get_latest_financial_data(cik) elif tool_name == "extract_financial_metrics": cik = arguments.get("cik", "") years = arguments.get("years", 3) return _extract_financial_metrics(cik, years) # ✅ 市场和新闻工具 - 直接调用 Python 函数 elif tool_name == "get_quote": from MarketandStockMCP.news_quote_mcp import get_quote symbol = arguments.get("symbol", "") return get_quote(symbol) elif tool_name == "get_market_news": from MarketandStockMCP.news_quote_mcp import get_market_news category = arguments.get("category", "general") min_id = arguments.get("min_id", 0) return get_market_news(category, min_id) elif tool_name == "get_company_news": from MarketandStockMCP.news_quote_mcp import get_company_news symbol = arguments.get("symbol", "") from_date = arguments.get("from_date") to_date = arguments.get("to_date") return get_company_news(symbol, from_date, to_date) else: return {"error": f"Unknown tool: {tool_name}"} except Exception as e: import traceback return { "error": f"{str(e)}", "traceback": traceback.format_exc()[:500] } def chatbot_response(message, history=None): """ AI 助手主函数(完整对话引擎) 支持多轮对话、动态工具调用、流式输出 Args: message: 用户消息 history: 对话历史,格式: [(user_msg, assistant_msg), ...] Returns: 生成器,不断 yield 响应文本 Example: for response in chatbot_response("What's Apple's revenue?", []): print(response) """ if not llm_client: yield "❌ Error: LLM client not available" return if not MCP_DIRECT_AVAILABLE: yield "❌ Error: MCP functions not available" return try: messages = [{"role": "system", "content": get_system_prompt()}] # 添加历史(最近2轮) - 严格限制上下文长度 if history: for item in history[-MAX_HISTORY_TURNS:]: if isinstance(item, (list, tuple)) and len(item) == 2: messages.append({"role": "user", "content": item[0]}) assistant_msg = str(item[1]) if len(assistant_msg) > MAX_HISTORY_CHARS: assistant_msg = truncate_text(assistant_msg, MAX_HISTORY_CHARS) messages.append({"role": "assistant", "content": assistant_msg}) messages.append({"role": "user", "content": message}) tool_calls_log = [] final_response_content = None # LLM 调用循环(支持多轮工具调用) for iteration in range(MAX_TOOL_ITERATIONS): response = llm_client.chat.completions.create( model="Qwen/Qwen2.5-72B-Instruct", messages=messages, tools=MCP_TOOLS, # type: ignore max_tokens=MAX_OUTPUT_TOKENS, temperature=0.7, tool_choice="auto", stream=False ) choice = response.choices[0] if choice.message.tool_calls: messages.append(choice.message) for tool_call in choice.message.tool_calls: tool_name = tool_call.function.name try: tool_args = json.loads(tool_call.function.arguments) except json.JSONDecodeError: tool_args = {} tool_result = call_mcp_tool(tool_name, tool_args) if isinstance(tool_result, dict) and "error" in tool_result: tool_calls_log.append({"name": tool_name, "arguments": tool_args, "result": tool_result, "error": True}) result_for_llm = json.dumps({"error": tool_result.get("error", "Unknown error")}, ensure_ascii=False) else: result_str = json.dumps(tool_result, ensure_ascii=False) if len(result_str) > MAX_TOOL_RESULT_CHARS: if isinstance(tool_result, dict) and "text" in tool_result: truncated_text = truncate_text(tool_result["text"], MAX_TOOL_RESULT_CHARS - 50) tool_result_truncated = {"text": truncated_text, "_truncated": True} elif isinstance(tool_result, dict): truncated = {} char_count = 0 for k, v in list(tool_result.items())[:8]: v_str = str(v)[:300] truncated[k] = v_str char_count += len(k) + len(v_str) if char_count > MAX_TOOL_RESULT_CHARS: break tool_result_truncated = {**truncated, "_truncated": True} else: tool_result_truncated = {"preview": truncate_text(result_str, MAX_TOOL_RESULT_CHARS), "_truncated": True} result_for_llm = json.dumps(tool_result_truncated, ensure_ascii=False) else: result_for_llm = result_str tool_calls_log.append({"name": tool_name, "arguments": tool_args, "result": tool_result}) messages.append({ "role": "tool", "name": tool_name, "content": result_for_llm, "tool_call_id": tool_call.id }) continue else: final_response_content = choice.message.content break response_prefix = "" if tool_calls_log: # ✅ 可折叠的工具调用显示,点击三角形展开/收起 tool_count = len(tool_calls_log) # 添加CSS样式,实现三角形旋转动画 response_prefix += """ """ response_prefix += f"""
🛠️ Tools Used ({tool_count}/{MAX_TOOL_ITERATIONS} calls)
""" for idx, tool_call in enumerate(tool_calls_log): args_json = json.dumps(tool_call['arguments'], ensure_ascii=False) result_json = json.dumps(tool_call.get('result', {}), ensure_ascii=False, indent=2) result_preview = result_json[:1500] + ('...' if len(result_json) > 1500 else '') error_indicator = " ❌ Error" if tool_call.get('error') else "" response_prefix += f"""
📋 {idx+1}. {tool_call['name']}{error_indicator}
{result_preview}
""" # ✅ 关闭外层details和div标签 response_prefix += """
--- """ yield response_prefix if final_response_content: yield response_prefix + final_response_content else: try: stream = llm_client.chat.completions.create( model="Qwen/Qwen2.5-72B-Instruct", messages=messages, tools=None, max_tokens=MAX_OUTPUT_TOKENS, temperature=0.7, stream=True ) accumulated_text = "" for chunk in stream: if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content: accumulated_text += chunk.choices[0].delta.content yield response_prefix + accumulated_text except Exception: final_resp = llm_client.chat.completions.create( model="Qwen/Qwen2.5-72B-Instruct", messages=messages, tools=None, max_tokens=MAX_OUTPUT_TOKENS, temperature=0.7, stream=False ) yield response_prefix + (final_resp.choices[0].message.content or "") except Exception as e: import traceback error_detail = str(e) if "500" in error_detail: yield f"❌ Error: 模型服务器错误\n\n{error_detail[:200]}" else: yield f"❌ Error: {error_detail}\n\n{traceback.format_exc()[:500]}"