Helion-V2.0-Thinking / advanced_inference.py
AlexGall's picture
Create advanced_inference.py
6e610ec verified
raw
history blame
15.2 kB
import torch
from transformers import (
AutoModelForCausalLM,
AutoProcessor,
BitsAndBytesConfig,
GenerationConfig
)
from PIL import Image
import json
from typing import Optional, List, Dict, Any, Union
import time
from dataclasses import dataclass
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class InferenceMetrics:
latency_ms: float
tokens_generated: int
tokens_per_second: float
memory_used_gb: float
input_tokens: int
total_tokens: int
class AdvancedHelionInference:
def __init__(
self,
model_name: str = "DeepXR/Helion-V2.0-Thinking",
quantization: Optional[str] = None,
device: str = "auto",
use_flash_attention: bool = True,
torch_compile: bool = False,
optimization_mode: str = "balanced"
):
logger.info(f"Initializing Helion-V2.0-Thinking with {optimization_mode} mode")
self.model_name = model_name
self.optimization_mode = optimization_mode
self.metrics_history = []
quantization_config = self._get_quantization_config(quantization)
logger.info("Loading processor...")
self.processor = AutoProcessor.from_pretrained(model_name)
logger.info("Loading model...")
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map=device,
torch_dtype=torch.bfloat16 if quantization is None else None,
use_flash_attention_2=use_flash_attention,
trust_remote_code=True,
low_cpu_mem_usage=True
)
if torch_compile and quantization is None:
logger.info("Compiling model with torch.compile...")
self.model = torch.compile(self.model, mode="reduce-overhead")
self.model.eval()
self.generation_configs = {
"creative": GenerationConfig(
do_sample=True,
temperature=0.9,
top_p=0.95,
top_k=50,
repetition_penalty=1.15,
max_new_tokens=2048
),
"precise": GenerationConfig(
do_sample=True,
temperature=0.3,
top_p=0.85,
top_k=40,
repetition_penalty=1.05,
max_new_tokens=1024
),
"balanced": GenerationConfig(
do_sample=True,
temperature=0.7,
top_p=0.9,
top_k=50,
repetition_penalty=1.1,
max_new_tokens=1024
),
"code": GenerationConfig(
do_sample=True,
temperature=0.2,
top_p=0.9,
top_k=40,
repetition_penalty=1.05,
max_new_tokens=2048
)
}
logger.info("Model loaded successfully!")
def _get_quantization_config(self, quantization: Optional[str]) -> Optional[BitsAndBytesConfig]:
if quantization is None:
return None
quantization_configs = {
"4bit": BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
),
"8bit": BitsAndBytesConfig(
load_in_8bit=True
)
}
return quantization_configs.get(quantization)
def generate(
self,
prompt: str,
images: Optional[Union[Image.Image, List[Image.Image]]] = None,
mode: str = "balanced",
max_new_tokens: Optional[int] = None,
temperature: Optional[float] = None,
stream: bool = False,
return_metrics: bool = False,
**kwargs
) -> Union[str, tuple[str, InferenceMetrics]]:
if isinstance(images, Image.Image):
images = [images]
start_time = time.time()
initial_memory = torch.cuda.memory_allocated() / (1024**3) if torch.cuda.is_available() else 0
if images:
inputs = self.processor(
text=prompt,
images=images,
return_tensors="pt"
).to(self.model.device)
else:
inputs = self.processor(
text=prompt,
return_tensors="pt"
).to(self.model.device)
input_length = inputs['input_ids'].shape[1]
gen_config = self.generation_configs[mode].to_dict()
if max_new_tokens:
gen_config['max_new_tokens'] = max_new_tokens
if temperature:
gen_config['temperature'] = temperature
gen_config.update(kwargs)
with torch.no_grad():
if stream:
return self._generate_stream(inputs, gen_config, return_metrics)
else:
outputs = self.model.generate(
**inputs,
**gen_config,
pad_token_id=self.processor.tokenizer.eos_token_id
)
if torch.cuda.is_available():
torch.cuda.synchronize()
end_time = time.time()
latency = (end_time - start_time) * 1000
response = self.processor.decode(outputs[0], skip_special_tokens=True)
if response.startswith(prompt):
response = response[len(prompt):].strip()
tokens_generated = outputs.shape[1] - input_length
tokens_per_second = tokens_generated / ((end_time - start_time) if (end_time - start_time) > 0 else 1)
final_memory = torch.cuda.memory_allocated() / (1024**3) if torch.cuda.is_available() else 0
memory_used = final_memory - initial_memory
metrics = InferenceMetrics(
latency_ms=latency,
tokens_generated=tokens_generated,
tokens_per_second=tokens_per_second,
memory_used_gb=memory_used,
input_tokens=input_length,
total_tokens=outputs.shape[1]
)
self.metrics_history.append(metrics)
if return_metrics:
return response, metrics
return response
def _generate_stream(self, inputs, gen_config, return_metrics):
from transformers import TextIteratorStreamer
from threading import Thread
streamer = TextIteratorStreamer(
self.processor.tokenizer,
skip_special_tokens=True,
skip_prompt=True
)
gen_config['streamer'] = streamer
thread = Thread(
target=self.model.generate,
kwargs={**inputs, **gen_config}
)
thread.start()
for new_text in streamer:
yield new_text
thread.join()
def batch_generate(
self,
prompts: List[str],
images_list: Optional[List[Optional[Union[Image.Image, List[Image.Image]]]]] = None,
mode: str = "balanced",
**kwargs
) -> List[str]:
if images_list is None:
images_list = [None] * len(prompts)
all_inputs = []
for prompt, images in zip(prompts, images_list):
if images:
if isinstance(images, Image.Image):
images = [images]
inputs = self.processor(
text=prompt,
images=images,
return_tensors="pt",
padding=True
)
else:
inputs = self.processor(
text=prompt,
return_tensors="pt",
padding=True
)
all_inputs.append(inputs)
batch_inputs = {
k: torch.cat([inp[k] for inp in all_inputs], dim=0).to(self.model.device)
for k in all_inputs[0].keys()
}
gen_config = self.generation_configs[mode].to_dict()
gen_config.update(kwargs)
with torch.no_grad():
outputs = self.model.generate(
**batch_inputs,
**gen_config,
pad_token_id=self.processor.tokenizer.eos_token_id
)
responses = [
self.processor.decode(output, skip_special_tokens=True)
for output in outputs
]
return responses
def vision_qa(
self,
image: Image.Image,
question: str,
mode: str = "precise"
) -> str:
prompt = f"Question: {question}\nAnswer:"
return self.generate(prompt, images=image, mode=mode)
def analyze_image(
self,
image: Image.Image,
analysis_type: str = "detailed"
) -> str:
prompts = {
"detailed": "Provide a detailed description of this image, including objects, people, actions, setting, and any text visible.",
"quick": "Briefly describe what you see in this image.",
"technical": "Analyze this image from a technical perspective, including composition, lighting, colors, and quality.",
"ocr": "Extract all text visible in this image and organize it clearly."
}
prompt = prompts.get(analysis_type, prompts["detailed"])
return self.generate(prompt, images=image, mode="precise")
def code_generation(
self,
task: str,
language: str = "python",
include_tests: bool = False
) -> str:
prompt = f"Write {language} code for the following task:\n{task}"
if include_tests:
prompt += "\n\nInclude unit tests for the code."
return self.generate(prompt, mode="code", max_new_tokens=2048)
def function_call(
self,
user_query: str,
available_tools: List[Dict[str, Any]]
) -> Dict[str, Any]:
tools_str = json.dumps(available_tools, indent=2)
prompt = f"""Available tools:
{tools_str}
User query: {user_query}
Respond with a JSON object specifying which tool to use and with what parameters:
{{"tool": "tool_name", "parameters": {{"param": "value"}}}}
Response:"""
response = self.generate(prompt, mode="precise", temperature=0.2)
try:
import re
json_match = re.search(r'\{.*\}', response, re.DOTALL)
if json_match:
return json.loads(json_match.group())
return {"error": "No valid JSON found", "raw": response}
except json.JSONDecodeError as e:
return {"error": str(e), "raw": response}
def multi_modal_rag(
self,
query: str,
documents: List[str],
images: Optional[List[Image.Image]] = None
) -> str:
context = "\n\n".join([f"Document {i+1}:\n{doc}" for i, doc in enumerate(documents)])
prompt = f"""Context:\n{context}\n\nQuestion: {query}\n\nAnswer based on the provided context:"""
return self.generate(prompt, images=images, mode="precise", max_new_tokens=1024)
def get_metrics_summary(self) -> Dict[str, float]:
if not self.metrics_history:
return {}
return {
"avg_latency_ms": sum(m.latency_ms for m in self.metrics_history) / len(self.metrics_history),
"avg_tokens_per_second": sum(m.tokens_per_second for m in self.metrics_history) / len(self.metrics_history),
"avg_memory_used_gb": sum(m.memory_used_gb for m in self.metrics_history) / len(self.metrics_history),
"total_tokens_generated": sum(m.tokens_generated for m in self.metrics_history),
"num_requests": len(self.metrics_history)
}
def clear_cache(self):
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.model.clear_cache() if hasattr(self.model, 'clear_cache') else None
logger.info("Cache cleared")
def main():
import argparse
parser = argparse.ArgumentParser(description="Advanced Helion-V2.0-Thinking Inference")
parser.add_argument("--model", type=str, default="DeepXR/Helion-V2.0-Thinking")
parser.add_argument("--quantization", type=str, choices=["4bit", "8bit", None], default=None)
parser.add_argument("--mode", type=str, default="balanced", choices=["creative", "precise", "balanced", "code"])
parser.add_argument("--prompt", type=str, help="Text prompt")
parser.add_argument("--image", type=str, help="Path to image file")
parser.add_argument("--stream", action="store_true", help="Enable streaming output")
parser.add_argument("--torch-compile", action="store_true", help="Use torch.compile")
parser.add_argument("--benchmark", action="store_true", help="Run benchmark")
args = parser.parse_args()
model = AdvancedHelionInference(
model_name=args.model,
quantization=args.quantization,
torch_compile=args.torch_compile
)
if args.benchmark:
print("Running benchmark...")
test_prompts = [
"Explain quantum computing in simple terms.",
"Write a Python function to calculate fibonacci numbers.",
"What are the main causes of climate change?"
]
for prompt in test_prompts:
response, metrics = model.generate(
prompt,
mode=args.mode,
return_metrics=True
)
print(f"\nPrompt: {prompt}")
print(f"Response: {response[:100]}...")
print(f"Metrics: {metrics}")
summary = model.get_metrics_summary()
print(f"\nBenchmark Summary:")
for key, value in summary.items():
print(f" {key}: {value:.2f}")
elif args.prompt:
image = Image.open(args.image) if args.image else None
if args.stream:
print("Streaming response:")
for text in model.generate(args.prompt, images=image, mode=args.mode, stream=True):
print(text, end="", flush=True)
print()
else:
response, metrics = model.generate(
args.prompt,
images=image,
mode=args.mode,
return_metrics=True
)
print(f"Response: {response}")
print(f"\nMetrics:")
print(f" Latency: {metrics.latency_ms:.2f}ms")
print(f" Tokens/sec: {metrics.tokens_per_second:.2f}")
print(f" Tokens generated: {metrics.tokens_generated}")
else:
print("Please provide --prompt or use --benchmark")
if __name__ == "__main__":
main()