Trouter-Library commited on
Commit
1966e56
·
verified ·
1 Parent(s): 720e2e5

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +553 -0
inference.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helion-V2.0-Thinking Inference Script
3
+ A comprehensive example showing different ways to use the multimodal model
4
+ with vision, tool use, and structured output capabilities
5
+ """
6
+
7
+ import torch
8
+ from transformers import (
9
+ AutoModelForCausalLM,
10
+ AutoTokenizer,
11
+ AutoProcessor,
12
+ BitsAndBytesConfig
13
+ )
14
+ from PIL import Image
15
+ import requests
16
+ from typing import Optional, List, Dict, Any
17
+ import argparse
18
+ import json
19
+ import re
20
+
21
+
22
+ class HelionInference:
23
+ """Wrapper class for Helion-V2.0-Thinking multimodal model inference"""
24
+
25
+ def __init__(
26
+ self,
27
+ model_name: str = "DeepXR/Helion-V2.0-Thinking",
28
+ device: str = "auto",
29
+ load_in_8bit: bool = False,
30
+ load_in_4bit: bool = False,
31
+ use_flash_attention: bool = True
32
+ ):
33
+ """
34
+ Initialize the model, tokenizer, and processor
35
+
36
+ Args:
37
+ model_name: HuggingFace model identifier
38
+ device: Device to load model on ('auto', 'cuda', 'cpu')
39
+ load_in_8bit: Enable 8-bit quantization
40
+ load_in_4bit: Enable 4-bit quantization
41
+ use_flash_attention: Use Flash Attention 2 for efficiency
42
+ """
43
+ print(f"Loading {model_name}...")
44
+
45
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
46
+ self.processor = AutoProcessor.from_pretrained(model_name)
47
+
48
+ # Configure quantization if requested
49
+ quantization_config = None
50
+ if load_in_4bit:
51
+ quantization_config = BitsAndBytesConfig(
52
+ load_in_4bit=True,
53
+ bnb_4bit_compute_dtype=torch.bfloat16,
54
+ bnb_4bit_use_double_quant=True,
55
+ bnb_4bit_quant_type="nf4"
56
+ )
57
+ elif load_in_8bit:
58
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
59
+
60
+ # Load model
61
+ self.model = AutoModelForCausalLM.from_pretrained(
62
+ model_name,
63
+ torch_dtype=torch.bfloat16,
64
+ device_map=device,
65
+ quantization_config=quantization_config,
66
+ use_flash_attention_2=use_flash_attention,
67
+ trust_remote_code=True
68
+ )
69
+
70
+ self.model.eval()
71
+ print("Model loaded successfully!")
72
+
73
+ # Tool definitions
74
+ self.tools = self._initialize_tools()
75
+
76
+ def _initialize_tools(self) -> List[Dict[str, Any]]:
77
+ """Initialize available tools for function calling"""
78
+ return [
79
+ {
80
+ "name": "calculator",
81
+ "description": "Perform mathematical calculations",
82
+ "parameters": {
83
+ "type": "object",
84
+ "properties": {
85
+ "expression": {
86
+ "type": "string",
87
+ "description": "Mathematical expression to evaluate"
88
+ }
89
+ },
90
+ "required": ["expression"]
91
+ }
92
+ },
93
+ {
94
+ "name": "web_search",
95
+ "description": "Search the web for current information",
96
+ "parameters": {
97
+ "type": "object",
98
+ "properties": {
99
+ "query": {
100
+ "type": "string",
101
+ "description": "The search query"
102
+ }
103
+ },
104
+ "required": ["query"]
105
+ }
106
+ },
107
+ {
108
+ "name": "code_executor",
109
+ "description": "Execute Python code safely",
110
+ "parameters": {
111
+ "type": "object",
112
+ "properties": {
113
+ "code": {
114
+ "type": "string",
115
+ "description": "Python code to execute"
116
+ }
117
+ },
118
+ "required": ["code"]
119
+ }
120
+ }
121
+ ]
122
+
123
+ def generate(
124
+ self,
125
+ prompt: str,
126
+ max_new_tokens: int = 512,
127
+ temperature: float = 0.7,
128
+ top_p: float = 0.9,
129
+ top_k: int = 50,
130
+ repetition_penalty: float = 1.1,
131
+ do_sample: bool = True,
132
+ images: Optional[List[Image.Image]] = None
133
+ ) -> str:
134
+ """
135
+ Generate text from a prompt with optional images
136
+
137
+ Args:
138
+ prompt: Input text
139
+ max_new_tokens: Maximum tokens to generate
140
+ temperature: Sampling temperature
141
+ top_p: Nucleus sampling threshold
142
+ top_k: Top-k sampling parameter
143
+ repetition_penalty: Penalty for repeating tokens
144
+ do_sample: Use sampling vs greedy decoding
145
+ images: Optional list of PIL images
146
+
147
+ Returns:
148
+ Generated text
149
+ """
150
+ if images:
151
+ inputs = self.processor(
152
+ text=prompt,
153
+ images=images,
154
+ return_tensors="pt"
155
+ ).to(self.model.device)
156
+ else:
157
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
158
+
159
+ with torch.no_grad():
160
+ outputs = self.model.generate(
161
+ **inputs,
162
+ max_new_tokens=max_new_tokens,
163
+ temperature=temperature,
164
+ top_p=top_p,
165
+ top_k=top_k,
166
+ repetition_penalty=repetition_penalty,
167
+ do_sample=do_sample,
168
+ pad_token_id=self.tokenizer.eos_token_id
169
+ )
170
+
171
+ # Decode and return
172
+ if images:
173
+ generated_text = self.processor.decode(outputs[0], skip_special_tokens=True)
174
+ else:
175
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
176
+
177
+ # Remove the prompt from output
178
+ if generated_text.startswith(prompt):
179
+ generated_text = generated_text[len(prompt):].strip()
180
+
181
+ return generated_text
182
+
183
+ def analyze_image(
184
+ self,
185
+ image: Image.Image,
186
+ query: str = "Describe this image in detail.",
187
+ max_new_tokens: int = 512
188
+ ) -> str:
189
+ """
190
+ Analyze an image with a specific query
191
+
192
+ Args:
193
+ image: PIL Image object
194
+ query: Question or instruction about the image
195
+ max_new_tokens: Maximum tokens to generate
196
+
197
+ Returns:
198
+ Image analysis response
199
+ """
200
+ return self.generate(
201
+ prompt=query,
202
+ images=[image],
203
+ max_new_tokens=max_new_tokens,
204
+ temperature=0.7
205
+ )
206
+
207
+ def extract_text_from_image(
208
+ self,
209
+ image: Image.Image
210
+ ) -> str:
211
+ """
212
+ Perform OCR on an image
213
+
214
+ Args:
215
+ image: PIL Image object
216
+
217
+ Returns:
218
+ Extracted text
219
+ """
220
+ prompt = "Extract all text from this image. Return only the text content without any additional commentary."
221
+ return self.generate(
222
+ prompt=prompt,
223
+ images=[image],
224
+ max_new_tokens=1024,
225
+ temperature=0.3
226
+ )
227
+
228
+ def call_function(
229
+ self,
230
+ prompt: str,
231
+ tools: Optional[List[Dict[str, Any]]] = None
232
+ ) -> Dict[str, Any]:
233
+ """
234
+ Use function calling to determine which tool to use
235
+
236
+ Args:
237
+ prompt: User query
238
+ tools: List of available tools (uses default if None)
239
+
240
+ Returns:
241
+ Dict with tool name and parameters
242
+ """
243
+ if tools is None:
244
+ tools = self.tools
245
+
246
+ system_prompt = f"""You are a helpful assistant with access to the following tools:
247
+ {json.dumps(tools, indent=2)}
248
+
249
+ To use a tool, respond with ONLY a JSON object in this exact format:
250
+ {{"tool": "tool_name", "parameters": {{"param": "value"}}}}
251
+
252
+ Do not include any other text or explanation."""
253
+
254
+ full_prompt = f"{system_prompt}\n\nUser query: {prompt}\n\nTool call:"
255
+
256
+ response = self.generate(
257
+ prompt=full_prompt,
258
+ max_new_tokens=256,
259
+ temperature=0.2,
260
+ do_sample=False
261
+ )
262
+
263
+ # Parse JSON response
264
+ try:
265
+ # Extract JSON from response
266
+ json_match = re.search(r'\{.*\}', response, re.DOTALL)
267
+ if json_match:
268
+ tool_call = json.loads(json_match.group())
269
+ return tool_call
270
+ else:
271
+ return {"error": "No valid JSON found in response", "raw": response}
272
+ except json.JSONDecodeError as e:
273
+ return {"error": f"JSON decode error: {str(e)}", "raw": response}
274
+
275
+ def structured_output(
276
+ self,
277
+ prompt: str,
278
+ schema: Dict[str, Any]
279
+ ) -> Dict[str, Any]:
280
+ """
281
+ Generate structured JSON output matching a schema
282
+
283
+ Args:
284
+ prompt: Input prompt
285
+ schema: JSON schema for the output
286
+
287
+ Returns:
288
+ Parsed JSON response
289
+ """
290
+ full_prompt = f"""Generate a JSON response matching this schema:
291
+ {json.dumps(schema, indent=2)}
292
+
293
+ User request: {prompt}
294
+
295
+ Return ONLY valid JSON, no other text:"""
296
+
297
+ response = self.generate(
298
+ prompt=full_prompt,
299
+ max_new_tokens=1024,
300
+ temperature=0.2,
301
+ do_sample=False
302
+ )
303
+
304
+ # Parse JSON response
305
+ try:
306
+ # Try to extract JSON from markdown code blocks
307
+ if "```json" in response:
308
+ json_str = response.split("```json")[-1].split("```")[0].strip()
309
+ elif "```" in response:
310
+ json_str = response.split("```")[1].strip()
311
+ else:
312
+ json_str = response.strip()
313
+
314
+ return json.loads(json_str)
315
+ except json.JSONDecodeError as e:
316
+ return {"error": f"JSON decode error: {str(e)}", "raw": response}
317
+
318
+ def chat(
319
+ self,
320
+ messages: List[Dict[str, Any]],
321
+ max_new_tokens: int = 512,
322
+ temperature: float = 0.7,
323
+ top_p: float = 0.9
324
+ ) -> str:
325
+ """
326
+ Chat interface using conversation format with support for images
327
+
328
+ Args:
329
+ messages: List of message dicts with 'role', 'content', and optional 'images' keys
330
+ max_new_tokens: Maximum tokens to generate
331
+ temperature: Sampling temperature
332
+ top_p: Nucleus sampling threshold
333
+
334
+ Returns:
335
+ Assistant's response
336
+ """
337
+ # Extract images from messages
338
+ all_images = []
339
+ for msg in messages:
340
+ if "images" in msg and msg["images"]:
341
+ all_images.extend(msg["images"])
342
+
343
+ # Apply chat template
344
+ prompt = self.processor.apply_chat_template(
345
+ messages,
346
+ tokenize=False,
347
+ add_generation_prompt=True
348
+ )
349
+
350
+ return self.generate(
351
+ prompt=prompt,
352
+ max_new_tokens=max_new_tokens,
353
+ temperature=temperature,
354
+ top_p=top_p,
355
+ images=all_images if all_images else None
356
+ )
357
+
358
+ def interactive_chat(self):
359
+ """Run an interactive chat session with multimodal support"""
360
+ print("\n" + "="*60)
361
+ print("Helion-V2.0-Thinking Interactive Chat")
362
+ print("Commands:")
363
+ print(" - Type 'exit' or 'quit' to end")
364
+ print(" - Type 'image <path>' to add an image")
365
+ print(" - Type 'clear' to reset conversation")
366
+ print("="*60 + "\n")
367
+
368
+ conversation_history = []
369
+
370
+ while True:
371
+ user_input = input("You: ").strip()
372
+
373
+ if user_input.lower() in ['exit', 'quit', 'q']:
374
+ print("Goodbye!")
375
+ break
376
+
377
+ if user_input.lower() == 'clear':
378
+ conversation_history = []
379
+ print("Conversation cleared.\n")
380
+ continue
381
+
382
+ if not user_input:
383
+ continue
384
+
385
+ # Check for image command
386
+ images = []
387
+ if user_input.lower().startswith('image '):
388
+ image_path = user_input[6:].strip()
389
+ try:
390
+ image = Image.open(image_path)
391
+ images.append(image)
392
+ print(f"Image loaded: {image_path}")
393
+ user_input = input("Your question about the image: ").strip()
394
+ except Exception as e:
395
+ print(f"Error loading image: {e}")
396
+ continue
397
+
398
+ # Add user message to history
399
+ message = {
400
+ "role": "user",
401
+ "content": user_input
402
+ }
403
+ if images:
404
+ message["images"] = images
405
+
406
+ conversation_history.append(message)
407
+
408
+ # Generate response
409
+ try:
410
+ response = self.chat(conversation_history)
411
+
412
+ # Add assistant response to history
413
+ conversation_history.append({
414
+ "role": "assistant",
415
+ "content": response
416
+ })
417
+
418
+ print(f"\nAssistant: {response}\n")
419
+ except Exception as e:
420
+ print(f"Error generating response: {e}\n")
421
+
422
+
423
+ def main():
424
+ parser = argparse.ArgumentParser(
425
+ description="Helion-V2.0-Thinking Multimodal Inference"
426
+ )
427
+ parser.add_argument(
428
+ "--model",
429
+ type=str,
430
+ default="DeepXR/Helion-V2.0-Thinking",
431
+ help="Model name or path"
432
+ )
433
+ parser.add_argument(
434
+ "--prompt",
435
+ type=str,
436
+ help="Input prompt for generation"
437
+ )
438
+ parser.add_argument(
439
+ "--image",
440
+ type=str,
441
+ help="Path to image file"
442
+ )
443
+ parser.add_argument(
444
+ "--interactive",
445
+ action="store_true",
446
+ help="Start interactive chat mode"
447
+ )
448
+ parser.add_argument(
449
+ "--load-in-8bit",
450
+ action="store_true",
451
+ help="Load model in 8-bit precision"
452
+ )
453
+ parser.add_argument(
454
+ "--load-in-4bit",
455
+ action="store_true",
456
+ help="Load model in 4-bit precision"
457
+ )
458
+ parser.add_argument(
459
+ "--max-tokens",
460
+ type=int,
461
+ default=512,
462
+ help="Maximum tokens to generate"
463
+ )
464
+ parser.add_argument(
465
+ "--temperature",
466
+ type=float,
467
+ default=0.7,
468
+ help="Sampling temperature"
469
+ )
470
+ parser.add_argument(
471
+ "--demo",
472
+ action="store_true",
473
+ help="Run demonstration examples"
474
+ )
475
+
476
+ args = parser.parse_args()
477
+
478
+ # Initialize model
479
+ model = HelionInference(
480
+ model_name=args.model,
481
+ load_in_8bit=args.load_in_8bit,
482
+ load_in_4bit=args.load_in_4bit
483
+ )
484
+
485
+ # Run interactive mode or examples
486
+ if args.interactive:
487
+ model.interactive_chat()
488
+ elif args.demo:
489
+ print("\n" + "="*60)
490
+ print("Running Demonstration Examples")
491
+ print("="*60 + "\n")
492
+
493
+ # Text generation example
494
+ print("1. Text Generation:")
495
+ print("-" * 40)
496
+ response = model.generate(
497
+ "Explain quantum entanglement in simple terms:",
498
+ max_new_tokens=256
499
+ )
500
+ print(f"Response: {response}\n")
501
+
502
+ # Function calling example
503
+ print("2. Function Calling:")
504
+ print("-" * 40)
505
+ tool_call = model.call_function(
506
+ "What is 45 multiplied by 23, plus 156?"
507
+ )
508
+ print(f"Tool call: {json.dumps(tool_call, indent=2)}\n")
509
+
510
+ # Structured output example
511
+ print("3. Structured Output:")
512
+ print("-" * 40)
513
+ schema = {
514
+ "type": "object",
515
+ "properties": {
516
+ "summary": {"type": "string"},
517
+ "sentiment": {"type": "string", "enum": ["positive", "negative", "neutral"]},
518
+ "key_points": {"type": "array", "items": {"type": "string"}}
519
+ }
520
+ }
521
+ structured = model.structured_output(
522
+ "Analyze this: The new product launch was highly successful.",
523
+ schema
524
+ )
525
+ print(f"Structured output: {json.dumps(structured, indent=2)}\n")
526
+
527
+ elif args.image:
528
+ # Image analysis
529
+ try:
530
+ image = Image.open(args.image)
531
+ prompt = args.prompt or "Describe this image in detail."
532
+ response = model.analyze_image(image, prompt, args.max_tokens)
533
+ print(f"\nImage: {args.image}")
534
+ print(f"Query: {prompt}")
535
+ print(f"Response: {response}\n")
536
+ except Exception as e:
537
+ print(f"Error processing image: {e}")
538
+
539
+ elif args.prompt:
540
+ response = model.generate(
541
+ prompt=args.prompt,
542
+ max_new_tokens=args.max_tokens,
543
+ temperature=args.temperature
544
+ )
545
+ print(f"\nPrompt: {args.prompt}")
546
+ print(f"Response: {response}\n")
547
+ else:
548
+ print("Please specify --interactive, --demo, --prompt, or --image")
549
+ print("Use --help for more information")
550
+
551
+
552
+ if __name__ == "__main__":
553
+ main()