Helion-V2.0-Thinking / inference.py
Trouter-Library's picture
Create inference.py
1966e56 verified
raw
history blame
17.3 kB
"""
Helion-V2.0-Thinking Inference Script
A comprehensive example showing different ways to use the multimodal model
with vision, tool use, and structured output capabilities
"""
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
AutoProcessor,
BitsAndBytesConfig
)
from PIL import Image
import requests
from typing import Optional, List, Dict, Any
import argparse
import json
import re
class HelionInference:
"""Wrapper class for Helion-V2.0-Thinking multimodal model inference"""
def __init__(
self,
model_name: str = "DeepXR/Helion-V2.0-Thinking",
device: str = "auto",
load_in_8bit: bool = False,
load_in_4bit: bool = False,
use_flash_attention: bool = True
):
"""
Initialize the model, tokenizer, and processor
Args:
model_name: HuggingFace model identifier
device: Device to load model on ('auto', 'cuda', 'cpu')
load_in_8bit: Enable 8-bit quantization
load_in_4bit: Enable 4-bit quantization
use_flash_attention: Use Flash Attention 2 for efficiency
"""
print(f"Loading {model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.processor = AutoProcessor.from_pretrained(model_name)
# Configure quantization if requested
quantization_config = None
if load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
elif load_in_8bit:
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
# Load model
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
quantization_config=quantization_config,
use_flash_attention_2=use_flash_attention,
trust_remote_code=True
)
self.model.eval()
print("Model loaded successfully!")
# Tool definitions
self.tools = self._initialize_tools()
def _initialize_tools(self) -> List[Dict[str, Any]]:
"""Initialize available tools for function calling"""
return [
{
"name": "calculator",
"description": "Perform mathematical calculations",
"parameters": {
"type": "object",
"properties": {
"expression": {
"type": "string",
"description": "Mathematical expression to evaluate"
}
},
"required": ["expression"]
}
},
{
"name": "web_search",
"description": "Search the web for current information",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query"
}
},
"required": ["query"]
}
},
{
"name": "code_executor",
"description": "Execute Python code safely",
"parameters": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "Python code to execute"
}
},
"required": ["code"]
}
}
]
def generate(
self,
prompt: str,
max_new_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.1,
do_sample: bool = True,
images: Optional[List[Image.Image]] = None
) -> str:
"""
Generate text from a prompt with optional images
Args:
prompt: Input text
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Nucleus sampling threshold
top_k: Top-k sampling parameter
repetition_penalty: Penalty for repeating tokens
do_sample: Use sampling vs greedy decoding
images: Optional list of PIL images
Returns:
Generated text
"""
if images:
inputs = self.processor(
text=prompt,
images=images,
return_tensors="pt"
).to(self.model.device)
else:
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
do_sample=do_sample,
pad_token_id=self.tokenizer.eos_token_id
)
# Decode and return
if images:
generated_text = self.processor.decode(outputs[0], skip_special_tokens=True)
else:
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove the prompt from output
if generated_text.startswith(prompt):
generated_text = generated_text[len(prompt):].strip()
return generated_text
def analyze_image(
self,
image: Image.Image,
query: str = "Describe this image in detail.",
max_new_tokens: int = 512
) -> str:
"""
Analyze an image with a specific query
Args:
image: PIL Image object
query: Question or instruction about the image
max_new_tokens: Maximum tokens to generate
Returns:
Image analysis response
"""
return self.generate(
prompt=query,
images=[image],
max_new_tokens=max_new_tokens,
temperature=0.7
)
def extract_text_from_image(
self,
image: Image.Image
) -> str:
"""
Perform OCR on an image
Args:
image: PIL Image object
Returns:
Extracted text
"""
prompt = "Extract all text from this image. Return only the text content without any additional commentary."
return self.generate(
prompt=prompt,
images=[image],
max_new_tokens=1024,
temperature=0.3
)
def call_function(
self,
prompt: str,
tools: Optional[List[Dict[str, Any]]] = None
) -> Dict[str, Any]:
"""
Use function calling to determine which tool to use
Args:
prompt: User query
tools: List of available tools (uses default if None)
Returns:
Dict with tool name and parameters
"""
if tools is None:
tools = self.tools
system_prompt = f"""You are a helpful assistant with access to the following tools:
{json.dumps(tools, indent=2)}
To use a tool, respond with ONLY a JSON object in this exact format:
{{"tool": "tool_name", "parameters": {{"param": "value"}}}}
Do not include any other text or explanation."""
full_prompt = f"{system_prompt}\n\nUser query: {prompt}\n\nTool call:"
response = self.generate(
prompt=full_prompt,
max_new_tokens=256,
temperature=0.2,
do_sample=False
)
# Parse JSON response
try:
# Extract JSON from response
json_match = re.search(r'\{.*\}', response, re.DOTALL)
if json_match:
tool_call = json.loads(json_match.group())
return tool_call
else:
return {"error": "No valid JSON found in response", "raw": response}
except json.JSONDecodeError as e:
return {"error": f"JSON decode error: {str(e)}", "raw": response}
def structured_output(
self,
prompt: str,
schema: Dict[str, Any]
) -> Dict[str, Any]:
"""
Generate structured JSON output matching a schema
Args:
prompt: Input prompt
schema: JSON schema for the output
Returns:
Parsed JSON response
"""
full_prompt = f"""Generate a JSON response matching this schema:
{json.dumps(schema, indent=2)}
User request: {prompt}
Return ONLY valid JSON, no other text:"""
response = self.generate(
prompt=full_prompt,
max_new_tokens=1024,
temperature=0.2,
do_sample=False
)
# Parse JSON response
try:
# Try to extract JSON from markdown code blocks
if "```json" in response:
json_str = response.split("```json")[-1].split("```")[0].strip()
elif "```" in response:
json_str = response.split("```")[1].strip()
else:
json_str = response.strip()
return json.loads(json_str)
except json.JSONDecodeError as e:
return {"error": f"JSON decode error: {str(e)}", "raw": response}
def chat(
self,
messages: List[Dict[str, Any]],
max_new_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9
) -> str:
"""
Chat interface using conversation format with support for images
Args:
messages: List of message dicts with 'role', 'content', and optional 'images' keys
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Nucleus sampling threshold
Returns:
Assistant's response
"""
# Extract images from messages
all_images = []
for msg in messages:
if "images" in msg and msg["images"]:
all_images.extend(msg["images"])
# Apply chat template
prompt = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
return self.generate(
prompt=prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
images=all_images if all_images else None
)
def interactive_chat(self):
"""Run an interactive chat session with multimodal support"""
print("\n" + "="*60)
print("Helion-V2.0-Thinking Interactive Chat")
print("Commands:")
print(" - Type 'exit' or 'quit' to end")
print(" - Type 'image <path>' to add an image")
print(" - Type 'clear' to reset conversation")
print("="*60 + "\n")
conversation_history = []
while True:
user_input = input("You: ").strip()
if user_input.lower() in ['exit', 'quit', 'q']:
print("Goodbye!")
break
if user_input.lower() == 'clear':
conversation_history = []
print("Conversation cleared.\n")
continue
if not user_input:
continue
# Check for image command
images = []
if user_input.lower().startswith('image '):
image_path = user_input[6:].strip()
try:
image = Image.open(image_path)
images.append(image)
print(f"Image loaded: {image_path}")
user_input = input("Your question about the image: ").strip()
except Exception as e:
print(f"Error loading image: {e}")
continue
# Add user message to history
message = {
"role": "user",
"content": user_input
}
if images:
message["images"] = images
conversation_history.append(message)
# Generate response
try:
response = self.chat(conversation_history)
# Add assistant response to history
conversation_history.append({
"role": "assistant",
"content": response
})
print(f"\nAssistant: {response}\n")
except Exception as e:
print(f"Error generating response: {e}\n")
def main():
parser = argparse.ArgumentParser(
description="Helion-V2.0-Thinking Multimodal Inference"
)
parser.add_argument(
"--model",
type=str,
default="DeepXR/Helion-V2.0-Thinking",
help="Model name or path"
)
parser.add_argument(
"--prompt",
type=str,
help="Input prompt for generation"
)
parser.add_argument(
"--image",
type=str,
help="Path to image file"
)
parser.add_argument(
"--interactive",
action="store_true",
help="Start interactive chat mode"
)
parser.add_argument(
"--load-in-8bit",
action="store_true",
help="Load model in 8-bit precision"
)
parser.add_argument(
"--load-in-4bit",
action="store_true",
help="Load model in 4-bit precision"
)
parser.add_argument(
"--max-tokens",
type=int,
default=512,
help="Maximum tokens to generate"
)
parser.add_argument(
"--temperature",
type=float,
default=0.7,
help="Sampling temperature"
)
parser.add_argument(
"--demo",
action="store_true",
help="Run demonstration examples"
)
args = parser.parse_args()
# Initialize model
model = HelionInference(
model_name=args.model,
load_in_8bit=args.load_in_8bit,
load_in_4bit=args.load_in_4bit
)
# Run interactive mode or examples
if args.interactive:
model.interactive_chat()
elif args.demo:
print("\n" + "="*60)
print("Running Demonstration Examples")
print("="*60 + "\n")
# Text generation example
print("1. Text Generation:")
print("-" * 40)
response = model.generate(
"Explain quantum entanglement in simple terms:",
max_new_tokens=256
)
print(f"Response: {response}\n")
# Function calling example
print("2. Function Calling:")
print("-" * 40)
tool_call = model.call_function(
"What is 45 multiplied by 23, plus 156?"
)
print(f"Tool call: {json.dumps(tool_call, indent=2)}\n")
# Structured output example
print("3. Structured Output:")
print("-" * 40)
schema = {
"type": "object",
"properties": {
"summary": {"type": "string"},
"sentiment": {"type": "string", "enum": ["positive", "negative", "neutral"]},
"key_points": {"type": "array", "items": {"type": "string"}}
}
}
structured = model.structured_output(
"Analyze this: The new product launch was highly successful.",
schema
)
print(f"Structured output: {json.dumps(structured, indent=2)}\n")
elif args.image:
# Image analysis
try:
image = Image.open(args.image)
prompt = args.prompt or "Describe this image in detail."
response = model.analyze_image(image, prompt, args.max_tokens)
print(f"\nImage: {args.image}")
print(f"Query: {prompt}")
print(f"Response: {response}\n")
except Exception as e:
print(f"Error processing image: {e}")
elif args.prompt:
response = model.generate(
prompt=args.prompt,
max_new_tokens=args.max_tokens,
temperature=args.temperature
)
print(f"\nPrompt: {args.prompt}")
print(f"Response: {response}\n")
else:
print("Please specify --interactive, --demo, --prompt, or --image")
print("Use --help for more information")
if __name__ == "__main__":
main()