kshitijthakkar commited on
Commit
ddc41f5
·
1 Parent(s): 4ea319e

fixed the task dropdown on tab 2 and enhanced model handler

Browse files
Files changed (3) hide show
  1. enhanced_app.py +148 -113
  2. enhanced_model_handler.py +1297 -0
  3. model_handler.py +1 -1
enhanced_app.py CHANGED
@@ -6,7 +6,9 @@ import datetime
6
  import json
7
  import random
8
  import os
9
- from model_handler import generate_response, get_inference_configs
 
 
10
  import torch
11
 
12
  # Configuration for datasets
@@ -210,10 +212,18 @@ def get_eval_dataset_info(dataset_name):
210
  return "No dataset selected"
211
 
212
 
 
 
 
 
 
 
213
  def get_task_types_for_eval(dataset_name):
214
  """Get unique task types from selected eval dataset"""
215
  if dataset_name in EVAL_DATASETS and 'task_type' in EVAL_DATASETS[dataset_name].columns:
216
  task_types = EVAL_DATASETS[dataset_name]['task_type'].unique().tolist()
 
 
217
  return [str(t) for t in task_types if pd.notna(t)]
218
  return ["No task types available"]
219
 
@@ -235,36 +245,58 @@ def get_tasks_by_type_eval(dataset_name, task_type):
235
  return ["No tasks found"]
236
 
237
 
238
- def get_selected_row_data(dataset_name, task_type, selected_task):
239
- """Get all data for the selected row"""
240
- if not selected_task or selected_task == "No tasks found":
241
- return "", "", "", "", "", "",""
242
-
243
- try:
244
- # Extract row index from selected_task
245
- row_idx = int(selected_task.split("Row ")[1].split(":")[0])
246
-
247
- if dataset_name in EVAL_DATASETS:
248
- df = EVAL_DATASETS[dataset_name]
249
- if row_idx in df.index:
250
- row = df.loc[row_idx]
251
-
252
- # Extract all fields with safe handling for missing columns
253
- task = str(row.get('task', 'N/A'))
254
- task_type_val = str(row.get('task_type', 'N/A'))
255
- input_model = str(row.get('input_model', 'N/A'))
256
- expected_response = str(row.get('expected_response', 'N/A'))
257
- loggenix_output = str(row.get('loggenix_output', 'N/A'))
258
- output_model = str(row.get('output_model', 'N/A'))
259
- input_text = str(row.get('input', 'N/A'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
 
 
 
261
 
262
- return task_type_val, input_model, output_model, task, input_text, expected_response, loggenix_output
 
 
 
 
 
 
263
 
264
- except Exception as e:
265
- return f"Error: {str(e)}", "", "", "", "", "", "", ""
266
 
267
- return "", "", "", "", "", "", ""
268
 
269
  # ===== TAB 3: VIEW FLAGGED RESPONSES =====
270
 
@@ -457,46 +489,46 @@ def create_interface():
457
  )
458
 
459
  # TAB 2: EVAL SAMPLES
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
  with gr.Tab("📊 Eval Samples"):
461
  gr.Markdown("## Dataset Evaluation Samples")
462
-
463
- with gr.Row():
464
- with gr.Column(scale=1):
465
- eval_dataset_dropdown = gr.Dropdown(
466
- choices=list(EVAL_DATASETS.keys()),
467
- value=list(EVAL_DATASETS.keys())[0] if EVAL_DATASETS else None,
468
- label="Select Dataset",
469
- info="Choose evaluation dataset to view"
470
- )
471
-
472
- eval_dataset_info = gr.Markdown(
473
- get_eval_dataset_info(list(EVAL_DATASETS.keys())[0] if EVAL_DATASETS else "")
474
- )
475
-
476
- with gr.Row():
477
- eval_table = gr.Dataframe(
478
- value=update_eval_table(list(EVAL_DATASETS.keys())[0]) if EVAL_DATASETS else pd.DataFrame(),
479
- label="Dataset Table",
480
- max_height=800,
481
- min_width=800,
482
- interactive=True,
483
- wrap=True,
484
- show_fullscreen_button=True,
485
- show_copy_button=True,
486
- show_row_numbers=True,
487
- show_search="search",
488
- column_widths=["80px","80px","80px","150px","250px","250px","250px"]
489
- )
490
-
491
- # Event handlers for Tab 2
492
- eval_dataset_dropdown.change(
493
- fn=lambda x: (update_eval_table(x), get_eval_dataset_info(x)),
494
- inputs=[eval_dataset_dropdown],
495
- outputs=[eval_table, eval_dataset_info]
496
- )
497
- with gr.Tab("📊 Eval Samples 2"):
498
- gr.Markdown("## Dataset Evaluation Samples")
499
- gr.Markdown("Select dataset, task type, and specific task to view detailed information")
500
 
501
  with gr.Row():
502
  with gr.Column(scale=1):
@@ -510,13 +542,8 @@ def create_interface():
510
  eval_task_type_dropdown = gr.Dropdown(
511
  choices=[],
512
  label="Select Task Type",
513
- info="Choose task type from selected dataset"
514
- )
515
-
516
- eval_task_dropdown = gr.Dropdown(
517
- choices=[],
518
- label="Select Specific Task",
519
- info="Choose specific task to view details"
520
  )
521
 
522
  with gr.Column(scale=1):
@@ -526,39 +553,33 @@ def create_interface():
526
 
527
  # Task details section
528
  gr.Markdown("### Task Details")
529
-
530
  with gr.Row():
531
- with gr.Column():
532
- task_field = gr.Textbox(
533
- label="Task",
534
- lines=8,
535
- max_lines=12,
536
- interactive=False
537
- )
538
-
539
- task_type_field = gr.Textbox(
540
- label="Task Type",
541
- lines=1,
542
- interactive=False
543
- )
544
 
545
- input_model_field = gr.Textbox(
546
- label="input_model",
547
- lines=1,
548
- interactive=False
549
- )
 
 
 
 
 
 
 
550
 
551
- input_field = gr.Textbox(
552
- label="input",
553
- lines=8,
554
- max_lines=12,
555
- interactive=False
556
- )
557
- output_model_field = gr.Textbox(
558
- label="output_model",
559
- lines=1,
560
- interactive=False
561
- )
562
 
563
  # Large text fields for outputs side by side
564
  gr.Markdown("### Expected vs Actual Response Comparison")
@@ -579,25 +600,39 @@ def create_interface():
579
  )
580
 
581
  # Event handlers for Tab 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
582
  eval_dataset_dropdown.change(
583
- fn=lambda x: (get_eval_dataset_info(x), get_task_types_for_eval(x), []),
584
  inputs=[eval_dataset_dropdown],
585
- outputs=[eval_dataset_info, eval_task_type_dropdown, eval_task_dropdown]
586
  )
587
-
588
  eval_task_type_dropdown.change(
589
- fn=get_tasks_by_type_eval,
590
  inputs=[eval_dataset_dropdown, eval_task_type_dropdown],
591
- outputs=[eval_task_dropdown]
592
- )
593
-
594
- eval_task_dropdown.change(
595
- fn=get_selected_row_data,
596
- inputs=[eval_dataset_dropdown, eval_task_type_dropdown, eval_task_dropdown],
597
- outputs=[task_type_field, input_model_field, output_model_field, task_field, input_field,
598
  loggenix_output_field, expected_response_field]
599
  )
600
 
 
601
  # TAB 3: VIEW FLAGGED RESPONSES (RENAMED FROM TAB 4)
602
  with gr.Tab("👀 View Flagged Responses"):
603
  gr.Markdown("## Review Flagged Responses")
 
6
  import json
7
  import random
8
  import os
9
+ #from model_handler import generate_response, get_inference_configs
10
+ from enhanced_model_handler import generate_response, get_inference_configs
11
+
12
  import torch
13
 
14
  # Configuration for datasets
 
212
  return "No dataset selected"
213
 
214
 
215
+ # def get_task_types_for_eval(dataset_name):
216
+ # """Get unique task types from selected eval dataset"""
217
+ # if dataset_name in EVAL_DATASETS and 'task_type' in EVAL_DATASETS[dataset_name].columns:
218
+ # task_types = EVAL_DATASETS[dataset_name]['task_type'].unique().tolist()
219
+ # return [str(t) for t in task_types if pd.notna(t)]
220
+ # return ["No task types available"]
221
  def get_task_types_for_eval(dataset_name):
222
  """Get unique task types from selected eval dataset"""
223
  if dataset_name in EVAL_DATASETS and 'task_type' in EVAL_DATASETS[dataset_name].columns:
224
  task_types = EVAL_DATASETS[dataset_name]['task_type'].unique().tolist()
225
+ # The correct way is to return the list directly, not a joined string.
226
+ # The list comprehension `[str(t) for t in task_types if pd.notna(t)]` already does this.
227
  return [str(t) for t in task_types if pd.notna(t)]
228
  return ["No task types available"]
229
 
 
245
  return ["No tasks found"]
246
 
247
 
248
+ # def get_selected_row_data(dataset_name, task_type, selected_task):
249
+ # """Get all data for the selected row"""
250
+ # if not selected_task or selected_task == "No tasks found":
251
+ # return "", "", "", "", "", "",""
252
+ #
253
+ # try:
254
+ # # Extract row index from selected_task
255
+ # row_idx = int(selected_task.split("Row ")[1].split(":")[0])
256
+ #
257
+ # if dataset_name in EVAL_DATASETS:
258
+ # df = EVAL_DATASETS[dataset_name]
259
+ # if row_idx in df.index:
260
+ # row = df.loc[row_idx]
261
+ #
262
+ # # Extract all fields with safe handling for missing columns
263
+ # task = str(row.get('task', 'N/A'))
264
+ # task_type_val = str(row.get('task_type', 'N/A'))
265
+ # input_model = str(row.get('input_model', 'N/A'))
266
+ # expected_response = str(row.get('expected_response', 'N/A'))
267
+ # loggenix_output = str(row.get('loggenix_output', 'N/A'))
268
+ # output_model = str(row.get('output_model', 'N/A'))
269
+ # input_text = str(row.get('input', 'N/A'))
270
+ #
271
+ #
272
+ # return task_type_val, input_model, output_model, task, input_text, expected_response, loggenix_output
273
+ #
274
+ # except Exception as e:
275
+ # return f"Error: {str(e)}", "", "", "", "", "", "", ""
276
+ #
277
+ # return "", "", "", "", "", "", ""
278
+
279
+ def get_selected_row_data_by_type(dataset_name, task_type):
280
+ """Get all data for the first row of a selected dataset and task type"""
281
+ if (dataset_name in EVAL_DATASETS and
282
+ 'task_type' in EVAL_DATASETS[dataset_name].columns and
283
+ 'task' in EVAL_DATASETS[dataset_name].columns):
284
 
285
+ filtered = EVAL_DATASETS[dataset_name][EVAL_DATASETS[dataset_name]['task_type'] == task_type]
286
+ if len(filtered) > 0:
287
+ row = filtered.iloc[0] # Get the first row
288
 
289
+ # Extract all fields with safe handling for missing columns
290
+ task = str(row.get('task', 'N/A'))
291
+ input_model = str(row.get('input_model', 'N/A'))
292
+ expected_response = str(row.get('expected_response', 'N/A'))
293
+ loggenix_output = str(row.get('loggenix_output', 'N/A'))
294
+ output_model = str(row.get('output_model', 'N/A'))
295
+ input_text = str(row.get('input', 'N/A'))
296
 
297
+ return input_model, output_model, task, input_text, expected_response, loggenix_output
 
298
 
299
+ return "", "", "", "", "", ""
300
 
301
  # ===== TAB 3: VIEW FLAGGED RESPONSES =====
302
 
 
489
  )
490
 
491
  # TAB 2: EVAL SAMPLES
492
+ # with gr.Tab("📊 Eval Samples"):
493
+ # gr.Markdown("## Dataset Evaluation Samples")
494
+ #
495
+ # with gr.Row():
496
+ # with gr.Column(scale=1):
497
+ # eval_dataset_dropdown = gr.Dropdown(
498
+ # choices=list(EVAL_DATASETS.keys()),
499
+ # value=list(EVAL_DATASETS.keys())[0] if EVAL_DATASETS else None,
500
+ # label="Select Dataset",
501
+ # info="Choose evaluation dataset to view"
502
+ # )
503
+ #
504
+ # eval_dataset_info = gr.Markdown(
505
+ # get_eval_dataset_info(list(EVAL_DATASETS.keys())[0] if EVAL_DATASETS else "")
506
+ # )
507
+ #
508
+ # with gr.Row():
509
+ # eval_table = gr.Dataframe(
510
+ # value=update_eval_table(list(EVAL_DATASETS.keys())[0]) if EVAL_DATASETS else pd.DataFrame(),
511
+ # label="Dataset Table",
512
+ # max_height=800,
513
+ # min_width=800,
514
+ # interactive=True,
515
+ # wrap=True,
516
+ # show_fullscreen_button=True,
517
+ # show_copy_button=True,
518
+ # show_row_numbers=True,
519
+ # show_search="search",
520
+ # column_widths=["80px","80px","80px","150px","250px","250px","250px"]
521
+ # )
522
+ #
523
+ # # Event handlers for Tab 2
524
+ # eval_dataset_dropdown.change(
525
+ # fn=lambda x: (update_eval_table(x), get_eval_dataset_info(x)),
526
+ # inputs=[eval_dataset_dropdown],
527
+ # outputs=[eval_table, eval_dataset_info]
528
+ # )
529
  with gr.Tab("📊 Eval Samples"):
530
  gr.Markdown("## Dataset Evaluation Samples")
531
+ gr.Markdown("Select dataset and task type to view detailed information")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
 
533
  with gr.Row():
534
  with gr.Column(scale=1):
 
542
  eval_task_type_dropdown = gr.Dropdown(
543
  choices=[],
544
  label="Select Task Type",
545
+ info="Choose task type from selected dataset",
546
+ allow_custom_value=True
 
 
 
 
 
547
  )
548
 
549
  with gr.Column(scale=1):
 
553
 
554
  # Task details section
555
  gr.Markdown("### Task Details")
 
556
  with gr.Row():
557
+ input_model_field = gr.Textbox(
558
+ label="input_model",
559
+ lines=1,
560
+ interactive=False
561
+ )
 
 
 
 
 
 
 
 
562
 
563
+ output_model_field = gr.Textbox(
564
+ label="output_model",
565
+ lines=1,
566
+ interactive=False
567
+ )
568
+ with gr.Row():
569
+ task_field = gr.Textbox(
570
+ label="Task",
571
+ lines=2,
572
+ max_lines=5,
573
+ interactive=False
574
+ )
575
 
576
+ with gr.Row():
577
+ input_field = gr.Textbox(
578
+ label="input",
579
+ lines=8,
580
+ max_lines=20,
581
+ interactive=False
582
+ )
 
 
 
 
583
 
584
  # Large text fields for outputs side by side
585
  gr.Markdown("### Expected vs Actual Response Comparison")
 
600
  )
601
 
602
  # Event handlers for Tab 2
603
+ # eval_dataset_dropdown.change(
604
+ # fn=lambda x: (get_eval_dataset_info(x), get_task_types_for_eval(x), None),
605
+ # inputs=[eval_dataset_dropdown],
606
+ # outputs=[eval_dataset_info, eval_task_type_dropdown]
607
+ # )
608
+
609
+ # Event handlers for Tab 2
610
+ # eval_dataset_dropdown.change(
611
+ # fn=lambda x: (get_eval_dataset_info(x), get_task_types_for_eval(x)),
612
+ # inputs=[eval_dataset_dropdown],
613
+ # outputs=[eval_dataset_info, eval_task_type_dropdown]
614
+ # )
615
+ # Define a new function instead of lambda for clarity
616
+ def update_eval_components(dataset_name):
617
+ info = get_eval_dataset_info(dataset_name)
618
+ task_types = get_task_types_for_eval(dataset_name)
619
+ return info, gr.update(choices=task_types,
620
+ value=task_types[0] if task_types else "No task types available")
621
+
622
+ # In the event handlers for Tab 2, replace the existing .change with this:
623
  eval_dataset_dropdown.change(
624
+ fn=update_eval_components,
625
  inputs=[eval_dataset_dropdown],
626
+ outputs=[eval_dataset_info, eval_task_type_dropdown]
627
  )
 
628
  eval_task_type_dropdown.change(
629
+ fn=get_selected_row_data_by_type,
630
  inputs=[eval_dataset_dropdown, eval_task_type_dropdown],
631
+ outputs=[input_model_field, output_model_field, task_field, input_field,
 
 
 
 
 
 
632
  loggenix_output_field, expected_response_field]
633
  )
634
 
635
+ # NOTE: The get_tasks_by_type_eval and eval_task_dropdown.change handlers are removed as per request.
636
  # TAB 3: VIEW FLAGGED RESPONSES (RENAMED FROM TAB 4)
637
  with gr.Tab("👀 View Flagged Responses"):
638
  gr.Markdown("## Review Flagged Responses")
enhanced_model_handler.py ADDED
@@ -0,0 +1,1297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import time
3
+ import gc
4
+ import json
5
+ import re
6
+ import logging
7
+ import traceback
8
+ import sys
9
+ from pathlib import Path
10
+ from typing import Dict, Any, Optional, Tuple
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
12
+
13
+
14
+ # Configure logging
15
+ def setup_logging(log_level=logging.INFO, log_file="model_inference.log"):
16
+ """Setup comprehensive logging configuration"""
17
+ # Create logs directory if it doesn't exist
18
+ log_dir = Path("logs")
19
+ log_dir.mkdir(exist_ok=True)
20
+
21
+ # Create formatter
22
+ formatter = logging.Formatter(
23
+ '%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s'
24
+ )
25
+
26
+ # Setup file handler
27
+ file_handler = logging.FileHandler(log_dir / log_file)
28
+ file_handler.setLevel(log_level)
29
+ file_handler.setFormatter(formatter)
30
+
31
+ # Setup console handler
32
+ console_handler = logging.StreamHandler(sys.stdout)
33
+ console_handler.setLevel(log_level)
34
+ console_handler.setFormatter(formatter)
35
+
36
+ # Setup logger
37
+ logger = logging.getLogger(__name__)
38
+ logger.setLevel(log_level)
39
+ logger.addHandler(file_handler)
40
+ logger.addHandler(console_handler)
41
+
42
+ # Prevent duplicate logs
43
+ logger.propagate = False
44
+
45
+ return logger
46
+
47
+
48
+ # Initialize logger
49
+ logger = setup_logging()
50
+
51
+ # Performance optimizations
52
+ try:
53
+ torch.backends.cudnn.benchmark = True
54
+ torch.backends.cuda.matmul.allow_tf32 = True
55
+ torch.backends.cudnn.allow_tf32 = True
56
+ logger.info("PyTorch optimizations enabled successfully")
57
+ except Exception as e:
58
+ logger.warning(f"Failed to enable some PyTorch optimizations: {e}")
59
+
60
+ # Global model and tokenizer variables
61
+ model = None
62
+ tokenizer = None
63
+ MODEL_ID = "kshitijthakkar/loggenix-moe-0.3B-A0.1B-e3-lr7e5-b16-4090-v6.3-finetuned-tool"
64
+
65
+ # Inference configurations
66
+ INFERENCE_CONFIGS = {
67
+ "Optimized for Speed": {
68
+ "max_new_tokens_base": 512,
69
+ "max_new_tokens_cap": 512,
70
+ "min_tokens": 50,
71
+ "temperature": 0.7,
72
+ "top_p": 0.9,
73
+ "do_sample": True,
74
+ "use_cache": False,
75
+ "description": "Fast responses with limited output length"
76
+ },
77
+ "Middle-ground": {
78
+ "max_new_tokens_base": 2048,
79
+ "max_new_tokens_cap": 2048,
80
+ "min_tokens": 50,
81
+ "temperature": 0.7,
82
+ "top_p": 0.9,
83
+ "do_sample": True,
84
+ "use_cache": False,
85
+ "description": "Balanced performance and output quality"
86
+ },
87
+ "Full Capacity": {
88
+ "max_new_tokens_base": 4096,
89
+ "max_new_tokens_cap": 4096,
90
+ "min_tokens": 1,
91
+ "temperature": 0.7,
92
+ "top_p": 0.9,
93
+ "do_sample": True,
94
+ "use_cache": False,
95
+ "description": "Maximum output length with dynamic allocation"
96
+ }
97
+ }
98
+
99
+
100
+ def validate_config(config_name: str) -> bool:
101
+ """Validate inference configuration"""
102
+ try:
103
+ if config_name not in INFERENCE_CONFIGS:
104
+ logger.error(f"Invalid config name: {config_name}. Available: {list(INFERENCE_CONFIGS.keys())}")
105
+ return False
106
+
107
+ config = INFERENCE_CONFIGS[config_name]
108
+ required_fields = ["max_new_tokens_base", "max_new_tokens_cap", "min_tokens", "temperature", "top_p"]
109
+
110
+ for field in required_fields:
111
+ if field not in config:
112
+ logger.error(f"Missing required field '{field}' in config '{config_name}'")
113
+ return False
114
+
115
+ logger.debug(f"Configuration '{config_name}' validated successfully")
116
+ return True
117
+ except Exception as e:
118
+ logger.error(f"Error validating config '{config_name}': {e}")
119
+ return False
120
+
121
+
122
+ def get_inference_configs():
123
+ """Get available inference configurations"""
124
+ try:
125
+ logger.debug("Retrieving inference configurations")
126
+ return INFERENCE_CONFIGS
127
+ except Exception as e:
128
+ logger.error(f"Error retrieving inference configurations: {e}")
129
+ return {}
130
+
131
+
132
+ def check_system_requirements() -> bool:
133
+ """Check if system meets requirements for model loading"""
134
+ try:
135
+ # Check CUDA availability
136
+ if not torch.cuda.is_available():
137
+ logger.warning("CUDA is not available. Model will run on CPU (much slower)")
138
+ return True # Still allow CPU execution
139
+
140
+ # Check GPU memory
141
+ gpu_count = torch.cuda.device_count()
142
+ logger.info(f"Found {gpu_count} GPU(s)")
143
+
144
+ for i in range(gpu_count):
145
+ gpu_props = torch.cuda.get_device_properties(i)
146
+ total_memory = gpu_props.total_memory / 1e9
147
+ logger.info(f"GPU {i}: {gpu_props.name}, Memory: {total_memory:.1f}GB")
148
+
149
+ if total_memory < 4.0: # Minimum 4GB for quantized model
150
+ logger.warning(f"GPU {i} has insufficient memory ({total_memory:.1f}GB < 4.0GB)")
151
+
152
+ return True
153
+ except Exception as e:
154
+ logger.error(f"Error checking system requirements: {e}")
155
+ return False
156
+
157
+
158
+ def load_model() -> Tuple[Optional[Any], Optional[Any]]:
159
+ """Load model and tokenizer with comprehensive error handling"""
160
+ global model, tokenizer
161
+
162
+ try:
163
+ if model is not None and tokenizer is not None:
164
+ logger.debug("Model and tokenizer already loaded")
165
+ return model, tokenizer
166
+
167
+ logger.info("Starting model loading process...")
168
+
169
+ # Check system requirements
170
+ if not check_system_requirements():
171
+ logger.error("System requirements check failed")
172
+ return None, None
173
+
174
+ # Load tokenizer with error handling
175
+ logger.info(f"Loading tokenizer from {MODEL_ID}...")
176
+ try:
177
+ tokenizer = AutoTokenizer.from_pretrained(
178
+ MODEL_ID,
179
+ trust_remote_code=True, # Add this for custom tokenizers
180
+ #cache_dir="./model_cache" # Use local cache
181
+ )
182
+ logger.info("Tokenizer loaded successfully")
183
+ except Exception as e:
184
+ logger.error(f"Failed to load tokenizer: {e}")
185
+ logger.error(f"Traceback: {traceback.format_exc()}")
186
+ return None, None
187
+
188
+ # Configure quantization
189
+ try:
190
+ quantization_config = BitsAndBytesConfig(
191
+ load_in_8bit=True,
192
+ llm_int8_threshold=6.0,
193
+ llm_int8_has_fp16_weight=False,
194
+ )
195
+ logger.info("8-bit quantization configuration created")
196
+ except Exception as e:
197
+ logger.error(f"Failed to create quantization config: {e}")
198
+ quantization_config = None
199
+
200
+ # Load model with extensive error handling
201
+ logger.info(f"Loading model from {MODEL_ID}...")
202
+ try:
203
+ model_kwargs = {
204
+ "device_map": "auto",
205
+ "dtype": torch.float16,
206
+ "use_cache": False,
207
+ "trust_remote_code": True,
208
+ #"cache_dir": "./model_cache"
209
+ }
210
+
211
+ # Add quantization if available
212
+ if quantization_config:
213
+ model_kwargs["quantization_config"] = quantization_config
214
+
215
+ # Try to use flash attention if available
216
+ try:
217
+ if hasattr(torch.nn, 'scaled_dot_product_attention'):
218
+ model_kwargs["attn_implementation"] = "flash_attention_2"
219
+ logger.info("Using Flash Attention 2")
220
+ except Exception as e:
221
+ logger.warning(f"Flash Attention 2 not available: {e}")
222
+
223
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **model_kwargs)
224
+ model = model.eval()
225
+ logger.info("Model loaded successfully")
226
+
227
+ except torch.cuda.OutOfMemoryError:
228
+ logger.error("CUDA out of memory. Try reducing batch size or using CPU")
229
+ return None, None
230
+ except Exception as e:
231
+ logger.error(f"Failed to load model: {e}")
232
+ logger.error(f"Traceback: {traceback.format_exc()}")
233
+ return None, None
234
+
235
+ # Configure model settings with error handling
236
+ try:
237
+ # Enable gradient checkpointing if available
238
+ if hasattr(model, 'gradient_checkpointing_enable'):
239
+ model.gradient_checkpointing_enable()
240
+ logger.debug("Gradient checkpointing enabled")
241
+
242
+ # Set pad_token_id
243
+ if model.config.pad_token_id is None:
244
+ if tokenizer.pad_token_id is not None:
245
+ model.config.pad_token_id = tokenizer.pad_token_id
246
+ logger.debug("Set model pad_token_id from tokenizer")
247
+ else:
248
+ # Fallback to eos_token_id
249
+ model.config.pad_token_id = tokenizer.eos_token_id
250
+ tokenizer.pad_token_id = tokenizer.eos_token_id
251
+ logger.debug("Set pad_token_id to eos_token_id")
252
+
253
+ # Set padding side to left for better batching
254
+ tokenizer.padding_side = "left"
255
+ logger.debug("Set tokenizer padding side to left")
256
+
257
+ except Exception as e:
258
+ logger.warning(f"Error configuring model settings: {e}")
259
+
260
+ # Log memory usage
261
+ try:
262
+ if hasattr(model, 'get_memory_footprint'):
263
+ memory = model.get_memory_footprint() / 1e6
264
+ logger.info(f"Model memory footprint: {memory:,.1f} MB")
265
+ except Exception as e:
266
+ logger.warning(f"Could not get memory footprint: {e}")
267
+
268
+ logger.info("Model loading completed successfully")
269
+ return model, tokenizer
270
+
271
+ except Exception as e:
272
+ logger.error(f"Unexpected error in load_model: {e}")
273
+ logger.error(f"Traceback: {traceback.format_exc()}")
274
+ return None, None
275
+
276
+
277
+ # ===== TOOL DEFINITIONS =====
278
+
279
+ def calculate_numbers(operation: str, num1: float, num2: float) -> Dict[str, Any]:
280
+ """
281
+ Sample tool to perform basic mathematical operations on two numbers.
282
+
283
+ Args:
284
+ operation: The operation to perform ('add', 'subtract', 'multiply', 'divide')
285
+ num1: First number
286
+ num2: Second number
287
+
288
+ Returns:
289
+ Dictionary with result and operation details
290
+ """
291
+ try:
292
+ logger.debug(f"Calculating: {num1} {operation} {num2}")
293
+
294
+ # Validate inputs
295
+ if not isinstance(operation, str):
296
+ raise ValueError("Operation must be a string")
297
+
298
+ try:
299
+ num1, num2 = float(num1), float(num2)
300
+ except (ValueError, TypeError) as e:
301
+ logger.error(f"Invalid number format: num1={num1}, num2={num2}")
302
+ return {"error": f"Invalid number format: {str(e)}"}
303
+
304
+ operation = operation.lower().strip()
305
+
306
+ # Perform operation
307
+ if operation == 'add':
308
+ result = num1 + num2
309
+ elif operation == 'subtract':
310
+ result = num1 - num2
311
+ elif operation == 'multiply':
312
+ result = num1 * num2
313
+ elif operation == 'divide':
314
+ if num2 == 0:
315
+ logger.error("Division by zero attempted")
316
+ return {"error": "Division by zero is not allowed"}
317
+ result = num1 / num2
318
+ else:
319
+ logger.error(f"Unknown operation: {operation}")
320
+ return {"error": f"Unknown operation: {operation}. Supported: add, subtract, multiply, divide"}
321
+
322
+ response = {
323
+ "result": result,
324
+ "operation": operation,
325
+ "operands": [num1, num2],
326
+ "formatted": f"{num1} {operation} {num2} = {result}"
327
+ }
328
+
329
+ logger.debug(f"Calculation successful: {response['formatted']}")
330
+ return response
331
+
332
+ except Exception as e:
333
+ logger.error(f"Unexpected error in calculate_numbers: {e}")
334
+ return {"error": f"Calculation error: {str(e)}"}
335
+
336
+
337
+ # Tool registry
338
+ AVAILABLE_TOOLS = {
339
+ "calculate_numbers": {
340
+ "function": calculate_numbers,
341
+ "description": "Perform basic mathematical operations (add, subtract, multiply, divide) on two numbers",
342
+ "parameters": {
343
+ "operation": "The mathematical operation to perform",
344
+ "num1": "First number",
345
+ "num2": "Second number"
346
+ }
347
+ }
348
+ }
349
+
350
+
351
+ def execute_tool_call(tool_name: str, **kwargs) -> Dict[str, Any]:
352
+ """Execute a tool call with given parameters"""
353
+ try:
354
+ logger.info(f"Executing tool: {tool_name} with parameters: {kwargs}")
355
+
356
+ if not tool_name or not isinstance(tool_name, str):
357
+ logger.error(f"Invalid tool name: {tool_name}")
358
+ return {"error": "Invalid tool name"}
359
+
360
+ if tool_name not in AVAILABLE_TOOLS:
361
+ logger.error(f"Unknown tool: {tool_name}. Available: {list(AVAILABLE_TOOLS.keys())}")
362
+ return {"error": f"Unknown tool: {tool_name}"}
363
+
364
+ if not isinstance(kwargs, dict):
365
+ logger.error(f"Invalid parameters type: {type(kwargs)}")
366
+ return {"error": "Parameters must be a dictionary"}
367
+
368
+ tool_function = AVAILABLE_TOOLS[tool_name]["function"]
369
+ result = tool_function(**kwargs)
370
+
371
+ response = {
372
+ "tool_name": tool_name,
373
+ "parameters": kwargs,
374
+ "result": result
375
+ }
376
+
377
+ if "error" not in result:
378
+ logger.info(f"Tool execution successful: {tool_name}")
379
+ else:
380
+ logger.warning(f"Tool execution returned error: {result['error']}")
381
+
382
+ return response
383
+
384
+ except TypeError as e:
385
+ logger.error(f"Parameter error for tool '{tool_name}': {e}")
386
+ return {
387
+ "tool_name": tool_name,
388
+ "parameters": kwargs,
389
+ "error": f"Invalid parameters: {str(e)}"
390
+ }
391
+ except Exception as e:
392
+ logger.error(f"Tool execution failed: {str(e)}")
393
+ logger.error(f"Traceback: {traceback.format_exc()}")
394
+ return {
395
+ "tool_name": tool_name,
396
+ "parameters": kwargs,
397
+ "error": f"Tool execution error: {str(e)}"
398
+ }
399
+
400
+
401
+ def parse_tool_calls(text: str) -> list:
402
+ """
403
+ Parse tool calls from model output with comprehensive error handling.
404
+ Supports both formats:
405
+ - [TOOL_CALL:tool_name(param1=value1, param2=value2)]
406
+ - <tool_call>{"name": "tool_name", "parameters": {"param1": "value1", "param2": "value2"}}</tool_call>
407
+ """
408
+ try:
409
+ if not text or not isinstance(text, str):
410
+ logger.warning("Invalid text input for tool call parsing")
411
+ return []
412
+
413
+ tool_calls = []
414
+ logger.debug(f"Parsing tool calls from text: {text[:200]}...")
415
+
416
+ # Pattern for both formats
417
+ pattern = r'(\[TOOL_CALL:(\w+)\((.*?)\)\]|<tool_call>\s*{"name":\s*"(\w+)",\s*"parameters":\s*{([^}]*)}\s*}\s*</tool_call>)'
418
+ matches = re.findall(pattern, text)
419
+ logger.debug(f"Found {len(matches)} potential tool call matches")
420
+
421
+ for i, match in enumerate(matches):
422
+ try:
423
+ full_match, old_tool_name, old_params, json_tool_name, json_params = match
424
+
425
+ # Determine which format was matched
426
+ if old_tool_name: # Old format: [TOOL_CALL:tool_name(params)]
427
+ tool_name = old_tool_name
428
+ params_str = old_params
429
+ original_call = f"[TOOL_CALL:{tool_name}({params_str})]"
430
+
431
+ params = {}
432
+ if params_str.strip():
433
+ param_pairs = params_str.split(',')
434
+ for pair in param_pairs:
435
+ try:
436
+ if '=' in pair:
437
+ key, value = pair.split('=', 1)
438
+ key = key.strip()
439
+ value = value.strip().strip('"\'') # Remove quotes
440
+ params[key] = value
441
+ except Exception as e:
442
+ logger.warning(f"Error parsing parameter pair '{pair}': {e}")
443
+
444
+ logger.debug(f"Parsed old format tool call: {tool_name} with params: {params}")
445
+
446
+ elif json_tool_name: # JSON format: <tool_call>...</tool_call>
447
+ tool_name = json_tool_name
448
+ params_str = json_params
449
+ original_call = full_match
450
+
451
+ params = {}
452
+ if params_str.strip():
453
+ # Parse JSON-like parameters
454
+ param_pairs = params_str.split(',')
455
+ for pair in param_pairs:
456
+ try:
457
+ if ':' in pair:
458
+ key, value = pair.split(':', 1)
459
+ key = key.strip().strip('"\'') # Remove quotes and whitespace
460
+ value = value.strip().strip('"\'') # Remove quotes and whitespace
461
+ params[key] = value
462
+ except Exception as e:
463
+ logger.warning(f"Error parsing JSON parameter pair '{pair}': {e}")
464
+
465
+ logger.debug(f"Parsed JSON format tool call: {tool_name} with params: {params}")
466
+
467
+ else:
468
+ logger.warning(f"Could not determine tool call format for match {i}")
469
+ continue
470
+
471
+ # Validate tool call
472
+ if tool_name and isinstance(params, dict):
473
+ tool_calls.append({
474
+ "tool_name": tool_name,
475
+ "parameters": params,
476
+ "original_call": original_call
477
+ })
478
+ else:
479
+ logger.warning(f"Invalid tool call data: tool_name='{tool_name}', params={params}")
480
+
481
+ except Exception as e:
482
+ logger.error(f"Error parsing tool call match {i}: {e}")
483
+ continue
484
+
485
+ logger.info(f"Successfully parsed {len(tool_calls)} tool calls")
486
+ return tool_calls
487
+
488
+ except Exception as e:
489
+ logger.error(f"Unexpected error in parse_tool_calls: {e}")
490
+ logger.error(f"Traceback: {traceback.format_exc()}")
491
+ return []
492
+
493
+
494
+ def process_tool_calls(text: str) -> str:
495
+ """Process tool calls in the generated text and replace with results"""
496
+ try:
497
+ if not text:
498
+ logger.warning("Empty text provided to process_tool_calls")
499
+ return text
500
+
501
+ logger.debug("Processing tool calls in generated text")
502
+ tool_calls = parse_tool_calls(text)
503
+
504
+ if not tool_calls:
505
+ logger.debug("No tool calls found in text")
506
+ return text
507
+
508
+ processed_text = text
509
+ successful_calls = 0
510
+
511
+ for i, tool_call in enumerate(tool_calls):
512
+ try:
513
+ tool_name = tool_call["tool_name"]
514
+ parameters = tool_call["parameters"]
515
+ original_call = tool_call["original_call"]
516
+
517
+ logger.debug(f"Processing tool call {i + 1}/{len(tool_calls)}: {tool_name}")
518
+
519
+ # Validate parameters before execution
520
+ if not isinstance(parameters, dict):
521
+ logger.error(f"Invalid parameters for tool {tool_name}: {parameters}")
522
+ replacement = f"[TOOL_ERROR: Invalid parameters for tool {tool_name}]"
523
+ else:
524
+ # Execute tool
525
+ result = execute_tool_call(tool_name, **parameters)
526
+
527
+ # Create replacement text
528
+ if "error" in result:
529
+ replacement = f"[TOOL_ERROR: {result['error']}]"
530
+ logger.warning(f"Tool call failed: {result['error']}")
531
+ else:
532
+ if "result" in result["result"] and "formatted" in result["result"]:
533
+ replacement = f"[TOOL_RESULT: {result['result']['formatted']}]"
534
+ elif "result" in result:
535
+ replacement = f"[TOOL_RESULT: {result['result']}]"
536
+ else:
537
+ replacement = f"[TOOL_RESULT: Success]"
538
+
539
+ successful_calls += 1
540
+ logger.debug(f"Tool call successful: {replacement}")
541
+
542
+ # Replace tool call with result
543
+ processed_text = processed_text.replace(original_call, replacement)
544
+
545
+ except Exception as e:
546
+ logger.error(f"Error processing tool call {i + 1}: {e}")
547
+ tool_name = tool_call.get("tool_name", "unknown")
548
+ original_call = tool_call.get("original_call", "")
549
+ replacement = f"[TOOL_ERROR: Failed to process tool call: {str(e)}]"
550
+ if original_call:
551
+ processed_text = processed_text.replace(original_call, replacement)
552
+
553
+ logger.info(f"Processed {len(tool_calls)} tool calls ({successful_calls} successful)")
554
+ return processed_text
555
+
556
+ except Exception as e:
557
+ logger.error(f"Unexpected error in process_tool_calls: {e}")
558
+ logger.error(f"Traceback: {traceback.format_exc()}")
559
+ return text # Return original text if processing fails
560
+
561
+
562
+ def monitor_memory():
563
+ """Monitor and log memory usage"""
564
+ try:
565
+ if torch.cuda.is_available():
566
+ allocated = torch.cuda.memory_allocated() / 1e9
567
+ cached = torch.cuda.memory_reserved() / 1e9
568
+ max_allocated = torch.cuda.max_memory_allocated() / 1e9
569
+
570
+ logger.info(
571
+ f"GPU Memory - Allocated: {allocated:.2f}GB, Cached: {cached:.2f}GB, Max: {max_allocated:.2f}GB")
572
+
573
+ # Log warning if memory usage is high
574
+ total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
575
+ if allocated / total_memory > 0.9:
576
+ logger.warning(f"High GPU memory usage: {allocated / total_memory * 100:.1f}%")
577
+
578
+ # Clean up cache if needed
579
+ torch.cuda.empty_cache()
580
+ else:
581
+ logger.debug("CUDA not available, skipping GPU memory monitoring")
582
+
583
+ # Clean up Python memory
584
+ gc.collect()
585
+ logger.debug("Resources cleaned up successfully")
586
+
587
+ except Exception as e:
588
+ logger.error(f"Error monitoring memory: {e}")
589
+
590
+
591
+ def get_model_info() -> Dict[str, Any]:
592
+ """Get information about the loaded model"""
593
+ try:
594
+ if model is None:
595
+ return {"status": "not_loaded"}
596
+
597
+ info = {
598
+ "status": "loaded",
599
+ "model_id": MODEL_ID,
600
+ "device": str(model.device) if hasattr(model, 'device') else "unknown",
601
+ "dtype": str(model.dtype) if hasattr(model, 'dtype') else "unknown"
602
+ }
603
+
604
+ # Add memory info if available
605
+ if hasattr(model, 'get_memory_footprint'):
606
+ try:
607
+ info["memory_footprint_mb"] = model.get_memory_footprint() / 1e6
608
+ except:
609
+ pass
610
+
611
+ # Add GPU info if available
612
+ if torch.cuda.is_available():
613
+ info["gpu_count"] = torch.cuda.device_count()
614
+ info["current_gpu"] = torch.cuda.current_device()
615
+ info["gpu_memory_allocated"] = torch.cuda.memory_allocated() / 1e9
616
+ info["gpu_memory_cached"] = torch.cuda.memory_reserved() / 1e9
617
+
618
+ return info
619
+ except Exception as e:
620
+ logger.error(f"Error getting model info: {e}")
621
+ return {"status": "error", "error": str(e)}
622
+
623
+
624
+ def health_check() -> Dict[str, Any]:
625
+ """Perform a health check of the system"""
626
+ try:
627
+ health_status = {
628
+ "timestamp": time.time(),
629
+ "torch_version": torch.__version__,
630
+ "cuda_available": torch.cuda.is_available(),
631
+ "model_loaded": model is not None,
632
+ "tokenizer_loaded": tokenizer is not None,
633
+ }
634
+
635
+ if torch.cuda.is_available():
636
+ health_status.update({
637
+ "cuda_version": torch.version.cuda,
638
+ "gpu_count": torch.cuda.device_count(),
639
+ "gpu_memory_total": torch.cuda.get_device_properties(0).total_memory / 1e9,
640
+ "gpu_memory_available": (torch.cuda.get_device_properties(
641
+ 0).total_memory - torch.cuda.memory_allocated()) / 1e9
642
+ })
643
+
644
+ # Test a simple generation if model is loaded
645
+ if model is not None and tokenizer is not None:
646
+ try:
647
+ test_response = generate_response(
648
+ "You are a helpful assistant.",
649
+ "Say hello",
650
+ "Optimized for Speed"
651
+ )
652
+ health_status["test_generation"] = "success" if test_response else "failed"
653
+ except Exception as e:
654
+ health_status["test_generation"] = f"error: {str(e)}"
655
+
656
+ logger.info(f"Health check completed: {health_status}")
657
+ return health_status
658
+ except Exception as e:
659
+ logger.error(f"Error during health check: {e}")
660
+ return {"status": "error", "error": str(e)}
661
+
662
+
663
+ def validate_inputs(system_prompt: str, user_input: str, config_name: str) -> bool:
664
+ """Validate inputs for generate_response"""
665
+ try:
666
+ if not isinstance(system_prompt, str) or not system_prompt.strip():
667
+ logger.error("System prompt must be a non-empty string")
668
+ return False
669
+
670
+ if not isinstance(user_input, str) or not user_input.strip():
671
+ logger.error("User input must be a non-empty string")
672
+ return False
673
+
674
+ if not validate_config(config_name):
675
+ return False
676
+
677
+ # Check input length
678
+ total_length = len(system_prompt) + len(user_input)
679
+ if total_length > 50000: # Reasonable limit
680
+ logger.warning(f"Input length is very long: {total_length} characters")
681
+
682
+ return True
683
+ except Exception as e:
684
+ logger.error(f"Error validating inputs: {e}")
685
+ return False
686
+
687
+
688
+ def generate_response(system_prompt: str, user_input: str, config_name: str = "Middle-ground") -> Optional[str]:
689
+ """
690
+ Run inference with comprehensive error handling and logging.
691
+
692
+ Args:
693
+ system_prompt: System message/prompt
694
+ user_input: User's input message
695
+ config_name: Name of the inference configuration to use
696
+
697
+ Returns:
698
+ Generated response text, or None if generation failed
699
+ """
700
+ try:
701
+ logger.info(f"Starting response generation with config: {config_name}")
702
+
703
+ # Validate inputs
704
+ if not validate_inputs(system_prompt, user_input, config_name):
705
+ logger.error("Input validation failed")
706
+ return None
707
+
708
+ # Load model
709
+ model, tokenizer = load_model()
710
+ if model is None or tokenizer is None:
711
+ logger.error("Failed to load model or tokenizer")
712
+ return None
713
+
714
+ # Get configuration
715
+ config = INFERENCE_CONFIGS[config_name]
716
+ logger.debug(f"Using config: {config}")
717
+
718
+ # Prepare messages
719
+ input_messages = [
720
+ {"role": "system", "content": system_prompt},
721
+ {"role": "user", "content": user_input}
722
+ ]
723
+
724
+ # Apply chat template
725
+ try:
726
+ prompt_text = tokenizer.apply_chat_template(
727
+ input_messages,
728
+ tokenize=False,
729
+ add_generation_prompt=True
730
+ )
731
+ logger.debug("Chat template applied successfully")
732
+ except Exception as e:
733
+ logger.error(f"Failed to apply chat template: {e}")
734
+ # Fallback to simple concatenation
735
+ prompt_text = f"System: {system_prompt}\nUser: {user_input}\nAssistant:"
736
+ logger.info("Using fallback prompt format")
737
+
738
+ # Tokenize input
739
+ try:
740
+ input_length = len(tokenizer.encode(prompt_text))
741
+ context_length = min(input_length, 3584) # Leave room for generation
742
+
743
+ inputs = tokenizer(
744
+ prompt_text,
745
+ return_tensors="pt",
746
+ truncation=True,
747
+ max_length=context_length,
748
+ padding=False
749
+ ).to(model.device)
750
+
751
+ logger.debug(f"Input tokenized: {inputs['input_ids'].shape[1]} tokens")
752
+
753
+ except Exception as e:
754
+ logger.error(f"Failed to tokenize input: {e}")
755
+ return None
756
+
757
+ # Calculate generation parameters
758
+ actual_input_length = inputs['input_ids'].shape[1]
759
+ max_new_tokens = min(config["max_new_tokens_cap"], 4096 - actual_input_length - 10)
760
+ max_new_tokens = max(config["min_tokens"], max_new_tokens)
761
+
762
+ logger.debug(f"Generation params - Input length: {actual_input_length}, Max new tokens: {max_new_tokens}")
763
+
764
+ # Monitor memory before generation
765
+ monitor_memory()
766
+
767
+ # Generate response
768
+ try:
769
+ with torch.no_grad():
770
+ start_time = time.time()
771
+
772
+ generation_kwargs = {
773
+ "do_sample": config["do_sample"],
774
+ "temperature": config["temperature"],
775
+ "top_p": config["top_p"],
776
+ "use_cache": config["use_cache"],
777
+ "max_new_tokens": max_new_tokens,
778
+ "pad_token_id": tokenizer.pad_token_id,
779
+ "eos_token_id": tokenizer.eos_token_id,
780
+ "output_attentions": False,
781
+ "output_hidden_states": False,
782
+ "return_dict_in_generate": False,
783
+ }
784
+
785
+ outputs = model.generate(**inputs, **generation_kwargs)
786
+ inference_time = time.time() - start_time
787
+
788
+ logger.info(f"Generation completed in {inference_time:.2f} seconds")
789
+
790
+ except torch.cuda.OutOfMemoryError:
791
+ logger.error("CUDA out of memory during generation")
792
+ # Try to free memory
793
+ gc.collect()
794
+ torch.cuda.empty_cache()
795
+ return None
796
+ except Exception as e:
797
+ logger.error(f"Generation failed: {e}")
798
+ logger.error(f"Traceback: {traceback.format_exc()}")
799
+ return None
800
+
801
+ # Monitor memory after generation
802
+ monitor_memory()
803
+
804
+ # Clean up GPU memory
805
+ try:
806
+ gc.collect()
807
+ if torch.cuda.is_available():
808
+ torch.cuda.empty_cache()
809
+ except Exception as e:
810
+ logger.warning(f"Error during cleanup: {e}")
811
+
812
+ # Decode response
813
+ try:
814
+ full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
815
+
816
+ # Extract generated response
817
+ if prompt_text in full_text:
818
+ response_start = full_text.find(prompt_text) + len(prompt_text)
819
+ generated_response = full_text[response_start:].strip()
820
+ else:
821
+ # More robust fallback
822
+ generated_response = full_text.strip()
823
+ try:
824
+ # Look for common assistant/response indicators
825
+ response_indicators = ["Assistant:", "<|assistant|>", "[/INST]", "Response:"]
826
+ for indicator in response_indicators:
827
+ if indicator in full_text:
828
+ parts = full_text.split(indicator)
829
+ if len(parts) > 1:
830
+ generated_response = parts[-1].strip()
831
+ break
832
+
833
+ # If no indicator found, try to remove the input part
834
+ if user_input in full_text:
835
+ parts = full_text.split(user_input)
836
+ if len(parts) > 1:
837
+ generated_response = parts[-1].strip()
838
+
839
+ except Exception as extract_error:
840
+ logger.warning(f"Error extracting response: {extract_error}")
841
+ generated_response = full_text.strip()
842
+
843
+ logger.debug(f"Extracted response: {generated_response[:100]}...")
844
+
845
+ except Exception as e:
846
+ logger.error(f"Failed to decode response: {e}")
847
+ return None
848
+
849
+ # Process tool calls
850
+ try:
851
+ processed_response = process_tool_calls(generated_response)
852
+ logger.debug("Tool call processing completed")
853
+ except Exception as e:
854
+ logger.error(f"Error processing tool calls: {e}")
855
+ processed_response = generated_response # Use original if tool processing fails
856
+
857
+ # Log final statistics
858
+ input_tokens = inputs['input_ids'].shape[1]
859
+ output_tokens = outputs.shape[1] - input_tokens
860
+ logger.info(
861
+ f"Generation stats - Input tokens: {input_tokens}, Output tokens: {output_tokens}, Time: {inference_time:.2f}s")
862
+
863
+ logger.info("Response generation completed successfully")
864
+ return processed_response
865
+
866
+ except Exception as e:
867
+ logger.error(f"Unexpected error in generate_response: {e}")
868
+ logger.error(f"Traceback: {traceback.format_exc()}")
869
+ return None
870
+
871
+
872
+ def safe_generate_response(system_prompt: str, user_input: str, config_name: str = "Middle-ground",
873
+ max_retries: int = 2) -> Optional[str]:
874
+ """
875
+ Generate response with retry logic and fallback options
876
+
877
+ Args:
878
+ system_prompt: System message/prompt
879
+ user_input: User's input message
880
+ config_name: Name of the inference configuration to use
881
+ max_retries: Maximum number of retry attempts
882
+
883
+ Returns:
884
+ Generated response text, or None if all attempts failed
885
+ """
886
+ for attempt in range(max_retries + 1):
887
+ try:
888
+ logger.info(f"Generation attempt {attempt + 1}/{max_retries + 1}")
889
+
890
+ response = generate_response(system_prompt, user_input, config_name)
891
+ if response is not None:
892
+ logger.info(f"Generation successful on attempt {attempt + 1}")
893
+ return response
894
+
895
+ if attempt < max_retries:
896
+ logger.warning(f"Generation failed on attempt {attempt + 1}, retrying...")
897
+ # Clean up before retry
898
+ gc.collect()
899
+ if torch.cuda.is_available():
900
+ torch.cuda.empty_cache()
901
+ time.sleep(1) # Brief pause before retry
902
+
903
+ except Exception as e:
904
+ logger.error(f"Error on generation attempt {attempt + 1}: {e}")
905
+ if attempt < max_retries:
906
+ logger.info("Cleaning up and retrying...")
907
+ try:
908
+ gc.collect()
909
+ if torch.cuda.is_available():
910
+ torch.cuda.empty_cache()
911
+ except:
912
+ pass
913
+ time.sleep(2) # Longer pause after error
914
+
915
+ logger.error(f"All {max_retries + 1} generation attempts failed")
916
+ return None
917
+
918
+
919
+ # Context manager for safe model operations
920
+ class ModelContext:
921
+ """Context manager for safe model operations with automatic cleanup"""
922
+
923
+ def __init__(self, auto_cleanup: bool = True):
924
+ self.auto_cleanup = auto_cleanup
925
+ self.original_model = None
926
+ self.original_tokenizer = None
927
+
928
+ def __enter__(self):
929
+ global model, tokenizer
930
+ self.original_model = model
931
+ self.original_tokenizer = tokenizer
932
+ logger.debug("Entered model context")
933
+ return self
934
+
935
+ def __exit__(self, exc_type, exc_val, exc_tb):
936
+ if exc_type is not None:
937
+ logger.error(f"Exception in model context: {exc_type.__name__}: {exc_val}")
938
+
939
+ if self.auto_cleanup:
940
+ try:
941
+ gc.collect()
942
+ if torch.cuda.is_available():
943
+ torch.cuda.empty_cache()
944
+ logger.debug("Model context cleanup completed")
945
+ except Exception as e:
946
+ logger.warning(f"Error during model context cleanup: {e}")
947
+
948
+ logger.debug("Exited model context")
949
+
950
+
951
+ def cleanup_resources():
952
+ """Clean up model resources"""
953
+ global model, tokenizer
954
+ try:
955
+ if model is not None:
956
+ del model
957
+ model = None
958
+ logger.info("Model removed from memory")
959
+ if tokenizer is not None:
960
+ del tokenizer
961
+ tokenizer = None
962
+ logger.info("Tokenizer removed from memory")
963
+
964
+ # Clean up GPU memory
965
+ gc.collect()
966
+ if torch.cuda.is_available():
967
+ torch.cuda.empty_cache()
968
+ torch.cuda.synchronize()
969
+ logger.info("GPU memory cleaned up")
970
+
971
+ logger.info("Resource cleanup completed")
972
+
973
+ except Exception as e:
974
+ logger.error(f"Error during resource cleanup: {e}")
975
+
976
+
977
+ def unload_model():
978
+ """Explicitly unload the model and tokenizer"""
979
+ try:
980
+ logger.info("Unloading model and tokenizer...")
981
+ cleanup_resources()
982
+ logger.info("Model and tokenizer unloaded successfully")
983
+ return True
984
+ except Exception as e:
985
+ logger.error(f"Error unloading model: {e}")
986
+ return False
987
+
988
+
989
+ def reload_model():
990
+ """Reload the model and tokenizer"""
991
+ try:
992
+ logger.info("Reloading model and tokenizer...")
993
+ # First clean up existing resources
994
+ cleanup_resources()
995
+ time.sleep(1) # Brief pause
996
+
997
+ # Load fresh model and tokenizer
998
+ model, tokenizer = load_model()
999
+ if model is not None and tokenizer is not None:
1000
+ logger.info("Model and tokenizer reloaded successfully")
1001
+ return True
1002
+ else:
1003
+ logger.error("Failed to reload model and tokenizer")
1004
+ return False
1005
+ except Exception as e:
1006
+ logger.error(f"Error reloading model: {e}")
1007
+ return False
1008
+
1009
+
1010
+ def get_available_tools() -> Dict[str, Any]:
1011
+ """Get information about available tools"""
1012
+ try:
1013
+ return {
1014
+ "tools": AVAILABLE_TOOLS,
1015
+ "count": len(AVAILABLE_TOOLS),
1016
+ "tool_names": list(AVAILABLE_TOOLS.keys())
1017
+ }
1018
+ except Exception as e:
1019
+ logger.error(f"Error getting available tools: {e}")
1020
+ return {"error": str(e)}
1021
+
1022
+
1023
+ def add_tool(tool_name: str, tool_function, description: str, parameters: Dict[str, str]):
1024
+ """Add a new tool to the registry"""
1025
+ try:
1026
+ if not tool_name or not isinstance(tool_name, str):
1027
+ raise ValueError("Tool name must be a non-empty string")
1028
+
1029
+ if not callable(tool_function):
1030
+ raise ValueError("Tool function must be callable")
1031
+
1032
+ if tool_name in AVAILABLE_TOOLS:
1033
+ logger.warning(f"Tool '{tool_name}' already exists, replacing...")
1034
+
1035
+ AVAILABLE_TOOLS[tool_name] = {
1036
+ "function": tool_function,
1037
+ "description": description,
1038
+ "parameters": parameters or {}
1039
+ }
1040
+
1041
+ logger.info(f"Tool '{tool_name}' added successfully")
1042
+ return True
1043
+ except Exception as e:
1044
+ logger.error(f"Error adding tool '{tool_name}': {e}")
1045
+ return False
1046
+
1047
+
1048
+ def remove_tool(tool_name: str):
1049
+ """Remove a tool from the registry"""
1050
+ try:
1051
+ if tool_name not in AVAILABLE_TOOLS:
1052
+ logger.warning(f"Tool '{tool_name}' not found")
1053
+ return False
1054
+
1055
+ del AVAILABLE_TOOLS[tool_name]
1056
+ logger.info(f"Tool '{tool_name}' removed successfully")
1057
+ return True
1058
+ except Exception as e:
1059
+ logger.error(f"Error removing tool '{tool_name}': {e}")
1060
+ return False
1061
+
1062
+
1063
+ # Example usage and testing functions
1064
+ def run_example():
1065
+ """Run an example to test the system"""
1066
+ try:
1067
+ logger.info("Running example test")
1068
+
1069
+ # Test health check
1070
+ health = health_check()
1071
+ logger.info(f"System health: {health}")
1072
+
1073
+ # Test model loading
1074
+ model_obj, tokenizer_obj = load_model()
1075
+ if model_obj is None or tokenizer_obj is None:
1076
+ logger.error("Failed to load model for example")
1077
+ return False
1078
+
1079
+ # Test generation
1080
+ with ModelContext():
1081
+ response = safe_generate_response(
1082
+ "You are a helpful mathematical assistant.",
1083
+ "What is 15 + 25? Use the calculate_numbers tool.",
1084
+ "Optimized for Speed"
1085
+ )
1086
+
1087
+ if response:
1088
+ logger.info(f"Example response: {response}")
1089
+ return True
1090
+ else:
1091
+ logger.error("Example generation failed")
1092
+ return False
1093
+
1094
+ except Exception as e:
1095
+ logger.error(f"Error in example: {e}")
1096
+ return False
1097
+
1098
+
1099
+ def run_batch_test():
1100
+ """Run batch test with multiple inputs"""
1101
+ try:
1102
+ logger.info("Running batch test")
1103
+
1104
+ test_cases = [
1105
+ {
1106
+ "system": "You are a helpful assistant.",
1107
+ "user": "Hello, how are you?",
1108
+ "config": "Optimized for Speed"
1109
+ },
1110
+ {
1111
+ "system": "You are a mathematical assistant.",
1112
+ "user": "Calculate 10 * 5 using the calculate_numbers tool.",
1113
+ "config": "Middle-ground"
1114
+ },
1115
+ {
1116
+ "system": "You are a helpful assistant.",
1117
+ "user": "Explain the concept of machine learning in simple terms.",
1118
+ "config": "Full Capacity"
1119
+ }
1120
+ ]
1121
+
1122
+ results = []
1123
+ for i, test_case in enumerate(test_cases):
1124
+ logger.info(f"Running test case {i + 1}/{len(test_cases)}")
1125
+
1126
+ with ModelContext():
1127
+ response = safe_generate_response(
1128
+ test_case["system"],
1129
+ test_case["user"],
1130
+ test_case["config"]
1131
+ )
1132
+
1133
+ results.append({
1134
+ "test_case": i + 1,
1135
+ "success": response is not None,
1136
+ "response": response[:100] + "..." if response and len(response) > 100 else response
1137
+ })
1138
+
1139
+ success_count = sum(1 for r in results if r["success"])
1140
+ logger.info(f"Batch test completed: {success_count}/{len(test_cases)} successful")
1141
+
1142
+ return results
1143
+
1144
+ except Exception as e:
1145
+ logger.error(f"Error in batch test: {e}")
1146
+ return []
1147
+
1148
+
1149
+ def benchmark_generation(num_runs: int = 5):
1150
+ """Benchmark generation performance"""
1151
+ try:
1152
+ logger.info(f"Running benchmark with {num_runs} iterations")
1153
+
1154
+ # Load model first
1155
+ model_obj, tokenizer_obj = load_model()
1156
+ if model_obj is None or tokenizer_obj is None:
1157
+ logger.error("Failed to load model for benchmark")
1158
+ return None
1159
+
1160
+ system_prompt = "You are a helpful assistant."
1161
+ user_input = "Explain the importance of renewable energy in 2-3 sentences."
1162
+
1163
+ times = []
1164
+ token_counts = []
1165
+
1166
+ for i in range(num_runs):
1167
+ logger.info(f"Benchmark run {i + 1}/{num_runs}")
1168
+
1169
+ start_time = time.time()
1170
+ response = generate_response(system_prompt, user_input, "Middle-ground")
1171
+ end_time = time.time()
1172
+
1173
+ if response:
1174
+ generation_time = end_time - start_time
1175
+ times.append(generation_time)
1176
+
1177
+ # Estimate token count (rough approximation)
1178
+ token_count = len(response.split()) * 1.3 # Rough tokens-to-words ratio
1179
+ token_counts.append(token_count)
1180
+
1181
+ logger.info(f"Run {i + 1}: {generation_time:.2f}s, ~{token_count:.0f} tokens")
1182
+ else:
1183
+ logger.warning(f"Run {i + 1} failed")
1184
+
1185
+ if times:
1186
+ avg_time = sum(times) / len(times)
1187
+ avg_tokens = sum(token_counts) / len(token_counts)
1188
+ tokens_per_sec = avg_tokens / avg_time if avg_time > 0 else 0
1189
+
1190
+ benchmark_results = {
1191
+ "runs": num_runs,
1192
+ "successful_runs": len(times),
1193
+ "avg_time": avg_time,
1194
+ "avg_tokens": avg_tokens,
1195
+ "tokens_per_second": tokens_per_sec,
1196
+ "min_time": min(times),
1197
+ "max_time": max(times)
1198
+ }
1199
+
1200
+ logger.info(f"Benchmark results: {benchmark_results}")
1201
+ return benchmark_results
1202
+ else:
1203
+ logger.error("All benchmark runs failed")
1204
+ return None
1205
+
1206
+ except Exception as e:
1207
+ logger.error(f"Error in benchmark: {e}")
1208
+ return None
1209
+
1210
+
1211
+ # API-like interface functions
1212
+ def initialize_system():
1213
+ """Initialize the inference system"""
1214
+ try:
1215
+ logger.info("Initializing inference system...")
1216
+
1217
+ # Check system requirements
1218
+ if not check_system_requirements():
1219
+ return {"status": "error", "message": "System requirements not met"}
1220
+
1221
+ # Load model and tokenizer
1222
+ model_obj, tokenizer_obj = load_model()
1223
+ if model_obj is None or tokenizer_obj is None:
1224
+ return {"status": "error", "message": "Failed to load model"}
1225
+
1226
+ # Run health check
1227
+ health = health_check()
1228
+ if "error" in health:
1229
+ return {"status": "warning", "message": "System initialized with warnings", "health": health}
1230
+
1231
+ logger.info("Inference system initialized successfully")
1232
+ return {"status": "success", "message": "System initialized successfully", "health": health}
1233
+
1234
+ except Exception as e:
1235
+ logger.error(f"Error initializing system: {e}")
1236
+ return {"status": "error", "message": str(e)}
1237
+
1238
+
1239
+ def shutdown_system():
1240
+ """Shutdown the inference system cleanly"""
1241
+ try:
1242
+ logger.info("Shutting down inference system...")
1243
+ cleanup_resources()
1244
+ logger.info("Inference system shutdown complete")
1245
+ return {"status": "success", "message": "System shutdown successfully"}
1246
+ except Exception as e:
1247
+ logger.error(f"Error during shutdown: {e}")
1248
+ return {"status": "error", "message": str(e)}
1249
+
1250
+
1251
+ if __name__ == "__main__":
1252
+ """Main entry point for testing"""
1253
+ try:
1254
+ logger.info("Starting model inference system")
1255
+
1256
+ # Initialize system
1257
+ init_result = initialize_system()
1258
+ logger.info(f"Initialization result: {init_result}")
1259
+
1260
+ if init_result["status"] != "error":
1261
+ # Run example
1262
+ success = run_example()
1263
+
1264
+ if success:
1265
+ logger.info("System test completed successfully")
1266
+
1267
+ # Optionally run additional tests
1268
+ print("\nWould you like to run additional tests? (y/n)")
1269
+ try:
1270
+ choice = input().lower().strip()
1271
+ if choice == 'y':
1272
+ logger.info("Running batch test...")
1273
+ batch_results = run_batch_test()
1274
+ logger.info(f"Batch test results: {batch_results}")
1275
+
1276
+ logger.info("Running benchmark...")
1277
+ benchmark_results = benchmark_generation(3)
1278
+ logger.info(f"Benchmark results: {benchmark_results}")
1279
+
1280
+ except (EOFError, KeyboardInterrupt):
1281
+ logger.info("Skipping additional tests")
1282
+ else:
1283
+ logger.error("System test failed")
1284
+
1285
+ # Shutdown
1286
+ shutdown_result = shutdown_system()
1287
+ logger.info(f"Shutdown result: {shutdown_result}")
1288
+
1289
+ except KeyboardInterrupt:
1290
+ logger.info("Interrupted by user")
1291
+ cleanup_resources()
1292
+ except Exception as e:
1293
+ logger.error(f"Unexpected error in main: {e}")
1294
+ logger.error(f"Traceback: {traceback.format_exc()}")
1295
+ cleanup_resources()
1296
+ finally:
1297
+ logger.info("Program terminated")
model_handler.py CHANGED
@@ -14,7 +14,7 @@ torch.backends.cudnn.allow_tf32 = True
14
  # Global model and tokenizer variables
15
  model = None
16
  tokenizer = None
17
- MODEL_ID = "kshitijthakkar/loggenix-moe-0.3B-A0.1B-e3-lr7e5-b16-4090-v6.2-finetuned-tool"
18
 
19
  # Inference configurations
20
  INFERENCE_CONFIGS = {
 
14
  # Global model and tokenizer variables
15
  model = None
16
  tokenizer = None
17
+ MODEL_ID = "kshitijthakkar/loggenix-moe-0.3B-A0.1B-e3-lr7e5-b16-4090-v6.3-finetuned-tool"
18
 
19
  # Inference configurations
20
  INFERENCE_CONFIGS = {