Spaces:
Sleeping
Sleeping
| """LLM operations and text generation functionality""" | |
| import os | |
| import torch | |
| from src.utils.model_loader import ModelLoader | |
| from src.utils.prompt_formatter import PromptFormatter | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| class LLMManager: | |
| """Manages LLM model loading and text generation operations""" | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.device, self.dtype = ModelLoader.get_device_and_dtype() | |
| def load_models(self, model_name="meta-llama/Llama-3.2-1B-Instruct"): | |
| """Load the LLM model and tokenizer""" | |
| self.model_name = model_name | |
| self.model, self.tokenizer, self.device, self.dtype = ModelLoader.load_model_and_tokenizer(model_name) | |
| def validate_user_input(self, user_input, max_tokens=5): | |
| """Validate that user input is within token limits""" | |
| if not self.tokenizer: | |
| raise RuntimeError("Tokenizer not loaded. Call load_models() first.") | |
| tokens = self.tokenizer.encode(user_input, add_special_tokens=False) | |
| return len(tokens) <= max_tokens | |
| def count_tokens(self, text): | |
| """Count tokens in the given text""" | |
| if not self.tokenizer: | |
| raise RuntimeError("Tokenizer not loaded. Call load_models() first.") | |
| tokens = self.tokenizer.encode(text, add_special_tokens=False) | |
| return len(tokens) | |
| def tokenize_for_visualization(self, text): | |
| """Tokenize text and return individual token representations""" | |
| if not self.tokenizer: | |
| raise RuntimeError("Tokenizer not loaded. Call load_models() first.") | |
| tokens = self.tokenizer.encode(text, add_special_tokens=False) | |
| token_texts = [] | |
| for token_id in tokens: | |
| token_text = self.tokenizer.decode([token_id], skip_special_tokens=True) | |
| token_texts.append(token_text) | |
| return tokens, token_texts | |
| def extract_assistant_response(self, full_response: str) -> str: | |
| """Extract the assistant's response from the full generated text""" | |
| return PromptFormatter.extract_assistant_response(self.model_name, full_response) | |
| def format_prompt(self, prompt: str, partial_response: str, continuation: str) -> str: | |
| """Format the full prompt for generation""" | |
| return PromptFormatter.format_prompt(self.model_name, prompt, partial_response, continuation) | |
| def generate_response_from_user_input(self, prompt, partial_response, user_continuation): | |
| """Generate a full response from user's continuation""" | |
| if not self.model or not self.tokenizer: | |
| raise RuntimeError("Models not loaded. Call load_models() first.") | |
| # TODO: make this more robust for multiple models, needs to be formatted correctly | |
| full_prompt = self.format_prompt(prompt, partial_response, user_continuation) | |
| inputs = self.tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| inputs['input_ids'], | |
| attention_mask=inputs['attention_mask'], | |
| max_new_tokens=1000, | |
| do_sample=True, | |
| top_p=0.95, | |
| temperature=1.0, | |
| pad_token_id=self.tokenizer.eos_token_id | |
| ) | |
| full_response = self.tokenizer.decode(outputs[0].cpu(), skip_special_tokens=True) | |
| assistant_part = self.extract_assistant_response(full_response) | |
| return assistant_part |