""" HuggingFace Inference API client for remote model access. """ import os from typing import List, Dict, Any, Optional from huggingface_hub import InferenceClient from app.models.base_llm import BaseLLM class HuggingFaceInferenceAPI(BaseLLM): """ Remote model access via HuggingFace Inference API. Best for larger models (7B+) that don't fit in local RAM. """ def __init__(self, name: str, model_id: str, token: str = None): super().__init__(name, model_id) self.token = token or os.getenv("HF_TOKEN") self.client: Optional[InferenceClient] = None async def initialize(self) -> None: """Initialize the Inference API client.""" if self._initialized: return try: print(f"[{self.name}] Initializing Inference API for: {self.model_id}") self.client = InferenceClient( model=self.model_id, token=self.token ) self._initialized = True print(f"[{self.name}] Inference API ready") except Exception as e: print(f"[{self.name}] Failed to initialize: {e}") raise async def generate( self, prompt: str = None, chat_messages: List[Dict[str, str]] = None, max_new_tokens: int = 150, temperature: float = 0.7, top_p: float = 0.9, **kwargs ) -> str: """Generate text using HuggingFace Inference API.""" if not self._initialized or not self.client: raise RuntimeError(f"[{self.name}] Client not initialized") try: # Use chat completion if chat_messages provided if chat_messages: response = self.client.chat_completion( messages=chat_messages, max_tokens=max_new_tokens, temperature=temperature, top_p=top_p, ) return response.choices[0].message.content.strip() # Otherwise use text generation elif prompt: response = self.client.text_generation( prompt=prompt, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=True, ) return response.strip() else: raise ValueError("Either prompt or chat_messages required") except Exception as e: print(f"[{self.name}] Generation error: {e}") raise def get_info(self) -> Dict[str, Any]: """Return model info.""" return { "name": self.name, "model_id": self.model_id, "type": "inference_api", "initialized": self._initialized, }