Spaces:
Runtime error
Runtime error
Commit
Β·
f48d0d7
1
Parent(s):
a7f8e09
included new user and updated buttons
Browse files- interface.py +18 -17
- medrax/tools/llava_med.py +2 -0
interface.py
CHANGED
|
@@ -16,7 +16,8 @@ REPORT_DIR.mkdir(exist_ok=True)
|
|
| 16 |
SALT = b'$2b$12$MC7djiqmIR7154Syul5Wme'
|
| 17 |
|
| 18 |
USERS = {
|
| 19 |
-
'test_user': b'$2b$12$MC7djiqmIR7154Syul5WmeQwebwsNOK5svMX08zMYhvpF9P9IVXe6'
|
|
|
|
| 20 |
}
|
| 21 |
|
| 22 |
class ChatInterface:
|
|
@@ -73,7 +74,7 @@ class ChatInterface:
|
|
| 73 |
else:
|
| 74 |
self.display_file_path = str(saved_path)
|
| 75 |
|
| 76 |
-
return self.display_file_path, gr.update(interactive=True), gr.update(interactive=True)
|
| 77 |
|
| 78 |
def add_message(
|
| 79 |
self, message: str, display_image: str, history: List[dict]
|
|
@@ -266,7 +267,7 @@ def create_demo(agent, tools_dict):
|
|
| 266 |
)
|
| 267 |
with gr.Row():
|
| 268 |
analyze_btn = gr.Button("Analyze", interactive=False)
|
| 269 |
-
ground_btn = gr.Button("Ground", interactive=False)
|
| 270 |
segment_btn = gr.Button("Segment", interactive=False)
|
| 271 |
with gr.Row():
|
| 272 |
clear_btn = gr.Button("Clear Chat")
|
|
@@ -394,38 +395,38 @@ def create_demo(agent, tools_dict):
|
|
| 394 |
bot_msg.then(lambda: gr.Textbox(interactive=True), None, [txt])
|
| 395 |
|
| 396 |
analyze_btn.click(
|
| 397 |
-
lambda: gr.update(value="Analyze
|
| 398 |
).then(
|
| 399 |
-
interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt]
|
| 400 |
).then(
|
| 401 |
interface.process_message,
|
| 402 |
inputs=[txt, image_display, chatbot],
|
| 403 |
outputs=[chatbot, image_display, txt],
|
| 404 |
).then(lambda: gr.Textbox(interactive=True), None, [txt])
|
| 405 |
|
| 406 |
-
ground_btn.click(
|
| 407 |
-
|
| 408 |
-
).then(
|
| 409 |
-
|
| 410 |
-
).then(
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
).then(lambda: gr.Textbox(interactive=True), None, [txt])
|
| 415 |
|
| 416 |
segment_btn.click(
|
| 417 |
lambda: gr.update(value="Segment the major affected lung"), None, txt
|
| 418 |
).then(
|
| 419 |
-
interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt]
|
| 420 |
).then(
|
| 421 |
interface.process_message,
|
| 422 |
inputs=[txt, image_display, chatbot],
|
| 423 |
outputs=[chatbot, image_display, txt],
|
| 424 |
).then(lambda: gr.Textbox(interactive=True), None, [txt])
|
| 425 |
|
| 426 |
-
upload_button.upload(handle_file_upload, inputs=upload_button, outputs=[image_display, analyze_btn,
|
| 427 |
|
| 428 |
-
dicom_upload.upload(handle_file_upload, inputs=dicom_upload, outputs=[image_display, analyze_btn,
|
| 429 |
|
| 430 |
clear_btn.click(clear_chat, outputs=[chatbot, image_display])
|
| 431 |
new_thread_btn.click(new_thread, outputs=[chatbot, image_display])
|
|
|
|
| 16 |
SALT = b'$2b$12$MC7djiqmIR7154Syul5Wme'
|
| 17 |
|
| 18 |
USERS = {
|
| 19 |
+
'test_user': b'$2b$12$MC7djiqmIR7154Syul5WmeQwebwsNOK5svMX08zMYhvpF9P9IVXe6',
|
| 20 |
+
'pna': b'$2b$12$MC7djiqmIR7154Syul5WmeWTzYft1UnOV4uGVn54FGfmbH3dRNq1C'
|
| 21 |
}
|
| 22 |
|
| 23 |
class ChatInterface:
|
|
|
|
| 74 |
else:
|
| 75 |
self.display_file_path = str(saved_path)
|
| 76 |
|
| 77 |
+
return self.display_file_path, gr.update(interactive=True), gr.update(interactive=True)
|
| 78 |
|
| 79 |
def add_message(
|
| 80 |
self, message: str, display_image: str, history: List[dict]
|
|
|
|
| 267 |
)
|
| 268 |
with gr.Row():
|
| 269 |
analyze_btn = gr.Button("Analyze", interactive=False)
|
| 270 |
+
# ground_btn = gr.Button("Ground", interactive=False)
|
| 271 |
segment_btn = gr.Button("Segment", interactive=False)
|
| 272 |
with gr.Row():
|
| 273 |
clear_btn = gr.Button("Clear Chat")
|
|
|
|
| 395 |
bot_msg.then(lambda: gr.Textbox(interactive=True), None, [txt])
|
| 396 |
|
| 397 |
analyze_btn.click(
|
| 398 |
+
lambda: gr.update(value="Analyze this xray and give me a detailed response. Use the medgemma_xray_expert tool"), None, txt
|
| 399 |
).then(
|
| 400 |
+
interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt]
|
| 401 |
).then(
|
| 402 |
interface.process_message,
|
| 403 |
inputs=[txt, image_display, chatbot],
|
| 404 |
outputs=[chatbot, image_display, txt],
|
| 405 |
).then(lambda: gr.Textbox(interactive=True), None, [txt])
|
| 406 |
|
| 407 |
+
# ground_btn.click(
|
| 408 |
+
# lambda: gr.update(value="Ground the main disease in this CXR"), None, txt
|
| 409 |
+
# ).then(
|
| 410 |
+
# interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt]
|
| 411 |
+
# ).then(
|
| 412 |
+
# interface.process_message,
|
| 413 |
+
# inputs=[txt, image_display, chatbot],
|
| 414 |
+
# outputs=[chatbot, image_display, txt],
|
| 415 |
+
# ).then(lambda: gr.Textbox(interactive=True), None, [txt])
|
| 416 |
|
| 417 |
segment_btn.click(
|
| 418 |
lambda: gr.update(value="Segment the major affected lung"), None, txt
|
| 419 |
).then(
|
| 420 |
+
interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt]
|
| 421 |
).then(
|
| 422 |
interface.process_message,
|
| 423 |
inputs=[txt, image_display, chatbot],
|
| 424 |
outputs=[chatbot, image_display, txt],
|
| 425 |
).then(lambda: gr.Textbox(interactive=True), None, [txt])
|
| 426 |
|
| 427 |
+
upload_button.upload(handle_file_upload, inputs=upload_button, outputs=[image_display, analyze_btn, segment_btn])
|
| 428 |
|
| 429 |
+
dicom_upload.upload(handle_file_upload, inputs=dicom_upload, outputs=[image_display, analyze_btn, segment_btn])
|
| 430 |
|
| 431 |
clear_btn.click(clear_chat, outputs=[chatbot, image_display])
|
| 432 |
new_thread_btn.click(new_thread, outputs=[chatbot, image_display])
|
medrax/tools/llava_med.py
CHANGED
|
@@ -56,6 +56,7 @@ class LlavaMedTool(BaseTool):
|
|
| 56 |
def __init__(
|
| 57 |
self,
|
| 58 |
model_path: str = "microsoft/llava-med-v1.5-mistral-7b",
|
|
|
|
| 59 |
cache_dir: str = "/model-weights",
|
| 60 |
low_cpu_mem_usage: bool = True,
|
| 61 |
torch_dtype: torch.dtype = torch.bfloat16,
|
|
@@ -68,6 +69,7 @@ class LlavaMedTool(BaseTool):
|
|
| 68 |
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
|
| 69 |
model_path=model_path,
|
| 70 |
model_base=None,
|
|
|
|
| 71 |
model_name=model_path,
|
| 72 |
load_in_4bit=load_in_4bit,
|
| 73 |
load_in_8bit=load_in_8bit,
|
|
|
|
| 56 |
def __init__(
|
| 57 |
self,
|
| 58 |
model_path: str = "microsoft/llava-med-v1.5-mistral-7b",
|
| 59 |
+
# model_path: str = "microsoft/llava-rad",
|
| 60 |
cache_dir: str = "/model-weights",
|
| 61 |
low_cpu_mem_usage: bool = True,
|
| 62 |
torch_dtype: torch.dtype = torch.bfloat16,
|
|
|
|
| 69 |
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
|
| 70 |
model_path=model_path,
|
| 71 |
model_base=None,
|
| 72 |
+
# model_base="lmsys/vicuna-7b-v1.5",
|
| 73 |
model_name=model_path,
|
| 74 |
load_in_4bit=load_in_4bit,
|
| 75 |
load_in_8bit=load_in_8bit,
|