|
|
""" |
|
|
Helion-V2.0-Thinking Inference Script |
|
|
A comprehensive example showing different ways to use the multimodal model |
|
|
with vision, tool use, and structured output capabilities |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from transformers import ( |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
AutoProcessor, |
|
|
BitsAndBytesConfig |
|
|
) |
|
|
from PIL import Image |
|
|
import requests |
|
|
from typing import Optional, List, Dict, Any |
|
|
import argparse |
|
|
import json |
|
|
import re |
|
|
|
|
|
|
|
|
class HelionInference: |
|
|
"""Wrapper class for Helion-V2.0-Thinking multimodal model inference""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "DeepXR/Helion-V2.0-Thinking", |
|
|
device: str = "auto", |
|
|
load_in_8bit: bool = False, |
|
|
load_in_4bit: bool = False, |
|
|
use_flash_attention: bool = True |
|
|
): |
|
|
""" |
|
|
Initialize the model, tokenizer, and processor |
|
|
|
|
|
Args: |
|
|
model_name: HuggingFace model identifier |
|
|
device: Device to load model on ('auto', 'cuda', 'cpu') |
|
|
load_in_8bit: Enable 8-bit quantization |
|
|
load_in_4bit: Enable 4-bit quantization |
|
|
use_flash_attention: Use Flash Attention 2 for efficiency |
|
|
""" |
|
|
print(f"Loading {model_name}...") |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
self.processor = AutoProcessor.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
quantization_config = None |
|
|
if load_in_4bit: |
|
|
quantization_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
|
bnb_4bit_use_double_quant=True, |
|
|
bnb_4bit_quant_type="nf4" |
|
|
) |
|
|
elif load_in_8bit: |
|
|
quantization_config = BitsAndBytesConfig(load_in_8bit=True) |
|
|
|
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map=device, |
|
|
quantization_config=quantization_config, |
|
|
use_flash_attention_2=use_flash_attention, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
self.model.eval() |
|
|
print("Model loaded successfully!") |
|
|
|
|
|
|
|
|
self.tools = self._initialize_tools() |
|
|
|
|
|
def _initialize_tools(self) -> List[Dict[str, Any]]: |
|
|
"""Initialize available tools for function calling""" |
|
|
return [ |
|
|
{ |
|
|
"name": "calculator", |
|
|
"description": "Perform mathematical calculations", |
|
|
"parameters": { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"expression": { |
|
|
"type": "string", |
|
|
"description": "Mathematical expression to evaluate" |
|
|
} |
|
|
}, |
|
|
"required": ["expression"] |
|
|
} |
|
|
}, |
|
|
{ |
|
|
"name": "web_search", |
|
|
"description": "Search the web for current information", |
|
|
"parameters": { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"query": { |
|
|
"type": "string", |
|
|
"description": "The search query" |
|
|
} |
|
|
}, |
|
|
"required": ["query"] |
|
|
} |
|
|
}, |
|
|
{ |
|
|
"name": "code_executor", |
|
|
"description": "Execute Python code safely", |
|
|
"parameters": { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"code": { |
|
|
"type": "string", |
|
|
"description": "Python code to execute" |
|
|
} |
|
|
}, |
|
|
"required": ["code"] |
|
|
} |
|
|
} |
|
|
] |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
prompt: str, |
|
|
max_new_tokens: int = 512, |
|
|
temperature: float = 0.7, |
|
|
top_p: float = 0.9, |
|
|
top_k: int = 50, |
|
|
repetition_penalty: float = 1.1, |
|
|
do_sample: bool = True, |
|
|
images: Optional[List[Image.Image]] = None |
|
|
) -> str: |
|
|
""" |
|
|
Generate text from a prompt with optional images |
|
|
|
|
|
Args: |
|
|
prompt: Input text |
|
|
max_new_tokens: Maximum tokens to generate |
|
|
temperature: Sampling temperature |
|
|
top_p: Nucleus sampling threshold |
|
|
top_k: Top-k sampling parameter |
|
|
repetition_penalty: Penalty for repeating tokens |
|
|
do_sample: Use sampling vs greedy decoding |
|
|
images: Optional list of PIL images |
|
|
|
|
|
Returns: |
|
|
Generated text |
|
|
""" |
|
|
if images: |
|
|
inputs = self.processor( |
|
|
text=prompt, |
|
|
images=images, |
|
|
return_tensors="pt" |
|
|
).to(self.model.device) |
|
|
else: |
|
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
top_k=top_k, |
|
|
repetition_penalty=repetition_penalty, |
|
|
do_sample=do_sample, |
|
|
pad_token_id=self.tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
if images: |
|
|
generated_text = self.processor.decode(outputs[0], skip_special_tokens=True) |
|
|
else: |
|
|
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
if generated_text.startswith(prompt): |
|
|
generated_text = generated_text[len(prompt):].strip() |
|
|
|
|
|
return generated_text |
|
|
|
|
|
def analyze_image( |
|
|
self, |
|
|
image: Image.Image, |
|
|
query: str = "Describe this image in detail.", |
|
|
max_new_tokens: int = 512 |
|
|
) -> str: |
|
|
""" |
|
|
Analyze an image with a specific query |
|
|
|
|
|
Args: |
|
|
image: PIL Image object |
|
|
query: Question or instruction about the image |
|
|
max_new_tokens: Maximum tokens to generate |
|
|
|
|
|
Returns: |
|
|
Image analysis response |
|
|
""" |
|
|
return self.generate( |
|
|
prompt=query, |
|
|
images=[image], |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=0.7 |
|
|
) |
|
|
|
|
|
def extract_text_from_image( |
|
|
self, |
|
|
image: Image.Image |
|
|
) -> str: |
|
|
""" |
|
|
Perform OCR on an image |
|
|
|
|
|
Args: |
|
|
image: PIL Image object |
|
|
|
|
|
Returns: |
|
|
Extracted text |
|
|
""" |
|
|
prompt = "Extract all text from this image. Return only the text content without any additional commentary." |
|
|
return self.generate( |
|
|
prompt=prompt, |
|
|
images=[image], |
|
|
max_new_tokens=1024, |
|
|
temperature=0.3 |
|
|
) |
|
|
|
|
|
def call_function( |
|
|
self, |
|
|
prompt: str, |
|
|
tools: Optional[List[Dict[str, Any]]] = None |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Use function calling to determine which tool to use |
|
|
|
|
|
Args: |
|
|
prompt: User query |
|
|
tools: List of available tools (uses default if None) |
|
|
|
|
|
Returns: |
|
|
Dict with tool name and parameters |
|
|
""" |
|
|
if tools is None: |
|
|
tools = self.tools |
|
|
|
|
|
system_prompt = f"""You are a helpful assistant with access to the following tools: |
|
|
{json.dumps(tools, indent=2)} |
|
|
|
|
|
To use a tool, respond with ONLY a JSON object in this exact format: |
|
|
{{"tool": "tool_name", "parameters": {{"param": "value"}}}} |
|
|
|
|
|
Do not include any other text or explanation.""" |
|
|
|
|
|
full_prompt = f"{system_prompt}\n\nUser query: {prompt}\n\nTool call:" |
|
|
|
|
|
response = self.generate( |
|
|
prompt=full_prompt, |
|
|
max_new_tokens=256, |
|
|
temperature=0.2, |
|
|
do_sample=False |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
json_match = re.search(r'\{.*\}', response, re.DOTALL) |
|
|
if json_match: |
|
|
tool_call = json.loads(json_match.group()) |
|
|
return tool_call |
|
|
else: |
|
|
return {"error": "No valid JSON found in response", "raw": response} |
|
|
except json.JSONDecodeError as e: |
|
|
return {"error": f"JSON decode error: {str(e)}", "raw": response} |
|
|
|
|
|
def structured_output( |
|
|
self, |
|
|
prompt: str, |
|
|
schema: Dict[str, Any] |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Generate structured JSON output matching a schema |
|
|
|
|
|
Args: |
|
|
prompt: Input prompt |
|
|
schema: JSON schema for the output |
|
|
|
|
|
Returns: |
|
|
Parsed JSON response |
|
|
""" |
|
|
full_prompt = f"""Generate a JSON response matching this schema: |
|
|
{json.dumps(schema, indent=2)} |
|
|
|
|
|
User request: {prompt} |
|
|
|
|
|
Return ONLY valid JSON, no other text:""" |
|
|
|
|
|
response = self.generate( |
|
|
prompt=full_prompt, |
|
|
max_new_tokens=1024, |
|
|
temperature=0.2, |
|
|
do_sample=False |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
if "```json" in response: |
|
|
json_str = response.split("```json")[-1].split("```")[0].strip() |
|
|
elif "```" in response: |
|
|
json_str = response.split("```")[1].strip() |
|
|
else: |
|
|
json_str = response.strip() |
|
|
|
|
|
return json.loads(json_str) |
|
|
except json.JSONDecodeError as e: |
|
|
return {"error": f"JSON decode error: {str(e)}", "raw": response} |
|
|
|
|
|
def chat( |
|
|
self, |
|
|
messages: List[Dict[str, Any]], |
|
|
max_new_tokens: int = 512, |
|
|
temperature: float = 0.7, |
|
|
top_p: float = 0.9 |
|
|
) -> str: |
|
|
""" |
|
|
Chat interface using conversation format with support for images |
|
|
|
|
|
Args: |
|
|
messages: List of message dicts with 'role', 'content', and optional 'images' keys |
|
|
max_new_tokens: Maximum tokens to generate |
|
|
temperature: Sampling temperature |
|
|
top_p: Nucleus sampling threshold |
|
|
|
|
|
Returns: |
|
|
Assistant's response |
|
|
""" |
|
|
|
|
|
all_images = [] |
|
|
for msg in messages: |
|
|
if "images" in msg and msg["images"]: |
|
|
all_images.extend(msg["images"]) |
|
|
|
|
|
|
|
|
prompt = self.processor.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
return self.generate( |
|
|
prompt=prompt, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
images=all_images if all_images else None |
|
|
) |
|
|
|
|
|
def interactive_chat(self): |
|
|
"""Run an interactive chat session with multimodal support""" |
|
|
print("\n" + "="*60) |
|
|
print("Helion-V2.0-Thinking Interactive Chat") |
|
|
print("Commands:") |
|
|
print(" - Type 'exit' or 'quit' to end") |
|
|
print(" - Type 'image <path>' to add an image") |
|
|
print(" - Type 'clear' to reset conversation") |
|
|
print("="*60 + "\n") |
|
|
|
|
|
conversation_history = [] |
|
|
|
|
|
while True: |
|
|
user_input = input("You: ").strip() |
|
|
|
|
|
if user_input.lower() in ['exit', 'quit', 'q']: |
|
|
print("Goodbye!") |
|
|
break |
|
|
|
|
|
if user_input.lower() == 'clear': |
|
|
conversation_history = [] |
|
|
print("Conversation cleared.\n") |
|
|
continue |
|
|
|
|
|
if not user_input: |
|
|
continue |
|
|
|
|
|
|
|
|
images = [] |
|
|
if user_input.lower().startswith('image '): |
|
|
image_path = user_input[6:].strip() |
|
|
try: |
|
|
image = Image.open(image_path) |
|
|
images.append(image) |
|
|
print(f"Image loaded: {image_path}") |
|
|
user_input = input("Your question about the image: ").strip() |
|
|
except Exception as e: |
|
|
print(f"Error loading image: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
message = { |
|
|
"role": "user", |
|
|
"content": user_input |
|
|
} |
|
|
if images: |
|
|
message["images"] = images |
|
|
|
|
|
conversation_history.append(message) |
|
|
|
|
|
|
|
|
try: |
|
|
response = self.chat(conversation_history) |
|
|
|
|
|
|
|
|
conversation_history.append({ |
|
|
"role": "assistant", |
|
|
"content": response |
|
|
}) |
|
|
|
|
|
print(f"\nAssistant: {response}\n") |
|
|
except Exception as e: |
|
|
print(f"Error generating response: {e}\n") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Helion-V2.0-Thinking Multimodal Inference" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model", |
|
|
type=str, |
|
|
default="DeepXR/Helion-V2.0-Thinking", |
|
|
help="Model name or path" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--prompt", |
|
|
type=str, |
|
|
help="Input prompt for generation" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--image", |
|
|
type=str, |
|
|
help="Path to image file" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--interactive", |
|
|
action="store_true", |
|
|
help="Start interactive chat mode" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--load-in-8bit", |
|
|
action="store_true", |
|
|
help="Load model in 8-bit precision" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--load-in-4bit", |
|
|
action="store_true", |
|
|
help="Load model in 4-bit precision" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max-tokens", |
|
|
type=int, |
|
|
default=512, |
|
|
help="Maximum tokens to generate" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--temperature", |
|
|
type=float, |
|
|
default=0.7, |
|
|
help="Sampling temperature" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--demo", |
|
|
action="store_true", |
|
|
help="Run demonstration examples" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
model = HelionInference( |
|
|
model_name=args.model, |
|
|
load_in_8bit=args.load_in_8bit, |
|
|
load_in_4bit=args.load_in_4bit |
|
|
) |
|
|
|
|
|
|
|
|
if args.interactive: |
|
|
model.interactive_chat() |
|
|
elif args.demo: |
|
|
print("\n" + "="*60) |
|
|
print("Running Demonstration Examples") |
|
|
print("="*60 + "\n") |
|
|
|
|
|
|
|
|
print("1. Text Generation:") |
|
|
print("-" * 40) |
|
|
response = model.generate( |
|
|
"Explain quantum entanglement in simple terms:", |
|
|
max_new_tokens=256 |
|
|
) |
|
|
print(f"Response: {response}\n") |
|
|
|
|
|
|
|
|
print("2. Function Calling:") |
|
|
print("-" * 40) |
|
|
tool_call = model.call_function( |
|
|
"What is 45 multiplied by 23, plus 156?" |
|
|
) |
|
|
print(f"Tool call: {json.dumps(tool_call, indent=2)}\n") |
|
|
|
|
|
|
|
|
print("3. Structured Output:") |
|
|
print("-" * 40) |
|
|
schema = { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"summary": {"type": "string"}, |
|
|
"sentiment": {"type": "string", "enum": ["positive", "negative", "neutral"]}, |
|
|
"key_points": {"type": "array", "items": {"type": "string"}} |
|
|
} |
|
|
} |
|
|
structured = model.structured_output( |
|
|
"Analyze this: The new product launch was highly successful.", |
|
|
schema |
|
|
) |
|
|
print(f"Structured output: {json.dumps(structured, indent=2)}\n") |
|
|
|
|
|
elif args.image: |
|
|
|
|
|
try: |
|
|
image = Image.open(args.image) |
|
|
prompt = args.prompt or "Describe this image in detail." |
|
|
response = model.analyze_image(image, prompt, args.max_tokens) |
|
|
print(f"\nImage: {args.image}") |
|
|
print(f"Query: {prompt}") |
|
|
print(f"Response: {response}\n") |
|
|
except Exception as e: |
|
|
print(f"Error processing image: {e}") |
|
|
|
|
|
elif args.prompt: |
|
|
response = model.generate( |
|
|
prompt=args.prompt, |
|
|
max_new_tokens=args.max_tokens, |
|
|
temperature=args.temperature |
|
|
) |
|
|
print(f"\nPrompt: {args.prompt}") |
|
|
print(f"Response: {response}\n") |
|
|
else: |
|
|
print("Please specify --interactive, --demo, --prompt, or --image") |
|
|
print("Use --help for more information") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |