HonestAI / src /llm_router.py
JatsTheAIGen's picture
Phase 1: Remove HF API inference - Local models only
5787d0a
raw
history blame
18.5 kB
# llm_router.py - UPDATED FOR LOCAL GPU MODEL LOADING
import logging
import asyncio
from typing import Dict, Optional
from .models_config import LLM_CONFIG
# Import GatedRepoError for handling gated repositories
try:
from huggingface_hub.exceptions import GatedRepoError
except ImportError:
# Fallback if huggingface_hub is not available
GatedRepoError = Exception
logger = logging.getLogger(__name__)
class LLMRouter:
def __init__(self, hf_token=None, use_local_models: bool = True):
# hf_token kept for backward compatibility but not used for API calls
# Only needed for downloading gated models from HuggingFace Hub
self.hf_token = hf_token
self.health_status = {}
self.use_local_models = use_local_models
self.local_loader = None
logger.info("LLMRouter initialized (local models only, no API fallback)")
if hf_token:
logger.info("HF token available (for model download only)")
else:
logger.warning("HF_TOKEN not set - may be needed for gated model access")
# Initialize local model loader - REQUIRED
if self.use_local_models:
try:
from .local_model_loader import LocalModelLoader
self.local_loader = LocalModelLoader()
logger.info("✓ Local model loader initialized (GPU-based inference)")
# Note: Pre-loading will happen on first request (lazy loading)
# Models will be loaded on-demand to avoid blocking startup
logger.info("Models will be loaded on-demand for faster startup")
except Exception as e:
logger.error(f"❌ CRITICAL: Could not initialize local model loader: {e}")
logger.error("Local models are required - API fallback has been removed")
raise RuntimeError(
"Local model loader is required but could not be initialized. "
"Please ensure transformers and torch are installed."
) from e
else:
logger.error("use_local_models=False but API fallback removed - this will fail")
raise ValueError("use_local_models must be True - API fallback has been removed")
async def route_inference(self, task_type: str, prompt: str, **kwargs):
"""
Smart routing based on task specialization
Uses ONLY local models - no API fallback
"""
logger.info(f"Routing inference for task: {task_type}")
model_config = self._select_model(task_type)
logger.info(f"Selected model: {model_config['model_id']}")
# Use local models only
if not self.local_loader:
raise RuntimeError("Local model loader not available - cannot perform inference")
try:
# Handle embedding generation separately
if task_type == "embedding_generation":
result = await self._call_local_embedding(model_config, prompt, **kwargs)
else:
result = await self._call_local_model(model_config, prompt, task_type, **kwargs)
if result is None:
logger.error(f"Local model returned None for task: {task_type}")
raise RuntimeError(f"Inference failed for task: {task_type}")
logger.info(f"Inference complete for {task_type} (local model)")
return result
except Exception as e:
logger.error(f"Local model inference failed: {e}", exc_info=True)
# Try fallback model if configured
fallback_model_id = model_config.get("fallback")
if fallback_model_id and fallback_model_id != model_config["model_id"]:
logger.warning(f"Attempting fallback model: {fallback_model_id}")
try:
fallback_config = model_config.copy()
fallback_config["model_id"] = fallback_model_id
fallback_config.pop("fallback", None) # Prevent infinite recursion
if task_type == "embedding_generation":
result = await self._call_local_embedding(fallback_config, prompt, **kwargs)
else:
result = await self._call_local_model(fallback_config, prompt, task_type, **{**kwargs, '_is_fallback': True})
if result is not None:
logger.info(f"Inference complete using fallback model: {fallback_model_id}")
return result
except Exception as fallback_error:
logger.error(f"Fallback model also failed: {fallback_error}")
# No API fallback - raise error
raise RuntimeError(
f"Inference failed for task: {task_type}. "
f"Local models are required - ensure models are properly loaded and accessible."
) from e
async def _call_local_model(self, model_config: dict, prompt: str, task_type: str, **kwargs) -> Optional[str]:
"""Call local model for inference."""
if not self.local_loader:
return None
# Check if this is already a fallback attempt (prevent infinite loops)
is_fallback_attempt = kwargs.get('_is_fallback', False)
model_id = model_config["model_id"]
max_tokens = kwargs.get('max_tokens', 512)
temperature = kwargs.get('temperature', 0.7)
try:
# Ensure model is loaded
if model_id not in self.local_loader.loaded_models:
logger.info(f"Loading model {model_id} on demand...")
# Check if model config specifies quantization
use_4bit = model_config.get("use_4bit_quantization", False)
use_8bit = model_config.get("use_8bit_quantization", False)
# Fallback to default quantization settings if not specified
if not use_4bit and not use_8bit:
quantization_config = LLM_CONFIG.get("quantization_settings", {})
use_4bit = quantization_config.get("default_4bit", True)
use_8bit = quantization_config.get("default_8bit", False)
try:
self.local_loader.load_chat_model(
model_id,
load_in_8bit=use_8bit,
load_in_4bit=use_4bit
)
except GatedRepoError as e:
logger.error(f"❌ Cannot access gated repository {model_id}")
logger.error(f" Visit https://huggingface.co/{model_id.split(':')[0] if ':' in model_id else model_id} to request access.")
# Prevent infinite loops: if this is already a fallback attempt, don't try another fallback
if is_fallback_attempt:
logger.error("❌ Fallback model also failed with gated repository error")
raise RuntimeError("Both primary and fallback models are gated repositories") from e
# Try fallback model if available and this is not already a fallback attempt
fallback_model_id = model_config.get("fallback")
if fallback_model_id and fallback_model_id != model_id: # Ensure fallback is different
logger.warning(f"Attempting fallback model: {fallback_model_id}")
try:
# Create fallback config without fallback to prevent loops
fallback_config = model_config.copy()
fallback_config["model_id"] = fallback_model_id
fallback_config.pop("fallback", None) # Remove fallback to prevent infinite recursion
# Retry with fallback model (mark as fallback attempt)
return await self._call_local_model(
fallback_config,
prompt,
task_type,
**{**kwargs, '_is_fallback': True}
)
except GatedRepoError as fallback_gated_error:
logger.error(f"❌ Fallback model {fallback_model_id} is also gated")
raise RuntimeError("Both primary and fallback models are gated repositories") from fallback_gated_error
except Exception as fallback_error:
logger.error(f"Fallback model also failed: {fallback_error}")
raise
else:
raise RuntimeError(f"Model {model_id} is a gated repository and no fallback available") from e
# Format as chat messages if needed
messages = [{"role": "user", "content": prompt}]
# Generate using local model
result = await asyncio.to_thread(
self.local_loader.generate_chat_completion,
model_id=model_id,
messages=messages,
max_tokens=max_tokens,
temperature=temperature
)
logger.info(f"Local model {model_id} generated response (length: {len(result)})")
logger.info("=" * 80)
logger.info("LOCAL MODEL RESPONSE:")
logger.info("=" * 80)
logger.info(f"Model: {model_id}")
logger.info(f"Task Type: {task_type}")
logger.info(f"Response Length: {len(result)} characters")
logger.info("-" * 40)
logger.info("FULL RESPONSE CONTENT:")
logger.info("-" * 40)
logger.info(result)
logger.info("-" * 40)
logger.info("END OF RESPONSE")
logger.info("=" * 80)
return result
except GatedRepoError:
# Re-raise to be handled by caller
raise
except Exception as e:
logger.error(f"Error calling local model: {e}", exc_info=True)
raise
async def _call_local_embedding(self, model_config: dict, text: str, **kwargs) -> Optional[list]:
"""Call local embedding model."""
if not self.local_loader:
raise RuntimeError("Local model loader not available")
model_id = model_config["model_id"]
try:
# Ensure model is loaded
if model_id not in self.local_loader.loaded_embedding_models:
logger.info(f"Loading embedding model {model_id} on demand...")
try:
self.local_loader.load_embedding_model(model_id)
except GatedRepoError as e:
logger.error(f"❌ Cannot access gated repository {model_id}")
logger.error(f" Visit https://huggingface.co/{model_id.split(':')[0] if ':' in model_id else model_id} to request access.")
raise RuntimeError(f"Embedding model {model_id} is a gated repository") from e
# Generate embedding
embedding = await asyncio.to_thread(
self.local_loader.get_embedding,
model_id=model_id,
text=text
)
logger.info(f"Local embedding model {model_id} generated vector (dim: {len(embedding)})")
return embedding
except Exception as e:
logger.error(f"Error calling local embedding model: {e}", exc_info=True)
raise
def _select_model(self, task_type: str) -> dict:
model_map = {
"intent_classification": LLM_CONFIG["models"]["classification_specialist"],
"embedding_generation": LLM_CONFIG["models"]["embedding_specialist"],
"safety_check": LLM_CONFIG["models"]["safety_checker"],
"general_reasoning": LLM_CONFIG["models"]["reasoning_primary"],
"response_synthesis": LLM_CONFIG["models"]["reasoning_primary"]
}
return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"])
# REMOVED: _is_model_healthy - no longer needed (local models only)
# REMOVED: _get_fallback_model - no longer needed (local models only)
# REMOVED: _call_hf_endpoint - HF API inference removed
async def get_available_models(self):
"""
Get list of available models for testing
"""
return list(LLM_CONFIG["models"].keys())
async def health_check(self):
"""
Perform health check on local models only
"""
health_status = {}
if not self.local_loader:
return {"error": "Local model loader not available"}
for model_name, model_config in LLM_CONFIG["models"].items():
model_id = model_config["model_id"]
# Check if model is loaded (for chat models)
is_loaded = model_id in self.local_loader.loaded_models or model_id in self.local_loader.loaded_embedding_models
health_status[model_name] = {
"model_id": model_id,
"loaded": is_loaded,
"healthy": is_loaded # Consider loaded models healthy
}
return health_status
def prepare_context_for_llm(self, raw_context: Dict, max_tokens: int = 4000) -> str:
"""Smart context windowing for LLM calls"""
try:
from transformers import AutoTokenizer
# Initialize tokenizer lazily
if not hasattr(self, 'tokenizer'):
try:
# Use the primary model for tokenization
primary_model_id = LLM_CONFIG["models"]["reasoning_primary"]["model_id"]
# Strip API suffix if present (though we don't use them anymore)
base_model_id = primary_model_id.split(':')[0] if ':' in primary_model_id else primary_model_id
self.tokenizer = AutoTokenizer.from_pretrained(base_model_id)
except GatedRepoError as e:
logger.warning(f"Gated repository error loading tokenizer: {e}")
logger.warning("Using character count estimation instead")
self.tokenizer = None
except Exception as e:
logger.warning(f"Could not load tokenizer: {e}, using character count estimation")
self.tokenizer = None
except ImportError:
logger.warning("transformers library not available, using character count estimation")
self.tokenizer = None
# Priority order for context elements
priority_elements = [
('current_query', 1.0),
('recent_interactions', 0.8),
('user_preferences', 0.6),
('session_summary', 0.4),
('historical_context', 0.2)
]
formatted_context = []
total_tokens = 0
for element, priority in priority_elements:
# Map element names to context keys
element_key_map = {
'current_query': raw_context.get('user_input', ''),
'recent_interactions': raw_context.get('interaction_contexts', []),
'user_preferences': raw_context.get('preferences', {}),
'session_summary': raw_context.get('session_context', {}),
'historical_context': raw_context.get('user_context', '')
}
content = element_key_map.get(element, '')
# Convert to string if needed
if isinstance(content, dict):
content = str(content)
elif isinstance(content, list):
content = "\n".join([str(item) for item in content[:10]]) # Limit to 10 items
if not content:
continue
# Estimate tokens
if self.tokenizer:
try:
tokens = len(self.tokenizer.encode(content))
except:
# Fallback to character-based estimation (rough: 1 token ≈ 4 chars)
tokens = len(content) // 4
else:
# Character-based estimation (rough: 1 token ≈ 4 chars)
tokens = len(content) // 4
if total_tokens + tokens <= max_tokens:
formatted_context.append(f"=== {element.upper()} ===\n{content}")
total_tokens += tokens
elif priority > 0.5: # Critical elements - truncate if needed
available = max_tokens - total_tokens
if available > 100: # Only truncate if we have meaningful space
truncated = self._truncate_to_tokens(content, available)
formatted_context.append(f"=== {element.upper()} (TRUNCATED) ===\n{truncated}")
break
return "\n\n".join(formatted_context)
def _truncate_to_tokens(self, content: str, max_tokens: int) -> str:
"""Truncate content to fit within token limit"""
if not self.tokenizer:
# Simple character-based truncation
max_chars = max_tokens * 4
if len(content) <= max_chars:
return content
return content[:max_chars-3] + "..."
try:
# Tokenize and truncate
tokens = self.tokenizer.encode(content)
if len(tokens) <= max_tokens:
return content
truncated_tokens = tokens[:max_tokens-3] # Leave room for "..."
truncated_text = self.tokenizer.decode(truncated_tokens)
return truncated_text + "..."
except Exception as e:
logger.warning(f"Error truncating with tokenizer: {e}, using character truncation")
max_chars = max_tokens * 4
if len(content) <= max_chars:
return content
return content[:max_chars-3] + "..."