"""LLM operations and text generation functionality""" import os import torch from src.utils.model_loader import ModelLoader from src.utils.prompt_formatter import PromptFormatter os.environ["TOKENIZERS_PARALLELISM"] = "false" class LLMManager: """Manages LLM model loading and text generation operations""" def __init__(self): self.model = None self.tokenizer = None self.device, self.dtype = ModelLoader.get_device_and_dtype() def load_models(self, model_name="meta-llama/Llama-3.2-1B-Instruct"): """Load the LLM model and tokenizer""" self.model_name = model_name self.model, self.tokenizer, self.device, self.dtype = ModelLoader.load_model_and_tokenizer(model_name) def validate_user_input(self, user_input, max_tokens=5): """Validate that user input is within token limits""" if not self.tokenizer: raise RuntimeError("Tokenizer not loaded. Call load_models() first.") tokens = self.tokenizer.encode(user_input, add_special_tokens=False) return len(tokens) <= max_tokens def count_tokens(self, text): """Count tokens in the given text""" if not self.tokenizer: raise RuntimeError("Tokenizer not loaded. Call load_models() first.") tokens = self.tokenizer.encode(text, add_special_tokens=False) return len(tokens) def tokenize_for_visualization(self, text): """Tokenize text and return individual token representations""" if not self.tokenizer: raise RuntimeError("Tokenizer not loaded. Call load_models() first.") tokens = self.tokenizer.encode(text, add_special_tokens=False) token_texts = [] for token_id in tokens: token_text = self.tokenizer.decode([token_id], skip_special_tokens=True) token_texts.append(token_text) return tokens, token_texts def extract_assistant_response(self, full_response: str) -> str: """Extract the assistant's response from the full generated text""" return PromptFormatter.extract_assistant_response(self.model_name, full_response) def format_prompt(self, prompt: str, partial_response: str, continuation: str) -> str: """Format the full prompt for generation""" return PromptFormatter.format_prompt(self.model_name, prompt, partial_response, continuation) def generate_response_from_user_input(self, prompt, partial_response, user_continuation): """Generate a full response from user's continuation""" if not self.model or not self.tokenizer: raise RuntimeError("Models not loaded. Call load_models() first.") # TODO: make this more robust for multiple models, needs to be formatted correctly full_prompt = self.format_prompt(prompt, partial_response, user_continuation) inputs = self.tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model.generate( inputs['input_ids'], attention_mask=inputs['attention_mask'], max_new_tokens=1000, do_sample=True, top_p=0.95, temperature=1.0, pad_token_id=self.tokenizer.eos_token_id ) full_response = self.tokenizer.decode(outputs[0].cpu(), skip_special_tokens=True) assistant_part = self.extract_assistant_response(full_response) return assistant_part