MedCodeMCP / app.py
gpaasch's picture
app.py wrapper for Gradio waa very bad idea, regoranizing project for clarity, utils folder will be very import of separation of concerns
3b5fe24
raw
history blame
10.8 kB
from huggingface_hub import hf_hub_download
import gradio as gr
from llama_index.core import Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.llama_cpp import LlamaCPP
from src.parse_tabular import create_symptom_index
from utils import model_configuration_utils as mc
from utils import voice_input_utils as viu
import json
import torch
import torchaudio.transforms as T
# Set up model paths
MODEL_NAME, REPO_ID = mc.select_best_model()
# Ensure model is downloaded
model_path = mc.ensure_model()
# Configure local LLM with LlamaCPP
print("\nInitializing LLM...")
llm = LlamaCPP(
model_path=model_path,
temperature=0.7,
max_new_tokens=256,
context_window=2048,
verbose=False # Reduce logging
# n_batch and n_threads are not valid parameters for LlamaCPP and should not be used.
# If you encounter segmentation faults, try reducing context_window or check your system resources.
)
print("LLM initialized successfully")
# Configure global settings
print("\nConfiguring settings...")
Settings.llm = llm
Settings.embed_model = HuggingFaceEmbedding(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
print("Settings configured")
# Create the index at startup
print("\nCreating symptom index...")
symptom_index = create_symptom_index()
print("Index created successfully")
print("Loaded symptom_index:", type(symptom_index))
# --- System prompt ---
SYSTEM_PROMPT = """
You are a medical assistant helping a user narrow down to the most likely ICD-10 code.
At each turn, EITHER ask one focused clarifying question (e.g. "Is your cough dry or productive?")
or, if you have enough info, output a final JSON with fields:
{"diagnoses":[…], "confidences":[…]}.
"""
# Build enhanced Gradio interface
with gr.Blocks(theme="default") as demo:
gr.Markdown("""
# 🏥 Medical Symptom to ICD-10 Code Assistant
## About
This application is part of the Agents+MCP Hackathon. It helps medical professionals
and patients understand potential diagnoses based on described symptoms.
### How it works:
1. Either click the record button and describe your symptoms or type them into the textbox
2. The AI will analyze your description and suggest possible diagnoses
3. Answer follow-up questions to refine the diagnosis
""")
with gr.Row():
with gr.Column(scale=2):
# Add text input above microphone
with gr.Row():
text_input = gr.Textbox(
label="Type your symptoms",
placeholder="Or type your symptoms here...",
lines=3
)
submit_btn = gr.Button("Submit", variant="primary")
# Existing microphone row
with gr.Row():
microphone = gr.Audio(
sources=["microphone"],
streaming=True,
type="numpy",
label="Describe your symptoms"
)
transcript_box = gr.Textbox(
label="Transcribed Text",
interactive=False,
show_label=True
)
clear_btn = gr.Button("Clear Chat", variant="secondary")
chatbot = gr.Chatbot(
label="Medical Consultation",
height=500,
container=True,
type="messages" # This is now properly supported by our message format
)
with gr.Column(scale=1):
with gr.Accordion("Enter an API Key to give it more power!", open=False):
api_key = gr.Textbox(
label="OpenAI API Key (optional)",
type="password",
placeholder="sk-..."
)
with gr.Row():
with gr.Column():
modal_key = gr.Textbox(
label="Modal Labs API Key",
type="password",
placeholder="mk-..."
)
anthropic_key = gr.Textbox(
label="Anthropic API Key",
type="password",
placeholder="sk-ant-..."
)
mistral_key = gr.Textbox(
label="MistralAI API Key",
type="password",
placeholder="..."
)
with gr.Column():
nebius_key = gr.Textbox(
label="Nebius API Key",
type="password",
placeholder="..."
)
hyperbolic_key = gr.Textbox(
label="Hyperbolic Labs API Key",
type="password",
placeholder="hyp-..."
)
sambanova_key = gr.Textbox(
label="SambaNova API Key",
type="password",
placeholder="..."
)
with gr.Row():
model_selector = gr.Dropdown(
choices=["OpenAI", "Modal", "Anthropic", "MistralAI", "Nebius", "Hyperbolic", "SambaNova"],
value="OpenAI",
label="Model Provider"
)
temperature = gr.Slider(
minimum=0,
maximum=1,
value=0.7,
label="Temperature"
)
# self promotion at bottom of page
gr.Markdown("""
---
### 👋 About the Creator
Hi! I'm Graham Paasch, an experienced technology professional!
🎥 **Check out my YouTube channel** for more tech content:
[Subscribe to my channel](https://www.youtube.com/channel/UCg3oUjrSYcqsL9rGk1g_lPQ)
💼 **Looking for a skilled developer?**
I'm currently seeking new opportunities! View my experience and connect on [LinkedIn](https://www.linkedin.com/in/grahampaasch/)
⭐ If you found this tool helpful, please consider:
- Subscribing to my YouTube channel
- Connecting on LinkedIn
- Sharing this tool with others in healthcare tech
""")
# Event handlers
clear_btn.click(lambda: None, None, chatbot, queue=False)
microphone.stream(
fn=viu.enhanced_process_speech,
inputs=[microphone, chatbot, api_key, model_selector, temperature],
outputs=chatbot,
show_progress="hidden",
api_name=False,
queue=True # Enable queuing for better stream handling
)
def process_audio(audio_array, sample_rate):
"""Pre-process audio for Whisper."""
if audio_array.ndim > 1:
audio_array = audio_array.mean(axis=1)
# Convert to tensor for resampling
audio_tensor = torch.FloatTensor(audio_array)
# Resample to 16kHz if needed
if sample_rate != 16000:
resampler = T.Resample(sample_rate, 16000)
audio_tensor = resampler(audio_tensor)
# Normalize
audio_tensor = audio_tensor / torch.max(torch.abs(audio_tensor))
# Convert back to numpy array and return in correct format
return {
"raw": audio_tensor.numpy(), # Key must be "raw"
"sampling_rate": 16000 # Key must be "sampling_rate"
}
# Update transcription handler
def update_live_transcription(audio):
"""Real-time transcription updates."""
if not audio or not isinstance(audio, tuple):
return ""
try:
sample_rate, audio_array = audio
features = process_audio(audio_array, sample_rate)
asr = viu.get_asr_pipeline()
result = asr(features)
return result.get("text", "").strip() if isinstance(result, dict) else str(result).strip()
except Exception as e:
print(f"Transcription error: {str(e)}")
return ""
microphone.stream(
fn=update_live_transcription,
inputs=[microphone],
outputs=transcript_box,
show_progress="hidden",
queue=True
)
clear_btn.click(
fn=lambda: (None, "", ""),
outputs=[chatbot, transcript_box, text_input],
queue=False
)
def cleanup_memory():
"""Release unused memory (placeholder for future memory management)."""
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def process_text_input(text, history):
"""Process text input with memory management."""
print("process_text_input received:", text)
if not text:
return history, "" # Return tuple to clear input
# Process the symptoms using the configured LLM
prompt = f"""Given these symptoms: '{text}'
Please provide:
1. Most likely ICD-10 codes
2. Confidence levels for each diagnosis
3. Key follow-up questions
Format as JSON with diagnoses, confidences, and follow_up fields."""
response = llm.complete(prompt)
try:
# Try to parse as JSON first
result = json.loads(response.text)
except json.JSONDecodeError:
# If not JSON, wrap in our format
result = {
"diagnoses": [],
"confidences": [],
"follow_up": str(response.text)[:1000] # Limit response length
}
new_history = history + [
{"role": "user", "content": text},
{"role": "assistant", "content": viu.format_response_for_user(result)}
]
return new_history, "" # Return empty string to clear input
# Update the submit button handler
submit_btn.click(
fn=process_text_input,
inputs=[text_input, chatbot],
outputs=[chatbot, text_input],
queue=True
).success( # Changed from .then to .success for better error handling
fn=cleanup_memory,
inputs=None,
outputs=None,
queue=False
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True, # Enable sharing via Gradio's temporary URLs
show_api=True # Shows the API documentation
)