import os import time import asyncio import importlib from fastapi import FastAPI, HTTPException, Depends, Body from typing import Optional, List from pydantic import ValidationError from app.models.registry import registry, MODEL_CONFIG from fastapi.middleware.cors import CORSMiddleware from app.schemas.schemas import ( EnhancedDescriptionResponse, CompareRequest, CompareResponse, ModelResult, ModelInfo, InfillRequest, InfillResponse, InfillResult, GapFill, CompareInfillRequest, CompareInfillResponse, ModelInfillResult, ) from app.logic.infill_utils import ( detect_gaps, parse_infill_json, apply_fills, build_fills_dict, normalize_gaps_to_tagged, ) from app.auth.placeholder_auth import get_authenticated_user app = FastAPI( title="Multi-Model Description Enhancer", description="AI-powered service for enhancing descriptions using multiple LLMs for A/B testing", version="3.0.0" ) # CORS configuration app.add_middleware( CORSMiddleware, allow_origins=[ "http://localhost:5173", "http://localhost:5174", os.getenv("FRONTEND_URL", "http://localhost:5173") ], allow_credentials=True, allow_methods=["POST", "GET"], allow_headers=["*"], ) @app.on_event("startup") async def startup_event(): """ Startup event - models are loaded lazily on first request. No models are pre-loaded to conserve memory. """ print("Application started. Models will be loaded lazily on first request.") print(f"Available models: {registry.get_available_model_names()}") # --- Helper function to load domain logic --- def get_domain_config(domain: str): try: module = importlib.import_module(f"app.domains.{domain}.config") return module.domain_config except (ImportError, AttributeError): raise HTTPException(status_code=404, detail=f"Domain '{domain}' not found or not configured correctly.") # --- API Endpoints --- @app.get("/") async def read_root(): return {"message": "Welcome to the Multi-Model Description Enhancer API! Go to /docs for documentation."} @app.get("/health") async def health_check(): """Check API health and model status.""" models = registry.list_models() loaded_models = registry.get_loaded_models() active_model = registry.get_active_model() return { "status": "ok", "available_models": len(models), "loaded_models": loaded_models, "active_local_model": active_model, } @app.get("/models", response_model=List[ModelInfo]) async def list_models(): """List all available models with their load status.""" return registry.list_models() @app.post("/models/{model_name}/load") async def load_model(model_name: str): """ Explicitly load a model into memory. For local models: unloads any previously loaded local model first. """ if model_name not in registry.get_available_model_names(): raise HTTPException(status_code=404, detail=f"Unknown model: {model_name}") try: info = await registry.load_model(model_name) return {"status": "loaded", "model": info} except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to load model: {str(e)}") @app.post("/models/{model_name}/unload") async def unload_model(model_name: str): """ Explicitly unload a model from memory to free resources. """ if model_name not in registry.get_available_model_names(): raise HTTPException(status_code=404, detail=f"Unknown model: {model_name}") try: result = await registry.unload_model(model_name) return result except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to unload model: {str(e)}") @app.post("/enhance-description", response_model=EnhancedDescriptionResponse) async def enhance_description( domain: str = Body(..., embed=True), data: dict = Body(..., embed=True), model: str = Body("bielik-1.5b", embed=True), user: Optional[dict] = Depends(get_authenticated_user) ): """ Generate an enhanced description using a single model. - **domain**: The name of the domain (e.g., 'cars'). - **data**: A dictionary with the data for the description. - **model**: Model to use (default: bielik-1.5b) """ start_time = time.time() # Validate model if model not in registry.get_available_model_names(): raise HTTPException(status_code=400, detail=f"Unknown model: {model}") # Load Domain Configuration domain_config = get_domain_config(domain) DomainSchema = domain_config["schema"] create_prompt = domain_config["create_prompt"] # Validate Input Data try: validated_data = DomainSchema(**data) except ValidationError as e: raise HTTPException(status_code=422, detail=f"Invalid data for domain '{domain}': {e}") # Prompt Construction chat_messages = create_prompt(validated_data) # Text Generation try: llm = await registry.get_model(model) generated_description = await llm.generate( chat_messages=chat_messages, max_new_tokens=150, temperature=0.75, top_p=0.9, ) except Exception as e: print(f"Error during text generation with {model}: {e}") raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}") generation_time = time.time() - start_time user_email = user['email'] if user else "anonymous" return EnhancedDescriptionResponse( description=generated_description, model_used=MODEL_CONFIG[model]["id"], generation_time=round(generation_time, 2), user_email=user_email ) @app.post("/compare", response_model=CompareResponse) async def compare_models( request: CompareRequest, user: Optional[dict] = Depends(get_authenticated_user) ): """ Compare outputs from multiple models for the same input. Returns results from all specified models (or all available if not specified). """ total_start = time.time() # Get models to compare available_models = registry.get_available_model_names() models_to_use = request.models if request.models else available_models # Validate requested models for model in models_to_use: if model not in available_models: raise HTTPException(status_code=400, detail=f"Unknown model: {model}") # Load Domain Configuration domain_config = get_domain_config(request.domain) DomainSchema = domain_config["schema"] create_prompt = domain_config["create_prompt"] # Validate Input Data try: validated_data = DomainSchema(**request.data) except ValidationError as e: raise HTTPException(status_code=422, detail=f"Invalid data: {e}") # Prompt Construction chat_messages = create_prompt(validated_data) # Generate with each model results = [] async def generate_with_model(model_name: str) -> ModelResult: start_time = time.time() try: llm = await registry.get_model(model_name) output = await llm.generate( chat_messages=chat_messages, max_new_tokens=150, temperature=0.75, top_p=0.9, ) return ModelResult( model=model_name, output=output, time=round(time.time() - start_time, 2), type=MODEL_CONFIG[model_name]["type"], error=None ) except Exception as e: return ModelResult( model=model_name, output="", time=round(time.time() - start_time, 2), type=MODEL_CONFIG[model_name]["type"], error=str(e) ) # Run all models (sequentially to avoid memory issues) for model_name in models_to_use: result = await generate_with_model(model_name) results.append(result) return CompareResponse( domain=request.domain, results=results, total_time=round(time.time() - total_start, 2) ) @app.get("/user/me") async def get_user_info(user: dict = Depends(get_authenticated_user)): """Get current authenticated user information""" if not user: raise HTTPException(status_code=401, detail="Not authenticated") return { "user_id": user['user_id'], "email": user['email'], "name": user.get('name', 'Unknown') } # --- Batch Infill Endpoints --- @app.post("/infill", response_model=InfillResponse) async def batch_infill( request: InfillRequest, user: Optional[dict] = Depends(get_authenticated_user) ): """ Batch gap-filling for ads using a single model. Accepts items with [GAP:n] markers or ___ and returns filled text with per-gap choices and alternatives. NOTE: For texts > 6000 chars, consider chunking (not yet implemented). """ total_start = time.time() # Validate model if request.model not in registry.get_available_model_names(): raise HTTPException(status_code=400, detail=f"Unknown model: {request.model}") # Load domain config for infill prompt domain_config = get_domain_config(request.domain) if "create_infill_prompt" not in domain_config: raise HTTPException( status_code=400, detail=f"Domain '{request.domain}' does not support infill operations" ) create_infill_prompt = domain_config["create_infill_prompt"] # Process each item results = [] error_count = 0 for item in request.items: result = await process_infill_item( item=item, model_name=request.model, options=request.options, create_infill_prompt=create_infill_prompt ) results.append(result) if result.status == "error": error_count += 1 return InfillResponse( model=request.model, results=results, total_time=round(time.time() - total_start, 2), processed_count=len(results), error_count=error_count ) @app.post("/compare-infill", response_model=CompareInfillResponse) async def compare_infill( request: CompareInfillRequest, user: Optional[dict] = Depends(get_authenticated_user) ): """ Multi-model batch gap-filling comparison for A/B testing. Runs the same batch of items through multiple models and returns per-model results for comparison. """ total_start = time.time() # Get models to compare available_models = registry.get_available_model_names() models_to_use = request.models if request.models else available_models # Validate requested models for model in models_to_use: if model not in available_models: raise HTTPException(status_code=400, detail=f"Unknown model: {model}") # Load domain config domain_config = get_domain_config(request.domain) if "create_infill_prompt" not in domain_config: raise HTTPException( status_code=400, detail=f"Domain '{request.domain}' does not support infill operations" ) create_infill_prompt = domain_config["create_infill_prompt"] # Process with each model (sequentially for memory safety) model_results = [] for model_name in models_to_use: model_start = time.time() results = [] error_count = 0 for item in request.items: result = await process_infill_item( item=item, model_name=model_name, options=request.options, create_infill_prompt=create_infill_prompt ) results.append(result) if result.status == "error": error_count += 1 model_results.append(ModelInfillResult( model=model_name, type=MODEL_CONFIG[model_name]["type"], results=results, time=round(time.time() - model_start, 2), error_count=error_count )) return CompareInfillResponse( domain=request.domain, models=model_results, total_time=round(time.time() - total_start, 2) ) async def process_infill_item( item, model_name: str, options, create_infill_prompt ) -> InfillResult: """ Process a single infill item. Returns InfillResult with status, filled_text, and gaps. """ try: # Normalize gaps to [GAP:n] format normalized_text, gaps = normalize_gaps_to_tagged(item.text_with_gaps) if not gaps: # No gaps found, return original text return InfillResult( id=item.id, status="ok", filled_text=item.text_with_gaps, gaps=[], error=None ) # Build prompt chat_messages = create_infill_prompt(normalized_text, options) # Generate llm = await registry.get_model(model_name) raw_output = await llm.generate( chat_messages=chat_messages, max_new_tokens=options.max_new_tokens, temperature=options.temperature, top_p=0.9, ) # Parse JSON from output parsed = parse_infill_json(raw_output) if not parsed: # JSON parsing failed return InfillResult( id=item.id, status="error", filled_text=None, gaps=[], error=f"Failed to parse JSON from model output: {raw_output[:200]}..." ) # Extract gaps and build result gap_fills = [] fills_dict = {} for gap_data in parsed.get("gaps", []): gap_fill = GapFill( index=gap_data.get("index", 0), marker=gap_data.get("marker", ""), choice=gap_data.get("choice", ""), alternatives=gap_data.get("alternatives", []) ) gap_fills.append(gap_fill) fills_dict[gap_fill.index] = gap_fill.choice # Get filled text - prefer model's version, fallback to reconstruction filled_text = parsed.get("filled_text") if not filled_text and fills_dict: filled_text = apply_fills(normalized_text, gaps, fills_dict) return InfillResult( id=item.id, status="ok", filled_text=filled_text, gaps=gap_fills, error=None ) except Exception as e: return InfillResult( id=item.id, status="error", filled_text=None, gaps=[], error=str(e) )