bielik_app_service / app /models /huggingface_inference_api.py
Patryk Studzinski
first-imrpvement-commit
a7fd202
raw
history blame
2.98 kB
"""
HuggingFace Inference API client for remote model access.
"""
import os
from typing import List, Dict, Any, Optional
from huggingface_hub import InferenceClient
from app.models.base_llm import BaseLLM
class HuggingFaceInferenceAPI(BaseLLM):
"""
Remote model access via HuggingFace Inference API.
Best for larger models (7B+) that don't fit in local RAM.
"""
def __init__(self, name: str, model_id: str, token: str = None):
super().__init__(name, model_id)
self.token = token or os.getenv("HF_TOKEN")
self.client: Optional[InferenceClient] = None
async def initialize(self) -> None:
"""Initialize the Inference API client."""
if self._initialized:
return
try:
print(f"[{self.name}] Initializing Inference API for: {self.model_id}")
self.client = InferenceClient(
model=self.model_id,
token=self.token
)
self._initialized = True
print(f"[{self.name}] Inference API ready")
except Exception as e:
print(f"[{self.name}] Failed to initialize: {e}")
raise
async def generate(
self,
prompt: str = None,
chat_messages: List[Dict[str, str]] = None,
max_new_tokens: int = 150,
temperature: float = 0.7,
top_p: float = 0.9,
**kwargs
) -> str:
"""Generate text using HuggingFace Inference API."""
if not self._initialized or not self.client:
raise RuntimeError(f"[{self.name}] Client not initialized")
try:
# Use chat completion if chat_messages provided
if chat_messages:
response = self.client.chat_completion(
messages=chat_messages,
max_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
)
return response.choices[0].message.content.strip()
# Otherwise use text generation
elif prompt:
response = self.client.text_generation(
prompt=prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
)
return response.strip()
else:
raise ValueError("Either prompt or chat_messages required")
except Exception as e:
print(f"[{self.name}] Generation error: {e}")
raise
def get_info(self) -> Dict[str, Any]:
"""Return model info."""
return {
"name": self.name,
"model_id": self.model_id,
"type": "inference_api",
"initialized": self._initialized,
}