""" Helion-V1.5-XL Inference Script Supports multiple inference modes and optimization techniques """ import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GenerationConfig ) from typing import Optional, Dict, Any, List import argparse import json import time class HelionInference: """Inference wrapper for Helion-V1.5-XL""" def __init__( self, model_name: str = "DeepXR/Helion-V1.5-XL", load_in_4bit: bool = False, load_in_8bit: bool = False, device_map: str = "auto", torch_dtype: str = "bfloat16" ): """ Initialize the model and tokenizer Args: model_name: HuggingFace model identifier load_in_4bit: Enable 4-bit quantization load_in_8bit: Enable 8-bit quantization device_map: Device mapping strategy torch_dtype: PyTorch dtype for model weights """ self.model_name = model_name print(f"Loading model: {model_name}") # Setup dtype dtype_map = { "bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32 } torch_dtype = dtype_map.get(torch_dtype, torch.bfloat16) # Setup quantization config quantization_config = None if load_in_4bit: quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch_dtype, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) elif load_in_8bit: quantization_config = BitsAndBytesConfig(load_in_8bit=True) # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True ) # Load model model_kwargs = { "device_map": device_map, "trust_remote_code": True, } if quantization_config: model_kwargs["quantization_config"] = quantization_config else: model_kwargs["torch_dtype"] = torch_dtype self.model = AutoModelForCausalLM.from_pretrained( model_name, **model_kwargs ) self.model.eval() print("Model loaded successfully!") def generate( self, prompt: str, max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.1, do_sample: bool = True, num_return_sequences: int = 1, **kwargs ) -> List[str]: """ Generate text from a prompt Args: prompt: Input text prompt max_new_tokens: Maximum number of tokens to generate temperature: Sampling temperature (0.0 to 2.0) top_p: Nucleus sampling threshold top_k: Top-k sampling threshold repetition_penalty: Penalty for repetition do_sample: Whether to use sampling num_return_sequences: Number of sequences to generate Returns: List of generated text strings """ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) generation_config = GenerationConfig( max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, do_sample=do_sample, num_return_sequences=num_return_sequences, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, **kwargs ) start_time = time.time() with torch.no_grad(): outputs = self.model.generate( **inputs, generation_config=generation_config ) generation_time = time.time() - start_time # Decode outputs responses = [] for output in outputs: response = self.tokenizer.decode(output, skip_special_tokens=True) # Remove the prompt from response response = response[len(prompt):].strip() responses.append(response) # Calculate tokens per second total_tokens = sum(len(output) for output in outputs) tokens_per_sec = total_tokens / generation_time print(f"\nGeneration Stats:") print(f" Time: {generation_time:.2f}s") print(f" Tokens/sec: {tokens_per_sec:.2f}") return responses def chat( self, messages: List[Dict[str, str]], max_new_tokens: int = 512, temperature: float = 0.7, **kwargs ) -> str: """ Generate response in chat format Args: messages: List of message dicts with 'role' and 'content' max_new_tokens: Maximum tokens to generate temperature: Sampling temperature Returns: Generated response string """ # Apply chat template prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) responses = self.generate( prompt, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs ) return responses[0] def batch_generate( self, prompts: List[str], max_new_tokens: int = 512, **kwargs ) -> List[str]: """ Generate responses for multiple prompts in batch Args: prompts: List of input prompts max_new_tokens: Maximum tokens per generation Returns: List of generated responses """ inputs = self.tokenizer( prompts, return_tensors="pt", padding=True, truncation=True ).to(self.model.device) with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, **kwargs ) responses = [] for i, output in enumerate(outputs): response = self.tokenizer.decode(output, skip_special_tokens=True) # Remove prompt response = response[len(prompts[i]):].strip() responses.append(response) return responses def main(): parser = argparse.ArgumentParser(description="Helion-V1.5-XL Inference") parser.add_argument( "--model", type=str, default="DeepXR/Helion-V1.5-XL", help="Model name or path" ) parser.add_argument( "--prompt", type=str, required=True, help="Input prompt" ) parser.add_argument( "--max-tokens", type=int, default=512, help="Maximum tokens to generate" ) parser.add_argument( "--temperature", type=float, default=0.7, help="Sampling temperature" ) parser.add_argument( "--top-p", type=float, default=0.9, help="Nucleus sampling threshold" ) parser.add_argument( "--load-in-4bit", action="store_true", help="Load model in 4-bit quantization" ) parser.add_argument( "--load-in-8bit", action="store_true", help="Load model in 8-bit quantization" ) parser.add_argument( "--chat-mode", action="store_true", help="Use chat format" ) args = parser.parse_args() # Initialize model inference = HelionInference( model_name=args.model, load_in_4bit=args.load_in_4bit, load_in_8bit=args.load_in_8bit ) # Generate response if args.chat_mode: messages = [ {"role": "user", "content": args.prompt} ] response = inference.chat( messages, max_new_tokens=args.max_tokens, temperature=args.temperature, top_p=args.top_p ) else: responses = inference.generate( args.prompt, max_new_tokens=args.max_tokens, temperature=args.temperature, top_p=args.top_p ) response = responses[0] print("\n" + "="*80) print("PROMPT:") print("="*80) print(args.prompt) print("\n" + "="*80) print("RESPONSE:") print("="*80) print(response) print("="*80) if __name__ == "__main__": main()