Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import subprocess | |
| import signal | |
| import os | |
| import requests | |
| import time | |
| from typing import Optional | |
| from duckduckgo_search import DDGS | |
| from bs4 import BeautifulSoup | |
| app = FastAPI() | |
| # Predefined list of available models (TheBloke only - verified, fits 18GB Space) | |
| AVAILABLE_MODELS = { | |
| # === General Purpose (Default) === | |
| "deepseek-chat": "TheBloke/deepseek-llm-7B-chat-GGUF:deepseek-llm-7b-chat.Q4_K_M.gguf", | |
| # === Financial & Summarization Models === | |
| "mistral-7b": "TheBloke/Mistral-7B-Instruct-v0.2-GGUF:mistral-7b-instruct-v0.2.Q4_K_M.gguf", | |
| "openhermes-7b": "TheBloke/OpenHermes-2.5-Mistral-7B-GGUF:openhermes-2.5-mistral-7b.Q4_K_M.gguf", | |
| # === Coding Models === | |
| "deepseek-coder": "TheBloke/deepseek-coder-6.7B-instruct-GGUF:deepseek-coder-6.7b-instruct.Q4_K_M.gguf", | |
| # === Lightweight/Fast === | |
| "llama-7b": "TheBloke/Llama-2-7B-Chat-GGUF:llama-2-7b-chat.Q4_K_M.gguf", | |
| } | |
| # Global state | |
| current_model = "deepseek-chat" # Default model | |
| llama_process: Optional[subprocess.Popen] = None | |
| LLAMA_SERVER_PORT = 8080 | |
| LLAMA_SERVER_URL = f"http://localhost:{LLAMA_SERVER_PORT}" | |
| class ModelSwitchRequest(BaseModel): | |
| model_name: str | |
| class ChatCompletionRequest(BaseModel): | |
| messages: list[dict] | |
| max_tokens: int = 256 | |
| temperature: float = 0.7 | |
| class WebChatRequest(BaseModel): | |
| messages: list[dict] | |
| max_tokens: int = 512 | |
| temperature: float = 0.7 | |
| max_search_results: int = 5 | |
| def start_llama_server(model_id: str) -> subprocess.Popen: | |
| """Start llama-server with specified model (optimized for speed).""" | |
| cmd = [ | |
| "llama-server", | |
| "-hf", model_id, | |
| "--host", "0.0.0.0", | |
| "--port", str(LLAMA_SERVER_PORT), | |
| "-c", "2048", # Context size | |
| "-t", "4", # CPU threads (adjust based on cores) | |
| "-ngl", "0", # GPU layers (0 for CPU-only) | |
| "--cont-batching", # Enable continuous batching for speed | |
| "-b", "512", # Batch size | |
| ] | |
| print(f"Starting llama-server with model: {model_id}") | |
| print("This may take 2-3 minutes to download and load the model...") | |
| process = subprocess.Popen( | |
| cmd, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| preexec_fn=os.setsid if os.name != 'nt' else None, | |
| text=True, | |
| bufsize=1 | |
| ) | |
| # Wait for server to be ready (increased timeout for model download) | |
| max_retries = 300 # 5 minutes | |
| for i in range(max_retries): | |
| # Check if process died | |
| if process.poll() is not None: | |
| stdout, _ = process.communicate() | |
| print(f"llama-server exited with code {process.returncode}") | |
| print(f"Output: {stdout}") | |
| raise RuntimeError("llama-server process died") | |
| try: | |
| # Try root endpoint instead of /health | |
| response = requests.get(f"{LLAMA_SERVER_URL}/", timeout=2) | |
| if response.status_code in [200, 404]: # 404 is ok, means server is up | |
| print(f"llama-server ready after {i+1} seconds") | |
| return process | |
| except requests.exceptions.ConnectionError: | |
| # Server not ready yet | |
| pass | |
| except Exception as e: | |
| # Other errors, keep waiting | |
| pass | |
| time.sleep(1) | |
| raise RuntimeError("llama-server failed to start within 5 minutes") | |
| def stop_llama_server(): | |
| """Stop the running llama-server.""" | |
| global llama_process | |
| if llama_process: | |
| print("Stopping llama-server...") | |
| try: | |
| if os.name != 'nt': | |
| os.killpg(os.getpgid(llama_process.pid), signal.SIGTERM) | |
| else: | |
| llama_process.terminate() | |
| llama_process.wait(timeout=10) | |
| except: | |
| if os.name != 'nt': | |
| os.killpg(os.getpgid(llama_process.pid), signal.SIGKILL) | |
| else: | |
| llama_process.kill() | |
| llama_process = None | |
| time.sleep(2) # Give it time to fully shut down | |
| async def startup_event(): | |
| """Start with default model.""" | |
| global llama_process | |
| model_id = AVAILABLE_MODELS[current_model] | |
| llama_process = start_llama_server(model_id) | |
| async def shutdown_event(): | |
| """Clean shutdown.""" | |
| stop_llama_server() | |
| async def root(): | |
| return { | |
| "status": "DeepSeek API with dynamic model switching", | |
| "current_model": current_model, | |
| "available_models": list(AVAILABLE_MODELS.keys()) | |
| } | |
| async def list_models(): | |
| """List all available models.""" | |
| return { | |
| "current_model": current_model, | |
| "available_models": list(AVAILABLE_MODELS.keys()) | |
| } | |
| async def switch_model(request: ModelSwitchRequest): | |
| """Switch to a different model.""" | |
| global current_model, llama_process | |
| if request.model_name not in AVAILABLE_MODELS: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Model '{request.model_name}' not found. Available: {list(AVAILABLE_MODELS.keys())}" | |
| ) | |
| if request.model_name == current_model: | |
| return {"message": f"Already using model: {current_model}"} | |
| # Stop current server | |
| stop_llama_server() | |
| # Start with new model | |
| model_id = AVAILABLE_MODELS[request.model_name] | |
| llama_process = start_llama_server(model_id) | |
| current_model = request.model_name | |
| return { | |
| "message": f"Switched to model: {current_model}", | |
| "model": current_model | |
| } | |
| async def chat_completions(request: ChatCompletionRequest): | |
| """OpenAI-compatible chat completions endpoint.""" | |
| try: | |
| # Forward to llama-server | |
| response = requests.post( | |
| f"{LLAMA_SERVER_URL}/v1/chat/completions", | |
| json={ | |
| "messages": request.messages, | |
| "max_tokens": request.max_tokens, | |
| "temperature": request.temperature, | |
| }, | |
| timeout=300 | |
| ) | |
| response.raise_for_status() | |
| return response.json() | |
| except requests.exceptions.RequestException as e: | |
| raise HTTPException(status_code=500, detail=f"llama-server error: {str(e)}") | |
| def search_web(query: str, max_results: int = 5) -> list[dict]: | |
| """Search the web using DuckDuckGo and return results.""" | |
| try: | |
| with DDGS() as ddgs: | |
| results = list(ddgs.text(query, max_results=max_results)) | |
| return results | |
| except Exception as e: | |
| print(f"Search error: {e}") | |
| return [] | |
| def format_search_context(query: str, search_results: list[dict]) -> str: | |
| """Format search results into context for the LLM.""" | |
| if not search_results: | |
| return f"No web results found for: {query}" | |
| context = f"# Web Search Results for: {query}\n\n" | |
| for i, result in enumerate(search_results, 1): | |
| title = result.get("title", "No title") | |
| body = result.get("body", "No description") | |
| url = result.get("href", "") | |
| context += f"## Result {i}: {title}\n" | |
| context += f"{body}\n" | |
| if url: | |
| context += f"Source: {url}\n" | |
| context += "\n" | |
| return context | |
| async def web_chat_completions(request: WebChatRequest): | |
| """ | |
| Chat completions with web search augmentation. | |
| The last user message is used as the search query. | |
| Search results are injected into the context before sending to the LLM. | |
| """ | |
| try: | |
| # Get the last user message as search query | |
| user_messages = [msg for msg in request.messages if msg.get("role") == "user"] | |
| if not user_messages: | |
| raise HTTPException(status_code=400, detail="No user message found") | |
| search_query = user_messages[-1].get("content", "") | |
| # Perform web search | |
| print(f"Searching web for: {search_query}") | |
| search_results = search_web(search_query, request.max_search_results) | |
| # Format search results as context | |
| web_context = format_search_context(search_query, search_results) | |
| # Create augmented messages with web context | |
| augmented_messages = request.messages.copy() | |
| # Insert web context as a system message before the last user message | |
| system_prompt = { | |
| "role": "system", | |
| "content": f"""You are a helpful assistant with access to current web information. | |
| {web_context} | |
| Use the above search results to provide accurate, up-to-date information in your response. | |
| Always cite sources when using information from the search results.""" | |
| } | |
| # Insert system message before the last user message | |
| augmented_messages.insert(-1, system_prompt) | |
| # Forward to llama-server with augmented context | |
| response = requests.post( | |
| f"{LLAMA_SERVER_URL}/v1/chat/completions", | |
| json={ | |
| "messages": augmented_messages, | |
| "max_tokens": request.max_tokens, | |
| "temperature": request.temperature, | |
| }, | |
| timeout=300 | |
| ) | |
| response.raise_for_status() | |
| result = response.json() | |
| # Add metadata about search results | |
| result["web_search"] = { | |
| "query": search_query, | |
| "results_count": len(search_results), | |
| "sources": [r.get("href", "") for r in search_results if r.get("href")] | |
| } | |
| return result | |
| except requests.exceptions.RequestException as e: | |
| raise HTTPException(status_code=500, detail=f"llama-server error: {str(e)}") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error: {str(e)}") |