import torch from transformers import ( AutoModelForCausalLM, AutoProcessor, BitsAndBytesConfig, GenerationConfig ) from PIL import Image import json from typing import Optional, List, Dict, Any, Union import time from dataclasses import dataclass import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @dataclass class InferenceMetrics: latency_ms: float tokens_generated: int tokens_per_second: float memory_used_gb: float input_tokens: int total_tokens: int class AdvancedHelionInference: def __init__( self, model_name: str = "DeepXR/Helion-V2.0-Thinking", quantization: Optional[str] = None, device: str = "auto", use_flash_attention: bool = True, torch_compile: bool = False, optimization_mode: str = "balanced" ): logger.info(f"Initializing Helion-V2.0-Thinking with {optimization_mode} mode") self.model_name = model_name self.optimization_mode = optimization_mode self.metrics_history = [] quantization_config = self._get_quantization_config(quantization) logger.info("Loading processor...") self.processor = AutoProcessor.from_pretrained(model_name) logger.info("Loading model...") self.model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=quantization_config, device_map=device, torch_dtype=torch.bfloat16 if quantization is None else None, use_flash_attention_2=use_flash_attention, trust_remote_code=True, low_cpu_mem_usage=True ) if torch_compile and quantization is None: logger.info("Compiling model with torch.compile...") self.model = torch.compile(self.model, mode="reduce-overhead") self.model.eval() self.generation_configs = { "creative": GenerationConfig( do_sample=True, temperature=0.9, top_p=0.95, top_k=50, repetition_penalty=1.15, max_new_tokens=2048 ), "precise": GenerationConfig( do_sample=True, temperature=0.3, top_p=0.85, top_k=40, repetition_penalty=1.05, max_new_tokens=1024 ), "balanced": GenerationConfig( do_sample=True, temperature=0.7, top_p=0.9, top_k=50, repetition_penalty=1.1, max_new_tokens=1024 ), "code": GenerationConfig( do_sample=True, temperature=0.2, top_p=0.9, top_k=40, repetition_penalty=1.05, max_new_tokens=2048 ) } logger.info("Model loaded successfully!") def _get_quantization_config(self, quantization: Optional[str]) -> Optional[BitsAndBytesConfig]: if quantization is None: return None quantization_configs = { "4bit": BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ), "8bit": BitsAndBytesConfig( load_in_8bit=True ) } return quantization_configs.get(quantization) def generate( self, prompt: str, images: Optional[Union[Image.Image, List[Image.Image]]] = None, mode: str = "balanced", max_new_tokens: Optional[int] = None, temperature: Optional[float] = None, stream: bool = False, return_metrics: bool = False, **kwargs ) -> Union[str, tuple[str, InferenceMetrics]]: if isinstance(images, Image.Image): images = [images] start_time = time.time() initial_memory = torch.cuda.memory_allocated() / (1024**3) if torch.cuda.is_available() else 0 if images: inputs = self.processor( text=prompt, images=images, return_tensors="pt" ).to(self.model.device) else: inputs = self.processor( text=prompt, return_tensors="pt" ).to(self.model.device) input_length = inputs['input_ids'].shape[1] gen_config = self.generation_configs[mode].to_dict() if max_new_tokens: gen_config['max_new_tokens'] = max_new_tokens if temperature: gen_config['temperature'] = temperature gen_config.update(kwargs) with torch.no_grad(): if stream: return self._generate_stream(inputs, gen_config, return_metrics) else: outputs = self.model.generate( **inputs, **gen_config, pad_token_id=self.processor.tokenizer.eos_token_id ) if torch.cuda.is_available(): torch.cuda.synchronize() end_time = time.time() latency = (end_time - start_time) * 1000 response = self.processor.decode(outputs[0], skip_special_tokens=True) if response.startswith(prompt): response = response[len(prompt):].strip() tokens_generated = outputs.shape[1] - input_length tokens_per_second = tokens_generated / ((end_time - start_time) if (end_time - start_time) > 0 else 1) final_memory = torch.cuda.memory_allocated() / (1024**3) if torch.cuda.is_available() else 0 memory_used = final_memory - initial_memory metrics = InferenceMetrics( latency_ms=latency, tokens_generated=tokens_generated, tokens_per_second=tokens_per_second, memory_used_gb=memory_used, input_tokens=input_length, total_tokens=outputs.shape[1] ) self.metrics_history.append(metrics) if return_metrics: return response, metrics return response def _generate_stream(self, inputs, gen_config, return_metrics): from transformers import TextIteratorStreamer from threading import Thread streamer = TextIteratorStreamer( self.processor.tokenizer, skip_special_tokens=True, skip_prompt=True ) gen_config['streamer'] = streamer thread = Thread( target=self.model.generate, kwargs={**inputs, **gen_config} ) thread.start() for new_text in streamer: yield new_text thread.join() def batch_generate( self, prompts: List[str], images_list: Optional[List[Optional[Union[Image.Image, List[Image.Image]]]]] = None, mode: str = "balanced", **kwargs ) -> List[str]: if images_list is None: images_list = [None] * len(prompts) all_inputs = [] for prompt, images in zip(prompts, images_list): if images: if isinstance(images, Image.Image): images = [images] inputs = self.processor( text=prompt, images=images, return_tensors="pt", padding=True ) else: inputs = self.processor( text=prompt, return_tensors="pt", padding=True ) all_inputs.append(inputs) batch_inputs = { k: torch.cat([inp[k] for inp in all_inputs], dim=0).to(self.model.device) for k in all_inputs[0].keys() } gen_config = self.generation_configs[mode].to_dict() gen_config.update(kwargs) with torch.no_grad(): outputs = self.model.generate( **batch_inputs, **gen_config, pad_token_id=self.processor.tokenizer.eos_token_id ) responses = [ self.processor.decode(output, skip_special_tokens=True) for output in outputs ] return responses def vision_qa( self, image: Image.Image, question: str, mode: str = "precise" ) -> str: prompt = f"Question: {question}\nAnswer:" return self.generate(prompt, images=image, mode=mode) def analyze_image( self, image: Image.Image, analysis_type: str = "detailed" ) -> str: prompts = { "detailed": "Provide a detailed description of this image, including objects, people, actions, setting, and any text visible.", "quick": "Briefly describe what you see in this image.", "technical": "Analyze this image from a technical perspective, including composition, lighting, colors, and quality.", "ocr": "Extract all text visible in this image and organize it clearly." } prompt = prompts.get(analysis_type, prompts["detailed"]) return self.generate(prompt, images=image, mode="precise") def code_generation( self, task: str, language: str = "python", include_tests: bool = False ) -> str: prompt = f"Write {language} code for the following task:\n{task}" if include_tests: prompt += "\n\nInclude unit tests for the code." return self.generate(prompt, mode="code", max_new_tokens=2048) def function_call( self, user_query: str, available_tools: List[Dict[str, Any]] ) -> Dict[str, Any]: tools_str = json.dumps(available_tools, indent=2) prompt = f"""Available tools: {tools_str} User query: {user_query} Respond with a JSON object specifying which tool to use and with what parameters: {{"tool": "tool_name", "parameters": {{"param": "value"}}}} Response:""" response = self.generate(prompt, mode="precise", temperature=0.2) try: import re json_match = re.search(r'\{.*\}', response, re.DOTALL) if json_match: return json.loads(json_match.group()) return {"error": "No valid JSON found", "raw": response} except json.JSONDecodeError as e: return {"error": str(e), "raw": response} def multi_modal_rag( self, query: str, documents: List[str], images: Optional[List[Image.Image]] = None ) -> str: context = "\n\n".join([f"Document {i+1}:\n{doc}" for i, doc in enumerate(documents)]) prompt = f"""Context:\n{context}\n\nQuestion: {query}\n\nAnswer based on the provided context:""" return self.generate(prompt, images=images, mode="precise", max_new_tokens=1024) def get_metrics_summary(self) -> Dict[str, float]: if not self.metrics_history: return {} return { "avg_latency_ms": sum(m.latency_ms for m in self.metrics_history) / len(self.metrics_history), "avg_tokens_per_second": sum(m.tokens_per_second for m in self.metrics_history) / len(self.metrics_history), "avg_memory_used_gb": sum(m.memory_used_gb for m in self.metrics_history) / len(self.metrics_history), "total_tokens_generated": sum(m.tokens_generated for m in self.metrics_history), "num_requests": len(self.metrics_history) } def clear_cache(self): if torch.cuda.is_available(): torch.cuda.empty_cache() self.model.clear_cache() if hasattr(self.model, 'clear_cache') else None logger.info("Cache cleared") def main(): import argparse parser = argparse.ArgumentParser(description="Advanced Helion-V2.0-Thinking Inference") parser.add_argument("--model", type=str, default="DeepXR/Helion-V2.0-Thinking") parser.add_argument("--quantization", type=str, choices=["4bit", "8bit", None], default=None) parser.add_argument("--mode", type=str, default="balanced", choices=["creative", "precise", "balanced", "code"]) parser.add_argument("--prompt", type=str, help="Text prompt") parser.add_argument("--image", type=str, help="Path to image file") parser.add_argument("--stream", action="store_true", help="Enable streaming output") parser.add_argument("--torch-compile", action="store_true", help="Use torch.compile") parser.add_argument("--benchmark", action="store_true", help="Run benchmark") args = parser.parse_args() model = AdvancedHelionInference( model_name=args.model, quantization=args.quantization, torch_compile=args.torch_compile ) if args.benchmark: print("Running benchmark...") test_prompts = [ "Explain quantum computing in simple terms.", "Write a Python function to calculate fibonacci numbers.", "What are the main causes of climate change?" ] for prompt in test_prompts: response, metrics = model.generate( prompt, mode=args.mode, return_metrics=True ) print(f"\nPrompt: {prompt}") print(f"Response: {response[:100]}...") print(f"Metrics: {metrics}") summary = model.get_metrics_summary() print(f"\nBenchmark Summary:") for key, value in summary.items(): print(f" {key}: {value:.2f}") elif args.prompt: image = Image.open(args.image) if args.image else None if args.stream: print("Streaming response:") for text in model.generate(args.prompt, images=image, mode=args.mode, stream=True): print(text, end="", flush=True) print() else: response, metrics = model.generate( args.prompt, images=image, mode=args.mode, return_metrics=True ) print(f"Response: {response}") print(f"\nMetrics:") print(f" Latency: {metrics.latency_ms:.2f}ms") print(f" Tokens/sec: {metrics.tokens_per_second:.2f}") print(f" Tokens generated: {metrics.tokens_generated}") else: print("Please provide --prompt or use --benchmark") if __name__ == "__main__": main()