Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -486,5 +486,239 @@ def start_answer_generation(model_choice: str):
|
|
| 486 |
thread.start()
|
| 487 |
|
| 488 |
return f"Answer generation started using {model_choice}. Check progress."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
|
| 490 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 486 |
thread.start()
|
| 487 |
|
| 488 |
return f"Answer generation started using {model_choice}. Check progress."
|
| 489 |
+
def get_generation_progress():
|
| 490 |
+
"""
|
| 491 |
+
Get the current progress of answer generation.
|
| 492 |
+
"""
|
| 493 |
+
if not processing_status["is_processing"] and processing_status["progress"] == 0:
|
| 494 |
+
return "Not started"
|
| 495 |
+
|
| 496 |
+
if processing_status["is_processing"]:
|
| 497 |
+
progress = processing_status["progress"]
|
| 498 |
+
total = processing_status["total"]
|
| 499 |
+
status_msg = f"Generating answers... {progress}/{total} completed"
|
| 500 |
+
return status_msg
|
| 501 |
+
else:
|
| 502 |
+
# Generation completed
|
| 503 |
+
if cached_answers:
|
| 504 |
+
# Create DataFrame with results
|
| 505 |
+
display_data = []
|
| 506 |
+
for task_id, data in cached_answers.items():
|
| 507 |
+
display_data.append({
|
| 508 |
+
"Task ID": task_id,
|
| 509 |
+
"Question": data["question"][:100] + "..." if len(data["question"]) > 100 else data["question"],
|
| 510 |
+
"Generated Answer": data["answer"][:200] + "..." if len(data["answer"]) > 200 else data["answer"]
|
| 511 |
+
})
|
| 512 |
+
|
| 513 |
+
df = pd.DataFrame(display_data)
|
| 514 |
+
status_msg = f"Answer generation completed! {len(cached_answers)} answers ready for submission."
|
| 515 |
+
return status_msg, df
|
| 516 |
+
else:
|
| 517 |
+
return "Answer generation completed but no answers were generated."
|
| 518 |
+
|
| 519 |
+
def submit_cached_answers(profile: gr.OAuthProfile | None):
|
| 520 |
+
"""
|
| 521 |
+
Submit the cached answers to the evaluation API.
|
| 522 |
+
"""
|
| 523 |
+
global cached_answers
|
| 524 |
+
|
| 525 |
+
if not profile:
|
| 526 |
+
return "Please log in to Hugging Face first.", None
|
| 527 |
+
|
| 528 |
+
if not cached_answers:
|
| 529 |
+
return "No cached answers available. Please generate answers first.", None
|
| 530 |
+
|
| 531 |
+
username = profile.username
|
| 532 |
+
space_id = os.getenv("SPACE_ID")
|
| 533 |
+
agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main" if space_id else "Unknown"
|
| 534 |
+
|
| 535 |
+
# Prepare submission payload
|
| 536 |
+
answers_payload = []
|
| 537 |
+
for task_id, data in cached_answers.items():
|
| 538 |
+
answers_payload.append({
|
| 539 |
+
"task_id": task_id,
|
| 540 |
+
"submitted_answer": data["answer"]
|
| 541 |
+
})
|
| 542 |
+
|
| 543 |
+
submission_data = {
|
| 544 |
+
"username": username.strip(),
|
| 545 |
+
"agent_code": agent_code,
|
| 546 |
+
"answers": answers_payload
|
| 547 |
+
}
|
| 548 |
+
|
| 549 |
+
# Submit to API
|
| 550 |
+
api_url = DEFAULT_API_URL
|
| 551 |
+
submit_url = f"{api_url}/submit"
|
| 552 |
+
|
| 553 |
+
print(f"Submitting {len(answers_payload)} answers to: {submit_url}")
|
| 554 |
+
|
| 555 |
+
try:
|
| 556 |
+
response = requests.post(submit_url, json=submission_data, timeout=60)
|
| 557 |
+
response.raise_for_status()
|
| 558 |
+
result_data = response.json()
|
| 559 |
+
|
| 560 |
+
final_status = (
|
| 561 |
+
f"Submission Successful!\n"
|
| 562 |
+
f"User: {result_data.get('username')}\n"
|
| 563 |
+
f"Overall Score: {result_data.get('score', 'N/A')}% "
|
| 564 |
+
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
|
| 565 |
+
f"Message: {result_data.get('message', 'No message received.')}"
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
# Create results DataFrame
|
| 569 |
+
results_log = []
|
| 570 |
+
for task_id, data in cached_answers.items():
|
| 571 |
+
results_log.append({
|
| 572 |
+
"Task ID": task_id,
|
| 573 |
+
"Question": data["question"],
|
| 574 |
+
"Submitted Answer": data["answer"]
|
| 575 |
+
})
|
| 576 |
+
|
| 577 |
+
results_df = pd.DataFrame(results_log)
|
| 578 |
+
return final_status, results_df
|
| 579 |
+
|
| 580 |
+
except requests.exceptions.HTTPError as e:
|
| 581 |
+
error_detail = f"Server responded with status {e.response.status_code}."
|
| 582 |
+
try:
|
| 583 |
+
error_json = e.response.json()
|
| 584 |
+
error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
|
| 585 |
+
except:
|
| 586 |
+
error_detail += f" Response: {e.response.text[:500]}"
|
| 587 |
+
return f"Submission Failed: {error_detail}", None
|
| 588 |
+
|
| 589 |
+
except requests.exceptions.Timeout:
|
| 590 |
+
return "Submission Failed: The request timed out.", None
|
| 591 |
+
|
| 592 |
+
except Exception as e:
|
| 593 |
+
return f"Submission Failed: {e}", None
|
| 594 |
|
| 595 |
+
def clear_cache():
|
| 596 |
+
"""
|
| 597 |
+
Clear all cached data.
|
| 598 |
+
"""
|
| 599 |
+
global cached_answers, cached_questions, processing_status
|
| 600 |
+
cached_answers = {}
|
| 601 |
+
cached_questions = []
|
| 602 |
+
processing_status = {"is_processing": False, "progress": 0, "total": 0}
|
| 603 |
+
return "Cache cleared successfully.", None
|
| 604 |
+
|
| 605 |
+
def test_media_processing(image_files, audio_files, question):
|
| 606 |
+
"""
|
| 607 |
+
Test the media processing functionality with uploaded files.
|
| 608 |
+
"""
|
| 609 |
+
if not question:
|
| 610 |
+
question = "What can you tell me about the uploaded media?"
|
| 611 |
+
|
| 612 |
+
agent = IntelligentAgent(debug=True)
|
| 613 |
+
|
| 614 |
+
# Convert file paths to lists
|
| 615 |
+
image_paths = [img.name for img in image_files] if image_files else None
|
| 616 |
+
audio_paths = [aud.name for aud in audio_files] if audio_files else None
|
| 617 |
+
|
| 618 |
+
try:
|
| 619 |
+
result = agent(question, image_files=image_paths, audio_files=audio_paths)
|
| 620 |
+
return result
|
| 621 |
+
except Exception as e:
|
| 622 |
+
return f"Error processing media: {e}"
|
| 623 |
+
|
| 624 |
+
# --- Enhanced Gradio Interface ---
|
| 625 |
+
with gr.Blocks(title="Intelligent Agent with Media Processing") as demo:
|
| 626 |
+
gr.Markdown("# Intelligent Agent with Conditional Search and Media Processing")
|
| 627 |
+
gr.Markdown("This agent can process images and audio files, uses an LLM to decide when search is needed, optimizing for both accuracy and efficiency.")
|
| 628 |
+
|
| 629 |
+
with gr.Row():
|
| 630 |
+
gr.LoginButton()
|
| 631 |
+
clear_btn = gr.Button("Clear Cache", variant="secondary")
|
| 632 |
+
|
| 633 |
+
with gr.Tab("Media Processing Test"):
|
| 634 |
+
gr.Markdown("### Test Image and Audio Processing")
|
| 635 |
+
|
| 636 |
+
with gr.Row():
|
| 637 |
+
with gr.Column():
|
| 638 |
+
image_upload = gr.File(
|
| 639 |
+
label="Upload Images",
|
| 640 |
+
file_types=["image"],
|
| 641 |
+
file_count="multiple"
|
| 642 |
+
)
|
| 643 |
+
audio_upload = gr.File(
|
| 644 |
+
label="Upload Audio Files",
|
| 645 |
+
file_types=["audio"],
|
| 646 |
+
file_count="multiple"
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
with gr.Column():
|
| 650 |
+
test_question = gr.Textbox(
|
| 651 |
+
label="Question about the media",
|
| 652 |
+
placeholder="What can you tell me about these files?",
|
| 653 |
+
lines=3
|
| 654 |
+
)
|
| 655 |
+
test_btn = gr.Button("Process Media", variant="primary")
|
| 656 |
+
|
| 657 |
+
test_output = gr.Textbox(
|
| 658 |
+
label="Processing Result",
|
| 659 |
+
lines=10,
|
| 660 |
+
interactive=False
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
test_btn.click(
|
| 664 |
+
fn=test_media_processing,
|
| 665 |
+
inputs=[image_upload, audio_upload, test_question],
|
| 666 |
+
outputs=test_output
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
with gr.Tab("Step 1: Fetch Questions"):
|
| 670 |
+
gr.Markdown("### Fetch Questions from API")
|
| 671 |
+
fetch_btn = gr.Button("Fetch Questions", variant="primary")
|
| 672 |
+
fetch_status = gr.Textbox(label="Fetch Status", lines=2, interactive=False)
|
| 673 |
+
questions_table = gr.DataFrame(label="Available Questions", wrap=True)
|
| 674 |
+
|
| 675 |
+
fetch_btn.click(
|
| 676 |
+
fn=fetch_questions,
|
| 677 |
+
outputs=[fetch_status, questions_table]
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
with gr.Tab("Step 2: Generate Answers"):
|
| 681 |
+
gr.Markdown("### Generate Answers with Intelligent Search Decision")
|
| 682 |
+
|
| 683 |
+
with gr.Row():
|
| 684 |
+
model_choice = gr.Dropdown(
|
| 685 |
+
choices=["Llama 3.1 8B", "Mistral 7B"],
|
| 686 |
+
value="Llama 3.1 8B",
|
| 687 |
+
label="Select Model"
|
| 688 |
+
)
|
| 689 |
+
generate_btn = gr.Button("Start Answer Generation", variant="primary")
|
| 690 |
+
refresh_btn = gr.Button("Refresh Progress", variant="secondary")
|
| 691 |
+
|
| 692 |
+
generation_status = gr.Textbox(label="Generation Status", lines=2, interactive=False)
|
| 693 |
+
answers_table = gr.DataFrame(label="Generated Answers", wrap=True)
|
| 694 |
+
|
| 695 |
+
generate_btn.click(
|
| 696 |
+
fn=start_answer_generation,
|
| 697 |
+
inputs=[model_choice],
|
| 698 |
+
outputs=generation_status
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
refresh_btn.click(
|
| 702 |
+
fn=get_generation_progress,
|
| 703 |
+
outputs=[generation_status, answers_table]
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
with gr.Tab("Step 3: Submit Results"):
|
| 707 |
+
gr.Markdown("### Submit Generated Answers")
|
| 708 |
+
submit_btn = gr.Button("Submit Answers", variant="primary")
|
| 709 |
+
submit_status = gr.Textbox(label="Submission Status", lines=4, interactive=False)
|
| 710 |
+
results_table = gr.DataFrame(label="Submission Results", wrap=True)
|
| 711 |
+
|
| 712 |
+
submit_btn.click(
|
| 713 |
+
fn=submit_cached_answers,
|
| 714 |
+
outputs=[submit_status, results_table]
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
# Clear cache functionality
|
| 718 |
+
clear_btn.click(
|
| 719 |
+
fn=clear_cache,
|
| 720 |
+
outputs=[fetch_status, questions_table]
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
if __name__ == "__main__":
|
| 724 |
+
demo.launch()
|