|
|
|
|
|
import logging |
|
|
import asyncio |
|
|
from typing import Dict, Optional |
|
|
from .models_config import LLM_CONFIG |
|
|
|
|
|
|
|
|
try: |
|
|
from huggingface_hub.exceptions import GatedRepoError |
|
|
except ImportError: |
|
|
|
|
|
GatedRepoError = Exception |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class LLMRouter: |
|
|
def __init__(self, hf_token=None, use_local_models: bool = True): |
|
|
|
|
|
|
|
|
self.hf_token = hf_token |
|
|
self.health_status = {} |
|
|
self.use_local_models = use_local_models |
|
|
self.local_loader = None |
|
|
|
|
|
logger.info("LLMRouter initialized (local models only, no API fallback)") |
|
|
if hf_token: |
|
|
logger.info("HF token available (for model download only)") |
|
|
else: |
|
|
logger.warning("HF_TOKEN not set - may be needed for gated model access") |
|
|
|
|
|
|
|
|
if self.use_local_models: |
|
|
try: |
|
|
from .local_model_loader import LocalModelLoader |
|
|
self.local_loader = LocalModelLoader() |
|
|
logger.info("✓ Local model loader initialized (GPU-based inference)") |
|
|
|
|
|
|
|
|
|
|
|
logger.info("Models will be loaded on-demand for faster startup") |
|
|
except Exception as e: |
|
|
logger.error(f"❌ CRITICAL: Could not initialize local model loader: {e}") |
|
|
logger.error("Local models are required - API fallback has been removed") |
|
|
raise RuntimeError( |
|
|
"Local model loader is required but could not be initialized. " |
|
|
"Please ensure transformers and torch are installed." |
|
|
) from e |
|
|
else: |
|
|
logger.error("use_local_models=False but API fallback removed - this will fail") |
|
|
raise ValueError("use_local_models must be True - API fallback has been removed") |
|
|
|
|
|
async def route_inference(self, task_type: str, prompt: str, **kwargs): |
|
|
""" |
|
|
Smart routing based on task specialization |
|
|
Uses ONLY local models - no API fallback |
|
|
""" |
|
|
logger.info(f"Routing inference for task: {task_type}") |
|
|
model_config = self._select_model(task_type) |
|
|
logger.info(f"Selected model: {model_config['model_id']}") |
|
|
|
|
|
|
|
|
if not self.local_loader: |
|
|
raise RuntimeError("Local model loader not available - cannot perform inference") |
|
|
|
|
|
try: |
|
|
|
|
|
if task_type == "embedding_generation": |
|
|
result = await self._call_local_embedding(model_config, prompt, **kwargs) |
|
|
else: |
|
|
result = await self._call_local_model(model_config, prompt, task_type, **kwargs) |
|
|
|
|
|
if result is None: |
|
|
logger.error(f"Local model returned None for task: {task_type}") |
|
|
raise RuntimeError(f"Inference failed for task: {task_type}") |
|
|
|
|
|
logger.info(f"Inference complete for {task_type} (local model)") |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Local model inference failed: {e}", exc_info=True) |
|
|
|
|
|
fallback_model_id = model_config.get("fallback") |
|
|
if fallback_model_id and fallback_model_id != model_config["model_id"]: |
|
|
logger.warning(f"Attempting fallback model: {fallback_model_id}") |
|
|
try: |
|
|
fallback_config = model_config.copy() |
|
|
fallback_config["model_id"] = fallback_model_id |
|
|
fallback_config.pop("fallback", None) |
|
|
|
|
|
if task_type == "embedding_generation": |
|
|
result = await self._call_local_embedding(fallback_config, prompt, **kwargs) |
|
|
else: |
|
|
result = await self._call_local_model(fallback_config, prompt, task_type, **{**kwargs, '_is_fallback': True}) |
|
|
|
|
|
if result is not None: |
|
|
logger.info(f"Inference complete using fallback model: {fallback_model_id}") |
|
|
return result |
|
|
except Exception as fallback_error: |
|
|
logger.error(f"Fallback model also failed: {fallback_error}") |
|
|
|
|
|
|
|
|
raise RuntimeError( |
|
|
f"Inference failed for task: {task_type}. " |
|
|
f"Local models are required - ensure models are properly loaded and accessible." |
|
|
) from e |
|
|
|
|
|
async def _call_local_model(self, model_config: dict, prompt: str, task_type: str, **kwargs) -> Optional[str]: |
|
|
"""Call local model for inference.""" |
|
|
if not self.local_loader: |
|
|
return None |
|
|
|
|
|
|
|
|
is_fallback_attempt = kwargs.get('_is_fallback', False) |
|
|
|
|
|
model_id = model_config["model_id"] |
|
|
max_tokens = kwargs.get('max_tokens', 512) |
|
|
temperature = kwargs.get('temperature', 0.7) |
|
|
|
|
|
try: |
|
|
|
|
|
if model_id not in self.local_loader.loaded_models: |
|
|
logger.info(f"Loading model {model_id} on demand...") |
|
|
|
|
|
use_4bit = model_config.get("use_4bit_quantization", False) |
|
|
use_8bit = model_config.get("use_8bit_quantization", False) |
|
|
|
|
|
if not use_4bit and not use_8bit: |
|
|
quantization_config = LLM_CONFIG.get("quantization_settings", {}) |
|
|
use_4bit = quantization_config.get("default_4bit", True) |
|
|
use_8bit = quantization_config.get("default_8bit", False) |
|
|
|
|
|
try: |
|
|
self.local_loader.load_chat_model( |
|
|
model_id, |
|
|
load_in_8bit=use_8bit, |
|
|
load_in_4bit=use_4bit |
|
|
) |
|
|
except GatedRepoError as e: |
|
|
logger.error(f"❌ Cannot access gated repository {model_id}") |
|
|
logger.error(f" Visit https://huggingface.co/{model_id.split(':')[0] if ':' in model_id else model_id} to request access.") |
|
|
|
|
|
|
|
|
if is_fallback_attempt: |
|
|
logger.error("❌ Fallback model also failed with gated repository error") |
|
|
raise RuntimeError("Both primary and fallback models are gated repositories") from e |
|
|
|
|
|
|
|
|
fallback_model_id = model_config.get("fallback") |
|
|
if fallback_model_id and fallback_model_id != model_id: |
|
|
logger.warning(f"Attempting fallback model: {fallback_model_id}") |
|
|
try: |
|
|
|
|
|
fallback_config = model_config.copy() |
|
|
fallback_config["model_id"] = fallback_model_id |
|
|
fallback_config.pop("fallback", None) |
|
|
|
|
|
|
|
|
return await self._call_local_model( |
|
|
fallback_config, |
|
|
prompt, |
|
|
task_type, |
|
|
**{**kwargs, '_is_fallback': True} |
|
|
) |
|
|
except GatedRepoError as fallback_gated_error: |
|
|
logger.error(f"❌ Fallback model {fallback_model_id} is also gated") |
|
|
raise RuntimeError("Both primary and fallback models are gated repositories") from fallback_gated_error |
|
|
except Exception as fallback_error: |
|
|
logger.error(f"Fallback model also failed: {fallback_error}") |
|
|
raise |
|
|
else: |
|
|
raise RuntimeError(f"Model {model_id} is a gated repository and no fallback available") from e |
|
|
|
|
|
|
|
|
messages = [{"role": "user", "content": prompt}] |
|
|
|
|
|
|
|
|
result = await asyncio.to_thread( |
|
|
self.local_loader.generate_chat_completion, |
|
|
model_id=model_id, |
|
|
messages=messages, |
|
|
max_tokens=max_tokens, |
|
|
temperature=temperature |
|
|
) |
|
|
|
|
|
logger.info(f"Local model {model_id} generated response (length: {len(result)})") |
|
|
logger.info("=" * 80) |
|
|
logger.info("LOCAL MODEL RESPONSE:") |
|
|
logger.info("=" * 80) |
|
|
logger.info(f"Model: {model_id}") |
|
|
logger.info(f"Task Type: {task_type}") |
|
|
logger.info(f"Response Length: {len(result)} characters") |
|
|
logger.info("-" * 40) |
|
|
logger.info("FULL RESPONSE CONTENT:") |
|
|
logger.info("-" * 40) |
|
|
logger.info(result) |
|
|
logger.info("-" * 40) |
|
|
logger.info("END OF RESPONSE") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
return result |
|
|
|
|
|
except GatedRepoError: |
|
|
|
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"Error calling local model: {e}", exc_info=True) |
|
|
raise |
|
|
|
|
|
async def _call_local_embedding(self, model_config: dict, text: str, **kwargs) -> Optional[list]: |
|
|
"""Call local embedding model.""" |
|
|
if not self.local_loader: |
|
|
raise RuntimeError("Local model loader not available") |
|
|
|
|
|
model_id = model_config["model_id"] |
|
|
|
|
|
try: |
|
|
|
|
|
if model_id not in self.local_loader.loaded_embedding_models: |
|
|
logger.info(f"Loading embedding model {model_id} on demand...") |
|
|
try: |
|
|
self.local_loader.load_embedding_model(model_id) |
|
|
except GatedRepoError as e: |
|
|
logger.error(f"❌ Cannot access gated repository {model_id}") |
|
|
logger.error(f" Visit https://huggingface.co/{model_id.split(':')[0] if ':' in model_id else model_id} to request access.") |
|
|
raise RuntimeError(f"Embedding model {model_id} is a gated repository") from e |
|
|
|
|
|
|
|
|
embedding = await asyncio.to_thread( |
|
|
self.local_loader.get_embedding, |
|
|
model_id=model_id, |
|
|
text=text |
|
|
) |
|
|
|
|
|
logger.info(f"Local embedding model {model_id} generated vector (dim: {len(embedding)})") |
|
|
return embedding |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error calling local embedding model: {e}", exc_info=True) |
|
|
raise |
|
|
|
|
|
def _select_model(self, task_type: str) -> dict: |
|
|
model_map = { |
|
|
"intent_classification": LLM_CONFIG["models"]["classification_specialist"], |
|
|
"embedding_generation": LLM_CONFIG["models"]["embedding_specialist"], |
|
|
"safety_check": LLM_CONFIG["models"]["safety_checker"], |
|
|
"general_reasoning": LLM_CONFIG["models"]["reasoning_primary"], |
|
|
"response_synthesis": LLM_CONFIG["models"]["reasoning_primary"] |
|
|
} |
|
|
return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_available_models(self): |
|
|
""" |
|
|
Get list of available models for testing |
|
|
""" |
|
|
return list(LLM_CONFIG["models"].keys()) |
|
|
|
|
|
async def health_check(self): |
|
|
""" |
|
|
Perform health check on local models only |
|
|
""" |
|
|
health_status = {} |
|
|
if not self.local_loader: |
|
|
return {"error": "Local model loader not available"} |
|
|
|
|
|
for model_name, model_config in LLM_CONFIG["models"].items(): |
|
|
model_id = model_config["model_id"] |
|
|
|
|
|
is_loaded = model_id in self.local_loader.loaded_models or model_id in self.local_loader.loaded_embedding_models |
|
|
health_status[model_name] = { |
|
|
"model_id": model_id, |
|
|
"loaded": is_loaded, |
|
|
"healthy": is_loaded |
|
|
} |
|
|
|
|
|
return health_status |
|
|
|
|
|
def prepare_context_for_llm(self, raw_context: Dict, max_tokens: int = 4000) -> str: |
|
|
"""Smart context windowing for LLM calls""" |
|
|
|
|
|
try: |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
|
|
|
if not hasattr(self, 'tokenizer'): |
|
|
try: |
|
|
|
|
|
primary_model_id = LLM_CONFIG["models"]["reasoning_primary"]["model_id"] |
|
|
|
|
|
base_model_id = primary_model_id.split(':')[0] if ':' in primary_model_id else primary_model_id |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(base_model_id) |
|
|
except GatedRepoError as e: |
|
|
logger.warning(f"Gated repository error loading tokenizer: {e}") |
|
|
logger.warning("Using character count estimation instead") |
|
|
self.tokenizer = None |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not load tokenizer: {e}, using character count estimation") |
|
|
self.tokenizer = None |
|
|
except ImportError: |
|
|
logger.warning("transformers library not available, using character count estimation") |
|
|
self.tokenizer = None |
|
|
|
|
|
|
|
|
priority_elements = [ |
|
|
('current_query', 1.0), |
|
|
('recent_interactions', 0.8), |
|
|
('user_preferences', 0.6), |
|
|
('session_summary', 0.4), |
|
|
('historical_context', 0.2) |
|
|
] |
|
|
|
|
|
formatted_context = [] |
|
|
total_tokens = 0 |
|
|
|
|
|
for element, priority in priority_elements: |
|
|
|
|
|
element_key_map = { |
|
|
'current_query': raw_context.get('user_input', ''), |
|
|
'recent_interactions': raw_context.get('interaction_contexts', []), |
|
|
'user_preferences': raw_context.get('preferences', {}), |
|
|
'session_summary': raw_context.get('session_context', {}), |
|
|
'historical_context': raw_context.get('user_context', '') |
|
|
} |
|
|
|
|
|
content = element_key_map.get(element, '') |
|
|
|
|
|
|
|
|
if isinstance(content, dict): |
|
|
content = str(content) |
|
|
elif isinstance(content, list): |
|
|
content = "\n".join([str(item) for item in content[:10]]) |
|
|
|
|
|
if not content: |
|
|
continue |
|
|
|
|
|
|
|
|
if self.tokenizer: |
|
|
try: |
|
|
tokens = len(self.tokenizer.encode(content)) |
|
|
except: |
|
|
|
|
|
tokens = len(content) // 4 |
|
|
else: |
|
|
|
|
|
tokens = len(content) // 4 |
|
|
|
|
|
if total_tokens + tokens <= max_tokens: |
|
|
formatted_context.append(f"=== {element.upper()} ===\n{content}") |
|
|
total_tokens += tokens |
|
|
elif priority > 0.5: |
|
|
available = max_tokens - total_tokens |
|
|
if available > 100: |
|
|
truncated = self._truncate_to_tokens(content, available) |
|
|
formatted_context.append(f"=== {element.upper()} (TRUNCATED) ===\n{truncated}") |
|
|
break |
|
|
|
|
|
return "\n\n".join(formatted_context) |
|
|
|
|
|
def _truncate_to_tokens(self, content: str, max_tokens: int) -> str: |
|
|
"""Truncate content to fit within token limit""" |
|
|
if not self.tokenizer: |
|
|
|
|
|
max_chars = max_tokens * 4 |
|
|
if len(content) <= max_chars: |
|
|
return content |
|
|
return content[:max_chars-3] + "..." |
|
|
|
|
|
try: |
|
|
|
|
|
tokens = self.tokenizer.encode(content) |
|
|
if len(tokens) <= max_tokens: |
|
|
return content |
|
|
|
|
|
truncated_tokens = tokens[:max_tokens-3] |
|
|
truncated_text = self.tokenizer.decode(truncated_tokens) |
|
|
return truncated_text + "..." |
|
|
except Exception as e: |
|
|
logger.warning(f"Error truncating with tokenizer: {e}, using character truncation") |
|
|
max_chars = max_tokens * 4 |
|
|
if len(content) <= max_chars: |
|
|
return content |
|
|
return content[:max_chars-3] + "..." |