Patryk Studzinski
pre-downloading-all-models-at-startup
cf748a3
raw
history blame
7.54 kB
"""
Model Registry - Central configuration and factory for all LLM models.
Supports lazy loading and on/off mechanism for memory management.
"""
import os
import gc
from typing import Dict, List, Any, Optional
from app.models.base_llm import BaseLLM
from app.models.huggingface_local import HuggingFaceLocal
from app.models.huggingface_inference_api import HuggingFaceInferenceAPI
# Model configuration - 3 local + 1 API for Polish language comparison
MODEL_CONFIG = {
"bielik-1.5b": {
"id": "speakleash/Bielik-1.5B-v3.0-Instruct",
"local_path": "bielik-1.5b",
"type": "local",
"polish_support": "excellent",
"size": "1.5B",
},
"qwen2.5-3b": {
"id": "Qwen/Qwen2.5-3B-Instruct",
"local_path": "qwen2.5-3b",
"type": "local",
"polish_support": "good",
"size": "3B",
},
"gemma-2-2b": {
"id": "google/gemma-2-2b-it",
"local_path": "gemma-2-2b",
"type": "local",
"polish_support": "medium",
"size": "2B",
},
"pllum-12b": {
"id": "CYFRAGOVPL/PLLuM-12B-instruct",
"type": "inference_api",
"polish_support": "excellent",
"size": "12B",
},
}
# Base path for pre-downloaded models in container
LOCAL_MODEL_BASE = os.getenv("MODEL_DIR", "/app/pretrain_model")
class ModelRegistry:
"""
Central registry for managing all LLM models.
Supports lazy loading (load on first request) and unloading for memory management.
Only one local model is loaded at a time to conserve memory.
"""
def __init__(self):
self._models: Dict[str, BaseLLM] = {}
self._config = MODEL_CONFIG.copy()
self._active_local_model: Optional[str] = None
def _create_model(self, name: str) -> BaseLLM:
"""Factory method to create model instance."""
if name not in self._config:
raise ValueError(f"Unknown model: {name}")
config = self._config[name]
model_type = config["type"]
model_id = config["id"]
# For local models, check if pre-downloaded version exists
if model_type == "local" and "local_path" in config:
local_path = os.path.join(LOCAL_MODEL_BASE, config["local_path"])
if os.path.exists(local_path):
print(f"Using pre-downloaded model at: {local_path}")
model_id = local_path
else:
print(f"Pre-downloaded model not found at {local_path}, will download from HuggingFace")
if model_type == "local":
return HuggingFaceLocal(
name=name,
model_id=model_id,
device="cpu"
)
elif model_type == "inference_api":
return HuggingFaceInferenceAPI(
name=name,
model_id=model_id
)
else:
raise ValueError(f"Unknown model type: {model_type}")
async def _unload_model(self, name: str) -> None:
"""Unload a model from memory."""
if name in self._models:
model = self._models[name]
# Call cleanup if available
if hasattr(model, 'cleanup'):
await model.cleanup()
del self._models[name]
gc.collect() # Force garbage collection
print(f"Model '{name}' unloaded from memory.")
async def _unload_all_local_models(self) -> None:
"""Unload all local models to free memory."""
local_models = [
name for name, config in self._config.items()
if config["type"] == "local" and name in self._models
]
for name in local_models:
await self._unload_model(name)
self._active_local_model = None
async def get_model(self, name: str) -> BaseLLM:
"""
Get a model (lazy loading).
For local models: unloads any previously loaded local model first.
For API models: always available without affecting local models.
"""
if name not in self._config:
raise ValueError(f"Unknown model: {name}")
config = self._config[name]
# If it's a local model, ensure only one is loaded at a time
if config["type"] == "local":
# Unload current local model if different
if self._active_local_model and self._active_local_model != name:
print(f"Switching from '{self._active_local_model}' to '{name}'...")
await self._unload_model(self._active_local_model)
# Load the requested model if not already loaded
if name not in self._models:
print(f"Loading model '{name}'...")
model = self._create_model(name)
await model.initialize()
self._models[name] = model
self._active_local_model = name
print(f"Model '{name}' loaded successfully.")
# For API models, just create/return (no memory concern)
elif config["type"] == "inference_api":
if name not in self._models:
print(f"Initializing API model '{name}'...")
model = self._create_model(name)
await model.initialize()
self._models[name] = model
return self._models[name]
async def load_model(self, name: str) -> Dict[str, Any]:
"""
Explicitly load a model (unloads other local models first).
Returns model info.
"""
await self.get_model(name)
return self.get_model_info(name)
async def unload_model(self, name: str) -> Dict[str, str]:
"""
Explicitly unload a model from memory.
"""
if name not in self._config:
raise ValueError(f"Unknown model: {name}")
if name not in self._models:
return {"status": "not_loaded", "model": name}
await self._unload_model(name)
if self._active_local_model == name:
self._active_local_model = None
return {"status": "unloaded", "model": name}
def get_model_info(self, name: str) -> Dict[str, Any]:
"""Get info about a specific model."""
if name not in self._config:
raise ValueError(f"Unknown model: {name}")
config = self._config[name]
return {
"name": name,
"model_id": config["id"],
"type": config["type"],
"polish_support": config["polish_support"],
"size": config["size"],
"loaded": name in self._models,
"active": name == self._active_local_model if config["type"] == "local" else None,
}
def list_models(self) -> List[Dict[str, Any]]:
"""List all available models with their info."""
return [self.get_model_info(name) for name in self._config.keys()]
def get_available_model_names(self) -> List[str]:
"""Get list of available model names."""
return list(self._config.keys())
def get_active_model(self) -> Optional[str]:
"""Get the currently active (loaded) local model name."""
return self._active_local_model
def get_loaded_models(self) -> List[str]:
"""Get list of currently loaded model names."""
return list(self._models.keys())
# Global registry instance
registry = ModelRegistry()