bielik_app_service / app /models /huggingface_service.py
Patryk Studzinski
adding-github-files-to-spaces
9a9ec03
raw
history blame
5.45 kB
from transformers import pipeline, AutoTokenizer
import torch
from fastapi import HTTPException
import asyncio
class HuggingFaceTextGenerationService:
def __init__(self, model_name_or_path: str, device: str = None, task: str = "text-generation"):
self.model_name_or_path = model_name_or_path
self.task = task
self.pipeline = None
self.tokenizer = None
if device is None:
self.device_index = 0 if torch.cuda.is_available() else -1
elif device == "cuda" and torch.cuda.is_available():
self.device_index = 0
elif device == "cpu":
self.device_index = -1
else:
self.device_index = -1
if self.device_index == 0:
print("CUDA (GPU) is available. Using GPU.")
else:
print(f"Device set to use {'cpu' if self.device_index == -1 else f'cuda:{self.device_index}'}")
async def initialize(self):
try:
print(f"Initializing Hugging Face pipeline for model: {self.model_name_or_path} on device index: {self.device_index}")
self.tokenizer = await asyncio.to_thread(
AutoTokenizer.from_pretrained, self.model_name_or_path, trust_remote_code=True
)
self.pipeline = await asyncio.to_thread(
pipeline,
self.task,
model=self.model_name_or_path,
tokenizer=self.tokenizer,
device=self.device_index,
torch_dtype=torch.bfloat16 if self.device_index != -1 and torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32,
trust_remote_code=True,
)
print(f"Pipeline for model {self.model_name_or_path} initialized successfully.")
except Exception as e:
print(f"Error initializing HuggingFace pipeline: {e}")
raise HTTPException(status_code=503, detail=f"LLM (HuggingFace) model could not be loaded: {str(e)}")
async def generate_text(self, prompt_text: str = None, chat_template_messages: list = None, max_new_tokens: int = 250, temperature: float = 0.7, top_p: float = 0.9, do_sample: bool = True, **kwargs) -> str:
if not self.pipeline or not self.tokenizer:
raise Exception("Pipeline is not initialized. Call initialize() first.")
formatted_prompt_input = ""
if chat_template_messages:
try:
formatted_prompt_input = self.tokenizer.apply_chat_template(
chat_template_messages,
tokenize=False,
add_generation_prompt=True
)
except Exception as e:
print(f"Could not apply chat template, falling back to raw prompt if available. Error: {e}")
if prompt_text:
formatted_prompt_input = prompt_text
else:
raise ValueError("Cannot generate text without a valid prompt or chat_template_messages.")
elif prompt_text:
formatted_prompt_input = prompt_text
else:
raise ValueError("Either prompt_text or chat_template_messages must be provided.")
try:
generated_outputs = await asyncio.to_thread(
self.pipeline,
formatted_prompt_input,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.eos_token_id if self.tokenizer.pad_token_id is None else self.tokenizer.pad_token_id, # Common setting
**kwargs
)
if generated_outputs and isinstance(generated_outputs, list) and "generated_text" in generated_outputs[0]:
full_generated_sequence = generated_outputs[0]["generated_text"]
assistant_response = ""
if full_generated_sequence.startswith(formatted_prompt_input):
assistant_response = full_generated_sequence[len(formatted_prompt_input):]
else:
assistant_marker = "<|im_start|>assistant\n"
last_marker_pos = full_generated_sequence.rfind(assistant_marker)
if last_marker_pos != -1:
assistant_response = full_generated_sequence[last_marker_pos + len(assistant_marker):]
print("Warning: Used fallback parsing for assistant response.")
else:
print("Error: Could not isolate assistant response from the full generated sequence.")
assistant_response = full_generated_sequence
if assistant_response.endswith("<|im_end|>"):
assistant_response = assistant_response[:-len("<|im_end|>")]
return assistant_response.strip()
else:
print(f"Unexpected output format from pipeline: {generated_outputs}")
return "Error: Could not parse generated text from pipeline output."
except Exception as e:
print(f"Error during text generation with {self.model_name_or_path}: {e}")
raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}")