|
|
import torch |
|
|
from transformers import ( |
|
|
AutoModelForCausalLM, |
|
|
AutoProcessor, |
|
|
BitsAndBytesConfig, |
|
|
GenerationConfig |
|
|
) |
|
|
from PIL import Image |
|
|
import json |
|
|
from typing import Optional, List, Dict, Any, Union |
|
|
import time |
|
|
from dataclasses import dataclass |
|
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class InferenceMetrics: |
|
|
latency_ms: float |
|
|
tokens_generated: int |
|
|
tokens_per_second: float |
|
|
memory_used_gb: float |
|
|
input_tokens: int |
|
|
total_tokens: int |
|
|
|
|
|
|
|
|
class AdvancedHelionInference: |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "DeepXR/Helion-V2.0-Thinking", |
|
|
quantization: Optional[str] = None, |
|
|
device: str = "auto", |
|
|
use_flash_attention: bool = True, |
|
|
torch_compile: bool = False, |
|
|
optimization_mode: str = "balanced" |
|
|
): |
|
|
logger.info(f"Initializing Helion-V2.0-Thinking with {optimization_mode} mode") |
|
|
|
|
|
self.model_name = model_name |
|
|
self.optimization_mode = optimization_mode |
|
|
self.metrics_history = [] |
|
|
|
|
|
quantization_config = self._get_quantization_config(quantization) |
|
|
|
|
|
logger.info("Loading processor...") |
|
|
self.processor = AutoProcessor.from_pretrained(model_name) |
|
|
|
|
|
logger.info("Loading model...") |
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
quantization_config=quantization_config, |
|
|
device_map=device, |
|
|
torch_dtype=torch.bfloat16 if quantization is None else None, |
|
|
use_flash_attention_2=use_flash_attention, |
|
|
trust_remote_code=True, |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
|
|
|
if torch_compile and quantization is None: |
|
|
logger.info("Compiling model with torch.compile...") |
|
|
self.model = torch.compile(self.model, mode="reduce-overhead") |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
self.generation_configs = { |
|
|
"creative": GenerationConfig( |
|
|
do_sample=True, |
|
|
temperature=0.9, |
|
|
top_p=0.95, |
|
|
top_k=50, |
|
|
repetition_penalty=1.15, |
|
|
max_new_tokens=2048 |
|
|
), |
|
|
"precise": GenerationConfig( |
|
|
do_sample=True, |
|
|
temperature=0.3, |
|
|
top_p=0.85, |
|
|
top_k=40, |
|
|
repetition_penalty=1.05, |
|
|
max_new_tokens=1024 |
|
|
), |
|
|
"balanced": GenerationConfig( |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
top_p=0.9, |
|
|
top_k=50, |
|
|
repetition_penalty=1.1, |
|
|
max_new_tokens=1024 |
|
|
), |
|
|
"code": GenerationConfig( |
|
|
do_sample=True, |
|
|
temperature=0.2, |
|
|
top_p=0.9, |
|
|
top_k=40, |
|
|
repetition_penalty=1.05, |
|
|
max_new_tokens=2048 |
|
|
) |
|
|
} |
|
|
|
|
|
logger.info("Model loaded successfully!") |
|
|
|
|
|
def _get_quantization_config(self, quantization: Optional[str]) -> Optional[BitsAndBytesConfig]: |
|
|
if quantization is None: |
|
|
return None |
|
|
|
|
|
quantization_configs = { |
|
|
"4bit": BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
|
bnb_4bit_use_double_quant=True, |
|
|
bnb_4bit_quant_type="nf4" |
|
|
), |
|
|
"8bit": BitsAndBytesConfig( |
|
|
load_in_8bit=True |
|
|
) |
|
|
} |
|
|
|
|
|
return quantization_configs.get(quantization) |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
prompt: str, |
|
|
images: Optional[Union[Image.Image, List[Image.Image]]] = None, |
|
|
mode: str = "balanced", |
|
|
max_new_tokens: Optional[int] = None, |
|
|
temperature: Optional[float] = None, |
|
|
stream: bool = False, |
|
|
return_metrics: bool = False, |
|
|
**kwargs |
|
|
) -> Union[str, tuple[str, InferenceMetrics]]: |
|
|
|
|
|
if isinstance(images, Image.Image): |
|
|
images = [images] |
|
|
|
|
|
start_time = time.time() |
|
|
initial_memory = torch.cuda.memory_allocated() / (1024**3) if torch.cuda.is_available() else 0 |
|
|
|
|
|
if images: |
|
|
inputs = self.processor( |
|
|
text=prompt, |
|
|
images=images, |
|
|
return_tensors="pt" |
|
|
).to(self.model.device) |
|
|
else: |
|
|
inputs = self.processor( |
|
|
text=prompt, |
|
|
return_tensors="pt" |
|
|
).to(self.model.device) |
|
|
|
|
|
input_length = inputs['input_ids'].shape[1] |
|
|
|
|
|
gen_config = self.generation_configs[mode].to_dict() |
|
|
|
|
|
if max_new_tokens: |
|
|
gen_config['max_new_tokens'] = max_new_tokens |
|
|
if temperature: |
|
|
gen_config['temperature'] = temperature |
|
|
|
|
|
gen_config.update(kwargs) |
|
|
|
|
|
with torch.no_grad(): |
|
|
if stream: |
|
|
return self._generate_stream(inputs, gen_config, return_metrics) |
|
|
else: |
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
**gen_config, |
|
|
pad_token_id=self.processor.tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
end_time = time.time() |
|
|
latency = (end_time - start_time) * 1000 |
|
|
|
|
|
response = self.processor.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
if response.startswith(prompt): |
|
|
response = response[len(prompt):].strip() |
|
|
|
|
|
tokens_generated = outputs.shape[1] - input_length |
|
|
tokens_per_second = tokens_generated / ((end_time - start_time) if (end_time - start_time) > 0 else 1) |
|
|
|
|
|
final_memory = torch.cuda.memory_allocated() / (1024**3) if torch.cuda.is_available() else 0 |
|
|
memory_used = final_memory - initial_memory |
|
|
|
|
|
metrics = InferenceMetrics( |
|
|
latency_ms=latency, |
|
|
tokens_generated=tokens_generated, |
|
|
tokens_per_second=tokens_per_second, |
|
|
memory_used_gb=memory_used, |
|
|
input_tokens=input_length, |
|
|
total_tokens=outputs.shape[1] |
|
|
) |
|
|
|
|
|
self.metrics_history.append(metrics) |
|
|
|
|
|
if return_metrics: |
|
|
return response, metrics |
|
|
return response |
|
|
|
|
|
def _generate_stream(self, inputs, gen_config, return_metrics): |
|
|
from transformers import TextIteratorStreamer |
|
|
from threading import Thread |
|
|
|
|
|
streamer = TextIteratorStreamer( |
|
|
self.processor.tokenizer, |
|
|
skip_special_tokens=True, |
|
|
skip_prompt=True |
|
|
) |
|
|
|
|
|
gen_config['streamer'] = streamer |
|
|
|
|
|
thread = Thread( |
|
|
target=self.model.generate, |
|
|
kwargs={**inputs, **gen_config} |
|
|
) |
|
|
thread.start() |
|
|
|
|
|
for new_text in streamer: |
|
|
yield new_text |
|
|
|
|
|
thread.join() |
|
|
|
|
|
def batch_generate( |
|
|
self, |
|
|
prompts: List[str], |
|
|
images_list: Optional[List[Optional[Union[Image.Image, List[Image.Image]]]]] = None, |
|
|
mode: str = "balanced", |
|
|
**kwargs |
|
|
) -> List[str]: |
|
|
|
|
|
if images_list is None: |
|
|
images_list = [None] * len(prompts) |
|
|
|
|
|
all_inputs = [] |
|
|
for prompt, images in zip(prompts, images_list): |
|
|
if images: |
|
|
if isinstance(images, Image.Image): |
|
|
images = [images] |
|
|
inputs = self.processor( |
|
|
text=prompt, |
|
|
images=images, |
|
|
return_tensors="pt", |
|
|
padding=True |
|
|
) |
|
|
else: |
|
|
inputs = self.processor( |
|
|
text=prompt, |
|
|
return_tensors="pt", |
|
|
padding=True |
|
|
) |
|
|
all_inputs.append(inputs) |
|
|
|
|
|
batch_inputs = { |
|
|
k: torch.cat([inp[k] for inp in all_inputs], dim=0).to(self.model.device) |
|
|
for k in all_inputs[0].keys() |
|
|
} |
|
|
|
|
|
gen_config = self.generation_configs[mode].to_dict() |
|
|
gen_config.update(kwargs) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
**batch_inputs, |
|
|
**gen_config, |
|
|
pad_token_id=self.processor.tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
responses = [ |
|
|
self.processor.decode(output, skip_special_tokens=True) |
|
|
for output in outputs |
|
|
] |
|
|
|
|
|
return responses |
|
|
|
|
|
def vision_qa( |
|
|
self, |
|
|
image: Image.Image, |
|
|
question: str, |
|
|
mode: str = "precise" |
|
|
) -> str: |
|
|
|
|
|
prompt = f"Question: {question}\nAnswer:" |
|
|
return self.generate(prompt, images=image, mode=mode) |
|
|
|
|
|
def analyze_image( |
|
|
self, |
|
|
image: Image.Image, |
|
|
analysis_type: str = "detailed" |
|
|
) -> str: |
|
|
|
|
|
prompts = { |
|
|
"detailed": "Provide a detailed description of this image, including objects, people, actions, setting, and any text visible.", |
|
|
"quick": "Briefly describe what you see in this image.", |
|
|
"technical": "Analyze this image from a technical perspective, including composition, lighting, colors, and quality.", |
|
|
"ocr": "Extract all text visible in this image and organize it clearly." |
|
|
} |
|
|
|
|
|
prompt = prompts.get(analysis_type, prompts["detailed"]) |
|
|
return self.generate(prompt, images=image, mode="precise") |
|
|
|
|
|
def code_generation( |
|
|
self, |
|
|
task: str, |
|
|
language: str = "python", |
|
|
include_tests: bool = False |
|
|
) -> str: |
|
|
|
|
|
prompt = f"Write {language} code for the following task:\n{task}" |
|
|
|
|
|
if include_tests: |
|
|
prompt += "\n\nInclude unit tests for the code." |
|
|
|
|
|
return self.generate(prompt, mode="code", max_new_tokens=2048) |
|
|
|
|
|
def function_call( |
|
|
self, |
|
|
user_query: str, |
|
|
available_tools: List[Dict[str, Any]] |
|
|
) -> Dict[str, Any]: |
|
|
|
|
|
tools_str = json.dumps(available_tools, indent=2) |
|
|
|
|
|
prompt = f"""Available tools: |
|
|
{tools_str} |
|
|
|
|
|
User query: {user_query} |
|
|
|
|
|
Respond with a JSON object specifying which tool to use and with what parameters: |
|
|
{{"tool": "tool_name", "parameters": {{"param": "value"}}}} |
|
|
|
|
|
Response:""" |
|
|
|
|
|
response = self.generate(prompt, mode="precise", temperature=0.2) |
|
|
|
|
|
try: |
|
|
import re |
|
|
json_match = re.search(r'\{.*\}', response, re.DOTALL) |
|
|
if json_match: |
|
|
return json.loads(json_match.group()) |
|
|
return {"error": "No valid JSON found", "raw": response} |
|
|
except json.JSONDecodeError as e: |
|
|
return {"error": str(e), "raw": response} |
|
|
|
|
|
def multi_modal_rag( |
|
|
self, |
|
|
query: str, |
|
|
documents: List[str], |
|
|
images: Optional[List[Image.Image]] = None |
|
|
) -> str: |
|
|
|
|
|
context = "\n\n".join([f"Document {i+1}:\n{doc}" for i, doc in enumerate(documents)]) |
|
|
|
|
|
prompt = f"""Context:\n{context}\n\nQuestion: {query}\n\nAnswer based on the provided context:""" |
|
|
|
|
|
return self.generate(prompt, images=images, mode="precise", max_new_tokens=1024) |
|
|
|
|
|
def get_metrics_summary(self) -> Dict[str, float]: |
|
|
|
|
|
if not self.metrics_history: |
|
|
return {} |
|
|
|
|
|
return { |
|
|
"avg_latency_ms": sum(m.latency_ms for m in self.metrics_history) / len(self.metrics_history), |
|
|
"avg_tokens_per_second": sum(m.tokens_per_second for m in self.metrics_history) / len(self.metrics_history), |
|
|
"avg_memory_used_gb": sum(m.memory_used_gb for m in self.metrics_history) / len(self.metrics_history), |
|
|
"total_tokens_generated": sum(m.tokens_generated for m in self.metrics_history), |
|
|
"num_requests": len(self.metrics_history) |
|
|
} |
|
|
|
|
|
def clear_cache(self): |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
self.model.clear_cache() if hasattr(self.model, 'clear_cache') else None |
|
|
logger.info("Cache cleared") |
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Advanced Helion-V2.0-Thinking Inference") |
|
|
parser.add_argument("--model", type=str, default="DeepXR/Helion-V2.0-Thinking") |
|
|
parser.add_argument("--quantization", type=str, choices=["4bit", "8bit", None], default=None) |
|
|
parser.add_argument("--mode", type=str, default="balanced", choices=["creative", "precise", "balanced", "code"]) |
|
|
parser.add_argument("--prompt", type=str, help="Text prompt") |
|
|
parser.add_argument("--image", type=str, help="Path to image file") |
|
|
parser.add_argument("--stream", action="store_true", help="Enable streaming output") |
|
|
parser.add_argument("--torch-compile", action="store_true", help="Use torch.compile") |
|
|
parser.add_argument("--benchmark", action="store_true", help="Run benchmark") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
model = AdvancedHelionInference( |
|
|
model_name=args.model, |
|
|
quantization=args.quantization, |
|
|
torch_compile=args.torch_compile |
|
|
) |
|
|
|
|
|
if args.benchmark: |
|
|
print("Running benchmark...") |
|
|
|
|
|
test_prompts = [ |
|
|
"Explain quantum computing in simple terms.", |
|
|
"Write a Python function to calculate fibonacci numbers.", |
|
|
"What are the main causes of climate change?" |
|
|
] |
|
|
|
|
|
for prompt in test_prompts: |
|
|
response, metrics = model.generate( |
|
|
prompt, |
|
|
mode=args.mode, |
|
|
return_metrics=True |
|
|
) |
|
|
print(f"\nPrompt: {prompt}") |
|
|
print(f"Response: {response[:100]}...") |
|
|
print(f"Metrics: {metrics}") |
|
|
|
|
|
summary = model.get_metrics_summary() |
|
|
print(f"\nBenchmark Summary:") |
|
|
for key, value in summary.items(): |
|
|
print(f" {key}: {value:.2f}") |
|
|
|
|
|
elif args.prompt: |
|
|
image = Image.open(args.image) if args.image else None |
|
|
|
|
|
if args.stream: |
|
|
print("Streaming response:") |
|
|
for text in model.generate(args.prompt, images=image, mode=args.mode, stream=True): |
|
|
print(text, end="", flush=True) |
|
|
print() |
|
|
else: |
|
|
response, metrics = model.generate( |
|
|
args.prompt, |
|
|
images=image, |
|
|
mode=args.mode, |
|
|
return_metrics=True |
|
|
) |
|
|
print(f"Response: {response}") |
|
|
print(f"\nMetrics:") |
|
|
print(f" Latency: {metrics.latency_ms:.2f}ms") |
|
|
print(f" Tokens/sec: {metrics.tokens_per_second:.2f}") |
|
|
print(f" Tokens generated: {metrics.tokens_generated}") |
|
|
else: |
|
|
print("Please provide --prompt or use --benchmark") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |