Spaces:
Sleeping
Sleeping
| """ | |
| Model Registry - Central configuration and factory for all LLM models. | |
| Supports lazy loading and on/off mechanism for memory management. | |
| """ | |
| import os | |
| import gc | |
| from typing import Dict, List, Any, Optional | |
| from app.models.base_llm import BaseLLM | |
| from app.models.huggingface_local import HuggingFaceLocal | |
| from app.models.huggingface_inference_api import HuggingFaceInferenceAPI | |
| # Model configuration - 3 local + 1 API for Polish language comparison | |
| MODEL_CONFIG = { | |
| "bielik-1.5b": { | |
| "id": "speakleash/Bielik-1.5B-v3.0-Instruct", | |
| "local_path": "bielik-1.5b", | |
| "type": "local", | |
| "polish_support": "excellent", | |
| "size": "1.5B", | |
| }, | |
| "qwen2.5-3b": { | |
| "id": "Qwen/Qwen2.5-3B-Instruct", | |
| "local_path": "qwen2.5-3b", | |
| "type": "local", | |
| "polish_support": "good", | |
| "size": "3B", | |
| }, | |
| "gemma-2-2b": { | |
| "id": "google/gemma-2-2b-it", | |
| "local_path": "gemma-2-2b", | |
| "type": "local", | |
| "polish_support": "medium", | |
| "size": "2B", | |
| }, | |
| "pllum-12b": { | |
| "id": "CYFRAGOVPL/PLLuM-12B-instruct", | |
| "type": "inference_api", | |
| "polish_support": "excellent", | |
| "size": "12B", | |
| }, | |
| } | |
| # Base path for pre-downloaded models in container | |
| LOCAL_MODEL_BASE = os.getenv("MODEL_DIR", "/app/pretrain_model") | |
| class ModelRegistry: | |
| """ | |
| Central registry for managing all LLM models. | |
| Supports lazy loading (load on first request) and unloading for memory management. | |
| Only one local model is loaded at a time to conserve memory. | |
| """ | |
| def __init__(self): | |
| self._models: Dict[str, BaseLLM] = {} | |
| self._config = MODEL_CONFIG.copy() | |
| self._active_local_model: Optional[str] = None | |
| def _create_model(self, name: str) -> BaseLLM: | |
| """Factory method to create model instance.""" | |
| if name not in self._config: | |
| raise ValueError(f"Unknown model: {name}") | |
| config = self._config[name] | |
| model_type = config["type"] | |
| model_id = config["id"] | |
| # For local models, check if pre-downloaded version exists | |
| if model_type == "local" and "local_path" in config: | |
| local_path = os.path.join(LOCAL_MODEL_BASE, config["local_path"]) | |
| if os.path.exists(local_path): | |
| print(f"Using pre-downloaded model at: {local_path}") | |
| model_id = local_path | |
| else: | |
| print(f"Pre-downloaded model not found at {local_path}, will download from HuggingFace") | |
| if model_type == "local": | |
| return HuggingFaceLocal( | |
| name=name, | |
| model_id=model_id, | |
| device="cpu" | |
| ) | |
| elif model_type == "inference_api": | |
| return HuggingFaceInferenceAPI( | |
| name=name, | |
| model_id=model_id | |
| ) | |
| else: | |
| raise ValueError(f"Unknown model type: {model_type}") | |
| async def _unload_model(self, name: str) -> None: | |
| """Unload a model from memory.""" | |
| if name in self._models: | |
| model = self._models[name] | |
| # Call cleanup if available | |
| if hasattr(model, 'cleanup'): | |
| await model.cleanup() | |
| del self._models[name] | |
| gc.collect() # Force garbage collection | |
| print(f"Model '{name}' unloaded from memory.") | |
| async def _unload_all_local_models(self) -> None: | |
| """Unload all local models to free memory.""" | |
| local_models = [ | |
| name for name, config in self._config.items() | |
| if config["type"] == "local" and name in self._models | |
| ] | |
| for name in local_models: | |
| await self._unload_model(name) | |
| self._active_local_model = None | |
| async def get_model(self, name: str) -> BaseLLM: | |
| """ | |
| Get a model (lazy loading). | |
| For local models: unloads any previously loaded local model first. | |
| For API models: always available without affecting local models. | |
| """ | |
| if name not in self._config: | |
| raise ValueError(f"Unknown model: {name}") | |
| config = self._config[name] | |
| # If it's a local model, ensure only one is loaded at a time | |
| if config["type"] == "local": | |
| # Unload current local model if different | |
| if self._active_local_model and self._active_local_model != name: | |
| print(f"Switching from '{self._active_local_model}' to '{name}'...") | |
| await self._unload_model(self._active_local_model) | |
| # Load the requested model if not already loaded | |
| if name not in self._models: | |
| print(f"Loading model '{name}'...") | |
| model = self._create_model(name) | |
| await model.initialize() | |
| self._models[name] = model | |
| self._active_local_model = name | |
| print(f"Model '{name}' loaded successfully.") | |
| # For API models, just create/return (no memory concern) | |
| elif config["type"] == "inference_api": | |
| if name not in self._models: | |
| print(f"Initializing API model '{name}'...") | |
| model = self._create_model(name) | |
| await model.initialize() | |
| self._models[name] = model | |
| return self._models[name] | |
| async def load_model(self, name: str) -> Dict[str, Any]: | |
| """ | |
| Explicitly load a model (unloads other local models first). | |
| Returns model info. | |
| """ | |
| await self.get_model(name) | |
| return self.get_model_info(name) | |
| async def unload_model(self, name: str) -> Dict[str, str]: | |
| """ | |
| Explicitly unload a model from memory. | |
| """ | |
| if name not in self._config: | |
| raise ValueError(f"Unknown model: {name}") | |
| if name not in self._models: | |
| return {"status": "not_loaded", "model": name} | |
| await self._unload_model(name) | |
| if self._active_local_model == name: | |
| self._active_local_model = None | |
| return {"status": "unloaded", "model": name} | |
| def get_model_info(self, name: str) -> Dict[str, Any]: | |
| """Get info about a specific model.""" | |
| if name not in self._config: | |
| raise ValueError(f"Unknown model: {name}") | |
| config = self._config[name] | |
| return { | |
| "name": name, | |
| "model_id": config["id"], | |
| "type": config["type"], | |
| "polish_support": config["polish_support"], | |
| "size": config["size"], | |
| "loaded": name in self._models, | |
| "active": name == self._active_local_model if config["type"] == "local" else None, | |
| } | |
| def list_models(self) -> List[Dict[str, Any]]: | |
| """List all available models with their info.""" | |
| return [self.get_model_info(name) for name in self._config.keys()] | |
| def get_available_model_names(self) -> List[str]: | |
| """Get list of available model names.""" | |
| return list(self._config.keys()) | |
| def get_active_model(self) -> Optional[str]: | |
| """Get the currently active (loaded) local model name.""" | |
| return self._active_local_model | |
| def get_loaded_models(self) -> List[str]: | |
| """Get list of currently loaded model names.""" | |
| return list(self._models.keys()) | |
| # Global registry instance | |
| registry = ModelRegistry() | |