AlexGall commited on
Commit
6e610ec
·
verified ·
1 Parent(s): cdfeef5

Create advanced_inference.py

Browse files
Files changed (1) hide show
  1. advanced_inference.py +455 -0
advanced_inference.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import (
3
+ AutoModelForCausalLM,
4
+ AutoProcessor,
5
+ BitsAndBytesConfig,
6
+ GenerationConfig
7
+ )
8
+ from PIL import Image
9
+ import json
10
+ from typing import Optional, List, Dict, Any, Union
11
+ import time
12
+ from dataclasses import dataclass
13
+ import logging
14
+
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ @dataclass
20
+ class InferenceMetrics:
21
+ latency_ms: float
22
+ tokens_generated: int
23
+ tokens_per_second: float
24
+ memory_used_gb: float
25
+ input_tokens: int
26
+ total_tokens: int
27
+
28
+
29
+ class AdvancedHelionInference:
30
+
31
+ def __init__(
32
+ self,
33
+ model_name: str = "DeepXR/Helion-V2.0-Thinking",
34
+ quantization: Optional[str] = None,
35
+ device: str = "auto",
36
+ use_flash_attention: bool = True,
37
+ torch_compile: bool = False,
38
+ optimization_mode: str = "balanced"
39
+ ):
40
+ logger.info(f"Initializing Helion-V2.0-Thinking with {optimization_mode} mode")
41
+
42
+ self.model_name = model_name
43
+ self.optimization_mode = optimization_mode
44
+ self.metrics_history = []
45
+
46
+ quantization_config = self._get_quantization_config(quantization)
47
+
48
+ logger.info("Loading processor...")
49
+ self.processor = AutoProcessor.from_pretrained(model_name)
50
+
51
+ logger.info("Loading model...")
52
+ self.model = AutoModelForCausalLM.from_pretrained(
53
+ model_name,
54
+ quantization_config=quantization_config,
55
+ device_map=device,
56
+ torch_dtype=torch.bfloat16 if quantization is None else None,
57
+ use_flash_attention_2=use_flash_attention,
58
+ trust_remote_code=True,
59
+ low_cpu_mem_usage=True
60
+ )
61
+
62
+ if torch_compile and quantization is None:
63
+ logger.info("Compiling model with torch.compile...")
64
+ self.model = torch.compile(self.model, mode="reduce-overhead")
65
+
66
+ self.model.eval()
67
+
68
+ self.generation_configs = {
69
+ "creative": GenerationConfig(
70
+ do_sample=True,
71
+ temperature=0.9,
72
+ top_p=0.95,
73
+ top_k=50,
74
+ repetition_penalty=1.15,
75
+ max_new_tokens=2048
76
+ ),
77
+ "precise": GenerationConfig(
78
+ do_sample=True,
79
+ temperature=0.3,
80
+ top_p=0.85,
81
+ top_k=40,
82
+ repetition_penalty=1.05,
83
+ max_new_tokens=1024
84
+ ),
85
+ "balanced": GenerationConfig(
86
+ do_sample=True,
87
+ temperature=0.7,
88
+ top_p=0.9,
89
+ top_k=50,
90
+ repetition_penalty=1.1,
91
+ max_new_tokens=1024
92
+ ),
93
+ "code": GenerationConfig(
94
+ do_sample=True,
95
+ temperature=0.2,
96
+ top_p=0.9,
97
+ top_k=40,
98
+ repetition_penalty=1.05,
99
+ max_new_tokens=2048
100
+ )
101
+ }
102
+
103
+ logger.info("Model loaded successfully!")
104
+
105
+ def _get_quantization_config(self, quantization: Optional[str]) -> Optional[BitsAndBytesConfig]:
106
+ if quantization is None:
107
+ return None
108
+
109
+ quantization_configs = {
110
+ "4bit": BitsAndBytesConfig(
111
+ load_in_4bit=True,
112
+ bnb_4bit_compute_dtype=torch.bfloat16,
113
+ bnb_4bit_use_double_quant=True,
114
+ bnb_4bit_quant_type="nf4"
115
+ ),
116
+ "8bit": BitsAndBytesConfig(
117
+ load_in_8bit=True
118
+ )
119
+ }
120
+
121
+ return quantization_configs.get(quantization)
122
+
123
+ def generate(
124
+ self,
125
+ prompt: str,
126
+ images: Optional[Union[Image.Image, List[Image.Image]]] = None,
127
+ mode: str = "balanced",
128
+ max_new_tokens: Optional[int] = None,
129
+ temperature: Optional[float] = None,
130
+ stream: bool = False,
131
+ return_metrics: bool = False,
132
+ **kwargs
133
+ ) -> Union[str, tuple[str, InferenceMetrics]]:
134
+
135
+ if isinstance(images, Image.Image):
136
+ images = [images]
137
+
138
+ start_time = time.time()
139
+ initial_memory = torch.cuda.memory_allocated() / (1024**3) if torch.cuda.is_available() else 0
140
+
141
+ if images:
142
+ inputs = self.processor(
143
+ text=prompt,
144
+ images=images,
145
+ return_tensors="pt"
146
+ ).to(self.model.device)
147
+ else:
148
+ inputs = self.processor(
149
+ text=prompt,
150
+ return_tensors="pt"
151
+ ).to(self.model.device)
152
+
153
+ input_length = inputs['input_ids'].shape[1]
154
+
155
+ gen_config = self.generation_configs[mode].to_dict()
156
+
157
+ if max_new_tokens:
158
+ gen_config['max_new_tokens'] = max_new_tokens
159
+ if temperature:
160
+ gen_config['temperature'] = temperature
161
+
162
+ gen_config.update(kwargs)
163
+
164
+ with torch.no_grad():
165
+ if stream:
166
+ return self._generate_stream(inputs, gen_config, return_metrics)
167
+ else:
168
+ outputs = self.model.generate(
169
+ **inputs,
170
+ **gen_config,
171
+ pad_token_id=self.processor.tokenizer.eos_token_id
172
+ )
173
+
174
+ if torch.cuda.is_available():
175
+ torch.cuda.synchronize()
176
+
177
+ end_time = time.time()
178
+ latency = (end_time - start_time) * 1000
179
+
180
+ response = self.processor.decode(outputs[0], skip_special_tokens=True)
181
+
182
+ if response.startswith(prompt):
183
+ response = response[len(prompt):].strip()
184
+
185
+ tokens_generated = outputs.shape[1] - input_length
186
+ tokens_per_second = tokens_generated / ((end_time - start_time) if (end_time - start_time) > 0 else 1)
187
+
188
+ final_memory = torch.cuda.memory_allocated() / (1024**3) if torch.cuda.is_available() else 0
189
+ memory_used = final_memory - initial_memory
190
+
191
+ metrics = InferenceMetrics(
192
+ latency_ms=latency,
193
+ tokens_generated=tokens_generated,
194
+ tokens_per_second=tokens_per_second,
195
+ memory_used_gb=memory_used,
196
+ input_tokens=input_length,
197
+ total_tokens=outputs.shape[1]
198
+ )
199
+
200
+ self.metrics_history.append(metrics)
201
+
202
+ if return_metrics:
203
+ return response, metrics
204
+ return response
205
+
206
+ def _generate_stream(self, inputs, gen_config, return_metrics):
207
+ from transformers import TextIteratorStreamer
208
+ from threading import Thread
209
+
210
+ streamer = TextIteratorStreamer(
211
+ self.processor.tokenizer,
212
+ skip_special_tokens=True,
213
+ skip_prompt=True
214
+ )
215
+
216
+ gen_config['streamer'] = streamer
217
+
218
+ thread = Thread(
219
+ target=self.model.generate,
220
+ kwargs={**inputs, **gen_config}
221
+ )
222
+ thread.start()
223
+
224
+ for new_text in streamer:
225
+ yield new_text
226
+
227
+ thread.join()
228
+
229
+ def batch_generate(
230
+ self,
231
+ prompts: List[str],
232
+ images_list: Optional[List[Optional[Union[Image.Image, List[Image.Image]]]]] = None,
233
+ mode: str = "balanced",
234
+ **kwargs
235
+ ) -> List[str]:
236
+
237
+ if images_list is None:
238
+ images_list = [None] * len(prompts)
239
+
240
+ all_inputs = []
241
+ for prompt, images in zip(prompts, images_list):
242
+ if images:
243
+ if isinstance(images, Image.Image):
244
+ images = [images]
245
+ inputs = self.processor(
246
+ text=prompt,
247
+ images=images,
248
+ return_tensors="pt",
249
+ padding=True
250
+ )
251
+ else:
252
+ inputs = self.processor(
253
+ text=prompt,
254
+ return_tensors="pt",
255
+ padding=True
256
+ )
257
+ all_inputs.append(inputs)
258
+
259
+ batch_inputs = {
260
+ k: torch.cat([inp[k] for inp in all_inputs], dim=0).to(self.model.device)
261
+ for k in all_inputs[0].keys()
262
+ }
263
+
264
+ gen_config = self.generation_configs[mode].to_dict()
265
+ gen_config.update(kwargs)
266
+
267
+ with torch.no_grad():
268
+ outputs = self.model.generate(
269
+ **batch_inputs,
270
+ **gen_config,
271
+ pad_token_id=self.processor.tokenizer.eos_token_id
272
+ )
273
+
274
+ responses = [
275
+ self.processor.decode(output, skip_special_tokens=True)
276
+ for output in outputs
277
+ ]
278
+
279
+ return responses
280
+
281
+ def vision_qa(
282
+ self,
283
+ image: Image.Image,
284
+ question: str,
285
+ mode: str = "precise"
286
+ ) -> str:
287
+
288
+ prompt = f"Question: {question}\nAnswer:"
289
+ return self.generate(prompt, images=image, mode=mode)
290
+
291
+ def analyze_image(
292
+ self,
293
+ image: Image.Image,
294
+ analysis_type: str = "detailed"
295
+ ) -> str:
296
+
297
+ prompts = {
298
+ "detailed": "Provide a detailed description of this image, including objects, people, actions, setting, and any text visible.",
299
+ "quick": "Briefly describe what you see in this image.",
300
+ "technical": "Analyze this image from a technical perspective, including composition, lighting, colors, and quality.",
301
+ "ocr": "Extract all text visible in this image and organize it clearly."
302
+ }
303
+
304
+ prompt = prompts.get(analysis_type, prompts["detailed"])
305
+ return self.generate(prompt, images=image, mode="precise")
306
+
307
+ def code_generation(
308
+ self,
309
+ task: str,
310
+ language: str = "python",
311
+ include_tests: bool = False
312
+ ) -> str:
313
+
314
+ prompt = f"Write {language} code for the following task:\n{task}"
315
+
316
+ if include_tests:
317
+ prompt += "\n\nInclude unit tests for the code."
318
+
319
+ return self.generate(prompt, mode="code", max_new_tokens=2048)
320
+
321
+ def function_call(
322
+ self,
323
+ user_query: str,
324
+ available_tools: List[Dict[str, Any]]
325
+ ) -> Dict[str, Any]:
326
+
327
+ tools_str = json.dumps(available_tools, indent=2)
328
+
329
+ prompt = f"""Available tools:
330
+ {tools_str}
331
+
332
+ User query: {user_query}
333
+
334
+ Respond with a JSON object specifying which tool to use and with what parameters:
335
+ {{"tool": "tool_name", "parameters": {{"param": "value"}}}}
336
+
337
+ Response:"""
338
+
339
+ response = self.generate(prompt, mode="precise", temperature=0.2)
340
+
341
+ try:
342
+ import re
343
+ json_match = re.search(r'\{.*\}', response, re.DOTALL)
344
+ if json_match:
345
+ return json.loads(json_match.group())
346
+ return {"error": "No valid JSON found", "raw": response}
347
+ except json.JSONDecodeError as e:
348
+ return {"error": str(e), "raw": response}
349
+
350
+ def multi_modal_rag(
351
+ self,
352
+ query: str,
353
+ documents: List[str],
354
+ images: Optional[List[Image.Image]] = None
355
+ ) -> str:
356
+
357
+ context = "\n\n".join([f"Document {i+1}:\n{doc}" for i, doc in enumerate(documents)])
358
+
359
+ prompt = f"""Context:\n{context}\n\nQuestion: {query}\n\nAnswer based on the provided context:"""
360
+
361
+ return self.generate(prompt, images=images, mode="precise", max_new_tokens=1024)
362
+
363
+ def get_metrics_summary(self) -> Dict[str, float]:
364
+
365
+ if not self.metrics_history:
366
+ return {}
367
+
368
+ return {
369
+ "avg_latency_ms": sum(m.latency_ms for m in self.metrics_history) / len(self.metrics_history),
370
+ "avg_tokens_per_second": sum(m.tokens_per_second for m in self.metrics_history) / len(self.metrics_history),
371
+ "avg_memory_used_gb": sum(m.memory_used_gb for m in self.metrics_history) / len(self.metrics_history),
372
+ "total_tokens_generated": sum(m.tokens_generated for m in self.metrics_history),
373
+ "num_requests": len(self.metrics_history)
374
+ }
375
+
376
+ def clear_cache(self):
377
+
378
+ if torch.cuda.is_available():
379
+ torch.cuda.empty_cache()
380
+ self.model.clear_cache() if hasattr(self.model, 'clear_cache') else None
381
+ logger.info("Cache cleared")
382
+
383
+
384
+ def main():
385
+
386
+ import argparse
387
+
388
+ parser = argparse.ArgumentParser(description="Advanced Helion-V2.0-Thinking Inference")
389
+ parser.add_argument("--model", type=str, default="DeepXR/Helion-V2.0-Thinking")
390
+ parser.add_argument("--quantization", type=str, choices=["4bit", "8bit", None], default=None)
391
+ parser.add_argument("--mode", type=str, default="balanced", choices=["creative", "precise", "balanced", "code"])
392
+ parser.add_argument("--prompt", type=str, help="Text prompt")
393
+ parser.add_argument("--image", type=str, help="Path to image file")
394
+ parser.add_argument("--stream", action="store_true", help="Enable streaming output")
395
+ parser.add_argument("--torch-compile", action="store_true", help="Use torch.compile")
396
+ parser.add_argument("--benchmark", action="store_true", help="Run benchmark")
397
+
398
+ args = parser.parse_args()
399
+
400
+ model = AdvancedHelionInference(
401
+ model_name=args.model,
402
+ quantization=args.quantization,
403
+ torch_compile=args.torch_compile
404
+ )
405
+
406
+ if args.benchmark:
407
+ print("Running benchmark...")
408
+
409
+ test_prompts = [
410
+ "Explain quantum computing in simple terms.",
411
+ "Write a Python function to calculate fibonacci numbers.",
412
+ "What are the main causes of climate change?"
413
+ ]
414
+
415
+ for prompt in test_prompts:
416
+ response, metrics = model.generate(
417
+ prompt,
418
+ mode=args.mode,
419
+ return_metrics=True
420
+ )
421
+ print(f"\nPrompt: {prompt}")
422
+ print(f"Response: {response[:100]}...")
423
+ print(f"Metrics: {metrics}")
424
+
425
+ summary = model.get_metrics_summary()
426
+ print(f"\nBenchmark Summary:")
427
+ for key, value in summary.items():
428
+ print(f" {key}: {value:.2f}")
429
+
430
+ elif args.prompt:
431
+ image = Image.open(args.image) if args.image else None
432
+
433
+ if args.stream:
434
+ print("Streaming response:")
435
+ for text in model.generate(args.prompt, images=image, mode=args.mode, stream=True):
436
+ print(text, end="", flush=True)
437
+ print()
438
+ else:
439
+ response, metrics = model.generate(
440
+ args.prompt,
441
+ images=image,
442
+ mode=args.mode,
443
+ return_metrics=True
444
+ )
445
+ print(f"Response: {response}")
446
+ print(f"\nMetrics:")
447
+ print(f" Latency: {metrics.latency_ms:.2f}ms")
448
+ print(f" Tokens/sec: {metrics.tokens_per_second:.2f}")
449
+ print(f" Tokens generated: {metrics.tokens_generated}")
450
+ else:
451
+ print("Please provide --prompt or use --benchmark")
452
+
453
+
454
+ if __name__ == "__main__":
455
+ main()