Spaces:
Sleeping
Sleeping
| """ | |
| HuggingFace Inference API client for remote model access. | |
| """ | |
| import os | |
| from typing import List, Dict, Any, Optional | |
| from huggingface_hub import InferenceClient | |
| from app.models.base_llm import BaseLLM | |
| class HuggingFaceInferenceAPI(BaseLLM): | |
| """ | |
| Remote model access via HuggingFace Inference API. | |
| Best for larger models (7B+) that don't fit in local RAM. | |
| """ | |
| def __init__(self, name: str, model_id: str, token: str = None): | |
| super().__init__(name, model_id) | |
| self.token = token or os.getenv("HF_TOKEN") | |
| self.client: Optional[InferenceClient] = None | |
| async def initialize(self) -> None: | |
| """Initialize the Inference API client.""" | |
| if self._initialized: | |
| return | |
| try: | |
| print(f"[{self.name}] Initializing Inference API for: {self.model_id}") | |
| self.client = InferenceClient( | |
| model=self.model_id, | |
| token=self.token | |
| ) | |
| self._initialized = True | |
| print(f"[{self.name}] Inference API ready") | |
| except Exception as e: | |
| print(f"[{self.name}] Failed to initialize: {e}") | |
| raise | |
| async def generate( | |
| self, | |
| prompt: str = None, | |
| chat_messages: List[Dict[str, str]] = None, | |
| max_new_tokens: int = 150, | |
| temperature: float = 0.7, | |
| top_p: float = 0.9, | |
| **kwargs | |
| ) -> str: | |
| """Generate text using HuggingFace Inference API.""" | |
| if not self._initialized or not self.client: | |
| raise RuntimeError(f"[{self.name}] Client not initialized") | |
| try: | |
| # Use chat completion if chat_messages provided | |
| if chat_messages: | |
| response = self.client.chat_completion( | |
| messages=chat_messages, | |
| max_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ) | |
| return response.choices[0].message.content.strip() | |
| # Otherwise use text generation | |
| elif prompt: | |
| response = self.client.text_generation( | |
| prompt=prompt, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| ) | |
| return response.strip() | |
| else: | |
| raise ValueError("Either prompt or chat_messages required") | |
| except Exception as e: | |
| print(f"[{self.name}] Generation error: {e}") | |
| raise | |
| def get_info(self) -> Dict[str, Any]: | |
| """Return model info.""" | |
| return { | |
| "name": self.name, | |
| "model_id": self.model_id, | |
| "type": "inference_api", | |
| "initialized": self._initialized, | |
| } | |