File size: 2,979 Bytes
a7fd202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""
HuggingFace Inference API client for remote model access.
"""

import os
from typing import List, Dict, Any, Optional
from huggingface_hub import InferenceClient

from app.models.base_llm import BaseLLM


class HuggingFaceInferenceAPI(BaseLLM):
    """
    Remote model access via HuggingFace Inference API.
    Best for larger models (7B+) that don't fit in local RAM.
    """
    
    def __init__(self, name: str, model_id: str, token: str = None):
        super().__init__(name, model_id)
        self.token = token or os.getenv("HF_TOKEN")
        self.client: Optional[InferenceClient] = None
    
    async def initialize(self) -> None:
        """Initialize the Inference API client."""
        if self._initialized:
            return
        
        try:
            print(f"[{self.name}] Initializing Inference API for: {self.model_id}")
            
            self.client = InferenceClient(
                model=self.model_id,
                token=self.token
            )
            
            self._initialized = True
            print(f"[{self.name}] Inference API ready")
            
        except Exception as e:
            print(f"[{self.name}] Failed to initialize: {e}")
            raise
    
    async def generate(
        self,
        prompt: str = None,
        chat_messages: List[Dict[str, str]] = None,
        max_new_tokens: int = 150,
        temperature: float = 0.7,
        top_p: float = 0.9,
        **kwargs
    ) -> str:
        """Generate text using HuggingFace Inference API."""
        
        if not self._initialized or not self.client:
            raise RuntimeError(f"[{self.name}] Client not initialized")
        
        try:
            # Use chat completion if chat_messages provided
            if chat_messages:
                response = self.client.chat_completion(
                    messages=chat_messages,
                    max_tokens=max_new_tokens,
                    temperature=temperature,
                    top_p=top_p,
                )
                return response.choices[0].message.content.strip()
            
            # Otherwise use text generation
            elif prompt:
                response = self.client.text_generation(
                    prompt=prompt,
                    max_new_tokens=max_new_tokens,
                    temperature=temperature,
                    top_p=top_p,
                    do_sample=True,
                )
                return response.strip()
            
            else:
                raise ValueError("Either prompt or chat_messages required")
                
        except Exception as e:
            print(f"[{self.name}] Generation error: {e}")
            raise
    
    def get_info(self) -> Dict[str, Any]:
        """Return model info."""
        return {
            "name": self.name,
            "model_id": self.model_id,
            "type": "inference_api",
            "initialized": self._initialized,
        }