| | from typing import Dict, Any, List |
| | import torch |
| | from transformers import T5ForConditionalGeneration, T5Tokenizer |
| |
|
| | class EndpointHandler(): |
| | def __init__(self, path=""): |
| | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| | try: |
| | self.model = T5ForConditionalGeneration.from_pretrained(path).to(self.device) |
| | self.tokenizer = T5Tokenizer.from_pretrained(path) |
| | except Exception as e: |
| | print(f"Error loading model or tokenizer from path {path}: {e}") |
| | |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | inputs = data.get("inputs", "") |
| | if not inputs: |
| | return [{"error": "No inputs provided"}] |
| |
|
| | tokenized_input = self.tokenizer(inputs, return_tensors="pt", truncation=True, max_length=512, padding="max_length") |
| | tokenized_input = tokenized_input.to(self.device) |
| |
|
| | summary_ids = self.model.generate(**tokenized_input, max_length=400, do_sample=True, top_p=0.8) |
| |
|
| | summary_text = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
| |
|
| | return [{"summary": summary_text}] |