# llm_router.py - UPDATED FOR LOCAL GPU MODEL LOADING import logging import asyncio from typing import Dict, Optional from .models_config import LLM_CONFIG # Import GatedRepoError for handling gated repositories try: from huggingface_hub.exceptions import GatedRepoError except ImportError: # Fallback if huggingface_hub is not available GatedRepoError = Exception logger = logging.getLogger(__name__) class LLMRouter: def __init__(self, hf_token=None, use_local_models: bool = True): # hf_token kept for backward compatibility but not used for API calls # Only needed for downloading gated models from HuggingFace Hub self.hf_token = hf_token self.health_status = {} self.use_local_models = use_local_models self.local_loader = None logger.info("LLMRouter initialized (local models only, no API fallback)") if hf_token: logger.info("HF token available (for model download only)") else: logger.warning("HF_TOKEN not set - may be needed for gated model access") # Initialize local model loader - REQUIRED if self.use_local_models: try: from .local_model_loader import LocalModelLoader self.local_loader = LocalModelLoader() logger.info("✓ Local model loader initialized (GPU-based inference)") # Note: Pre-loading will happen on first request (lazy loading) # Models will be loaded on-demand to avoid blocking startup logger.info("Models will be loaded on-demand for faster startup") except Exception as e: logger.error(f"❌ CRITICAL: Could not initialize local model loader: {e}") logger.error("Local models are required - API fallback has been removed") raise RuntimeError( "Local model loader is required but could not be initialized. " "Please ensure transformers and torch are installed." ) from e else: logger.error("use_local_models=False but API fallback removed - this will fail") raise ValueError("use_local_models must be True - API fallback has been removed") async def route_inference(self, task_type: str, prompt: str, **kwargs): """ Smart routing based on task specialization Uses ONLY local models - no API fallback """ logger.info(f"Routing inference for task: {task_type}") model_config = self._select_model(task_type) logger.info(f"Selected model: {model_config['model_id']}") # Use local models only if not self.local_loader: raise RuntimeError("Local model loader not available - cannot perform inference") try: # Handle embedding generation separately if task_type == "embedding_generation": result = await self._call_local_embedding(model_config, prompt, **kwargs) else: result = await self._call_local_model(model_config, prompt, task_type, **kwargs) if result is None: logger.error(f"Local model returned None for task: {task_type}") raise RuntimeError(f"Inference failed for task: {task_type}") logger.info(f"Inference complete for {task_type} (local model)") return result except Exception as e: logger.error(f"Local model inference failed: {e}", exc_info=True) # Try fallback model if configured fallback_model_id = model_config.get("fallback") if fallback_model_id and fallback_model_id != model_config["model_id"]: logger.warning(f"Attempting fallback model: {fallback_model_id}") try: fallback_config = model_config.copy() fallback_config["model_id"] = fallback_model_id fallback_config.pop("fallback", None) # Prevent infinite recursion if task_type == "embedding_generation": result = await self._call_local_embedding(fallback_config, prompt, **kwargs) else: result = await self._call_local_model(fallback_config, prompt, task_type, **{**kwargs, '_is_fallback': True}) if result is not None: logger.info(f"Inference complete using fallback model: {fallback_model_id}") return result except Exception as fallback_error: logger.error(f"Fallback model also failed: {fallback_error}") # No API fallback - raise error raise RuntimeError( f"Inference failed for task: {task_type}. " f"Local models are required - ensure models are properly loaded and accessible." ) from e async def _call_local_model(self, model_config: dict, prompt: str, task_type: str, **kwargs) -> Optional[str]: """Call local model for inference.""" if not self.local_loader: return None # Check if this is already a fallback attempt (prevent infinite loops) is_fallback_attempt = kwargs.get('_is_fallback', False) model_id = model_config["model_id"] max_tokens = kwargs.get('max_tokens', 512) temperature = kwargs.get('temperature', 0.7) try: # Ensure model is loaded if model_id not in self.local_loader.loaded_models: logger.info(f"Loading model {model_id} on demand...") # Check if model config specifies quantization use_4bit = model_config.get("use_4bit_quantization", False) use_8bit = model_config.get("use_8bit_quantization", False) # Fallback to default quantization settings if not specified if not use_4bit and not use_8bit: quantization_config = LLM_CONFIG.get("quantization_settings", {}) use_4bit = quantization_config.get("default_4bit", True) use_8bit = quantization_config.get("default_8bit", False) try: self.local_loader.load_chat_model( model_id, load_in_8bit=use_8bit, load_in_4bit=use_4bit ) except GatedRepoError as e: logger.error(f"❌ Cannot access gated repository {model_id}") logger.error(f" Visit https://huggingface.co/{model_id.split(':')[0] if ':' in model_id else model_id} to request access.") # Prevent infinite loops: if this is already a fallback attempt, don't try another fallback if is_fallback_attempt: logger.error("❌ Fallback model also failed with gated repository error") raise RuntimeError("Both primary and fallback models are gated repositories") from e # Try fallback model if available and this is not already a fallback attempt fallback_model_id = model_config.get("fallback") if fallback_model_id and fallback_model_id != model_id: # Ensure fallback is different logger.warning(f"Attempting fallback model: {fallback_model_id}") try: # Create fallback config without fallback to prevent loops fallback_config = model_config.copy() fallback_config["model_id"] = fallback_model_id fallback_config.pop("fallback", None) # Remove fallback to prevent infinite recursion # Retry with fallback model (mark as fallback attempt) return await self._call_local_model( fallback_config, prompt, task_type, **{**kwargs, '_is_fallback': True} ) except GatedRepoError as fallback_gated_error: logger.error(f"❌ Fallback model {fallback_model_id} is also gated") raise RuntimeError("Both primary and fallback models are gated repositories") from fallback_gated_error except Exception as fallback_error: logger.error(f"Fallback model also failed: {fallback_error}") raise else: raise RuntimeError(f"Model {model_id} is a gated repository and no fallback available") from e # Format as chat messages if needed messages = [{"role": "user", "content": prompt}] # Generate using local model result = await asyncio.to_thread( self.local_loader.generate_chat_completion, model_id=model_id, messages=messages, max_tokens=max_tokens, temperature=temperature ) logger.info(f"Local model {model_id} generated response (length: {len(result)})") logger.info("=" * 80) logger.info("LOCAL MODEL RESPONSE:") logger.info("=" * 80) logger.info(f"Model: {model_id}") logger.info(f"Task Type: {task_type}") logger.info(f"Response Length: {len(result)} characters") logger.info("-" * 40) logger.info("FULL RESPONSE CONTENT:") logger.info("-" * 40) logger.info(result) logger.info("-" * 40) logger.info("END OF RESPONSE") logger.info("=" * 80) return result except GatedRepoError: # Re-raise to be handled by caller raise except Exception as e: logger.error(f"Error calling local model: {e}", exc_info=True) raise async def _call_local_embedding(self, model_config: dict, text: str, **kwargs) -> Optional[list]: """Call local embedding model.""" if not self.local_loader: raise RuntimeError("Local model loader not available") model_id = model_config["model_id"] try: # Ensure model is loaded if model_id not in self.local_loader.loaded_embedding_models: logger.info(f"Loading embedding model {model_id} on demand...") try: self.local_loader.load_embedding_model(model_id) except GatedRepoError as e: logger.error(f"❌ Cannot access gated repository {model_id}") logger.error(f" Visit https://huggingface.co/{model_id.split(':')[0] if ':' in model_id else model_id} to request access.") raise RuntimeError(f"Embedding model {model_id} is a gated repository") from e # Generate embedding embedding = await asyncio.to_thread( self.local_loader.get_embedding, model_id=model_id, text=text ) logger.info(f"Local embedding model {model_id} generated vector (dim: {len(embedding)})") return embedding except Exception as e: logger.error(f"Error calling local embedding model: {e}", exc_info=True) raise def _select_model(self, task_type: str) -> dict: model_map = { "intent_classification": LLM_CONFIG["models"]["classification_specialist"], "embedding_generation": LLM_CONFIG["models"]["embedding_specialist"], "safety_check": LLM_CONFIG["models"]["safety_checker"], "general_reasoning": LLM_CONFIG["models"]["reasoning_primary"], "response_synthesis": LLM_CONFIG["models"]["reasoning_primary"] } return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"]) # REMOVED: _is_model_healthy - no longer needed (local models only) # REMOVED: _get_fallback_model - no longer needed (local models only) # REMOVED: _call_hf_endpoint - HF API inference removed async def get_available_models(self): """ Get list of available models for testing """ return list(LLM_CONFIG["models"].keys()) async def health_check(self): """ Perform health check on local models only """ health_status = {} if not self.local_loader: return {"error": "Local model loader not available"} for model_name, model_config in LLM_CONFIG["models"].items(): model_id = model_config["model_id"] # Check if model is loaded (for chat models) is_loaded = model_id in self.local_loader.loaded_models or model_id in self.local_loader.loaded_embedding_models health_status[model_name] = { "model_id": model_id, "loaded": is_loaded, "healthy": is_loaded # Consider loaded models healthy } return health_status def prepare_context_for_llm(self, raw_context: Dict, max_tokens: int = 4000) -> str: """Smart context windowing for LLM calls""" try: from transformers import AutoTokenizer # Initialize tokenizer lazily if not hasattr(self, 'tokenizer'): try: # Use the primary model for tokenization primary_model_id = LLM_CONFIG["models"]["reasoning_primary"]["model_id"] # Strip API suffix if present (though we don't use them anymore) base_model_id = primary_model_id.split(':')[0] if ':' in primary_model_id else primary_model_id self.tokenizer = AutoTokenizer.from_pretrained(base_model_id) except GatedRepoError as e: logger.warning(f"Gated repository error loading tokenizer: {e}") logger.warning("Using character count estimation instead") self.tokenizer = None except Exception as e: logger.warning(f"Could not load tokenizer: {e}, using character count estimation") self.tokenizer = None except ImportError: logger.warning("transformers library not available, using character count estimation") self.tokenizer = None # Priority order for context elements priority_elements = [ ('current_query', 1.0), ('recent_interactions', 0.8), ('user_preferences', 0.6), ('session_summary', 0.4), ('historical_context', 0.2) ] formatted_context = [] total_tokens = 0 for element, priority in priority_elements: # Map element names to context keys element_key_map = { 'current_query': raw_context.get('user_input', ''), 'recent_interactions': raw_context.get('interaction_contexts', []), 'user_preferences': raw_context.get('preferences', {}), 'session_summary': raw_context.get('session_context', {}), 'historical_context': raw_context.get('user_context', '') } content = element_key_map.get(element, '') # Convert to string if needed if isinstance(content, dict): content = str(content) elif isinstance(content, list): content = "\n".join([str(item) for item in content[:10]]) # Limit to 10 items if not content: continue # Estimate tokens if self.tokenizer: try: tokens = len(self.tokenizer.encode(content)) except: # Fallback to character-based estimation (rough: 1 token ≈ 4 chars) tokens = len(content) // 4 else: # Character-based estimation (rough: 1 token ≈ 4 chars) tokens = len(content) // 4 if total_tokens + tokens <= max_tokens: formatted_context.append(f"=== {element.upper()} ===\n{content}") total_tokens += tokens elif priority > 0.5: # Critical elements - truncate if needed available = max_tokens - total_tokens if available > 100: # Only truncate if we have meaningful space truncated = self._truncate_to_tokens(content, available) formatted_context.append(f"=== {element.upper()} (TRUNCATED) ===\n{truncated}") break return "\n\n".join(formatted_context) def _truncate_to_tokens(self, content: str, max_tokens: int) -> str: """Truncate content to fit within token limit""" if not self.tokenizer: # Simple character-based truncation max_chars = max_tokens * 4 if len(content) <= max_chars: return content return content[:max_chars-3] + "..." try: # Tokenize and truncate tokens = self.tokenizer.encode(content) if len(tokens) <= max_tokens: return content truncated_tokens = tokens[:max_tokens-3] # Leave room for "..." truncated_text = self.tokenizer.decode(truncated_tokens) return truncated_text + "..." except Exception as e: logger.warning(f"Error truncating with tokenizer: {e}, using character truncation") max_chars = max_tokens * 4 if len(content) <= max_chars: return content return content[:max_chars-3] + "..."