Spaces:
Running
on
Zero
Running
on
Zero
More UI changes and samples loading fix
Browse files
app.py
CHANGED
|
@@ -195,44 +195,54 @@ def prompt_lm(
|
|
| 195 |
return results
|
| 196 |
|
| 197 |
|
| 198 |
-
def
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
"""Add user message to chat and get model response"""
|
| 206 |
# Validate input
|
| 207 |
if not chat_input.strip():
|
| 208 |
return chatbot_history
|
| 209 |
|
| 210 |
-
if not audio_input:
|
| 211 |
-
chatbot_history.append({"role": "user", "content": chat_input.strip()})
|
| 212 |
-
chatbot_history.append({"role": "assistant", "content": "Thinking..."})
|
| 213 |
-
return chatbot_history
|
| 214 |
-
|
| 215 |
-
# Load audio with torchaudio and compute spectrogram
|
| 216 |
-
audio_tensor, sample_rate = torchaudio.load(audio_input)
|
| 217 |
-
spectrogram_fig = get_spectrogram(audio_tensor)
|
| 218 |
-
# Add gr.Plot to chatbot history
|
| 219 |
-
chatbot_history.append(
|
| 220 |
-
{"role": "user", "content": gr.Plot(spectrogram_fig, label="Spectrogram")}
|
| 221 |
-
)
|
| 222 |
-
# Add user message to chat history first
|
| 223 |
chatbot_history.append({"role": "user", "content": chat_input.strip()})
|
| 224 |
-
chatbot_history.append({"role": "assistant", "content": "Thinking..."})
|
| 225 |
return chatbot_history
|
| 226 |
|
| 227 |
|
| 228 |
-
def get_response(
|
| 229 |
-
chatbot_history: list[dict], audio_input: str, chat_input: str
|
| 230 |
-
) -> list[dict]:
|
| 231 |
"""Generate response from the model based on user input and audio file"""
|
| 232 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
response = prompt_lm(
|
| 234 |
audios=[audio_input],
|
| 235 |
-
queries=[
|
| 236 |
window_length_seconds=100_000,
|
| 237 |
hop_length_seconds=100_000,
|
| 238 |
)
|
|
@@ -250,11 +260,6 @@ def get_response(
|
|
| 250 |
return chatbot_history
|
| 251 |
|
| 252 |
|
| 253 |
-
def temp_func(chatbot_history: list[dict]):
|
| 254 |
-
# Search for the last user message that
|
| 255 |
-
pass
|
| 256 |
-
|
| 257 |
-
|
| 258 |
def main(
|
| 259 |
assets_dir: Path,
|
| 260 |
cfg_path: str | Path,
|
|
@@ -283,31 +288,18 @@ def main(
|
|
| 283 |
|
| 284 |
examples = {
|
| 285 |
"Caption the audio (Lazuli Bunting)": [
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
user_message("Caption the audio."),
|
| 289 |
-
]
|
| 290 |
],
|
| 291 |
"Caption the audio (Green Tree Frog)": [
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
user_message(
|
| 295 |
-
"Caption the audio, using the common name for any animal species."
|
| 296 |
-
),
|
| 297 |
-
]
|
| 298 |
],
|
| 299 |
"Caption the audio (American Robin)": [
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
user_message("Caption the audio."),
|
| 303 |
-
]
|
| 304 |
-
],
|
| 305 |
-
"Caption the audio (Warbling Vireo)": [
|
| 306 |
-
[
|
| 307 |
-
user_message({"path": str(vireo_audio)}),
|
| 308 |
-
user_message("Caption the audio."),
|
| 309 |
-
]
|
| 310 |
],
|
|
|
|
| 311 |
}
|
| 312 |
|
| 313 |
with gr.Blocks(
|
|
@@ -325,12 +317,12 @@ def main(
|
|
| 325 |
with gr.Tab("Analyze Audio"):
|
| 326 |
uploaded_audio = gr.State()
|
| 327 |
# Status indicator
|
| 328 |
-
status_text = gr.Textbox(
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
)
|
| 334 |
|
| 335 |
with gr.Column(visible=True) as onboarding_message:
|
| 336 |
gr.HTML(
|
|
@@ -383,22 +375,12 @@ def main(
|
|
| 383 |
sources=["upload"],
|
| 384 |
)
|
| 385 |
with gr.Group(visible=False) as chat:
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
group_consecutive_messages=False,
|
| 392 |
-
feedback_options=[
|
| 393 |
-
"like",
|
| 394 |
-
"dislike",
|
| 395 |
-
"wrong species",
|
| 396 |
-
"incorrect response",
|
| 397 |
-
"other",
|
| 398 |
-
],
|
| 399 |
-
resizeable=True,
|
| 400 |
)
|
| 401 |
-
gr.Markdown("### Your Query")
|
| 402 |
task_dropdown = gr.Dropdown(
|
| 403 |
[
|
| 404 |
"What are the common names for the species in the audio, if any?",
|
|
@@ -418,21 +400,37 @@ def main(
|
|
| 418 |
info="Select a task or enter a custom query below",
|
| 419 |
value=None,
|
| 420 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
|
| 422 |
-
def validate_and_submit(chatbot_history,
|
| 423 |
if not chat_input or not chat_input.strip():
|
| 424 |
gr.Warning("Please enter a query before sending.")
|
| 425 |
return chatbot_history, chat_input
|
| 426 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
# if this audio_input is the same as the CURRENT_AUDIO, set it None
|
| 428 |
# else update CURRENT_AUDIO
|
| 429 |
global CURRENT_AUDIO
|
| 430 |
-
if audio_input
|
| 431 |
-
audio_in = None
|
| 432 |
-
else:
|
| 433 |
CURRENT_AUDIO = audio_input
|
| 434 |
-
audio_in = audio_input
|
| 435 |
-
return add_user_query(chatbot_history, audio_in, chat_input)
|
| 436 |
|
| 437 |
chat_input = gr.Textbox(
|
| 438 |
placeholder="Enter a query and press Shift+Enter to send",
|
|
@@ -458,7 +456,8 @@ def main(
|
|
| 458 |
)
|
| 459 |
|
| 460 |
clear_button = gr.ClearButton(
|
| 461 |
-
components=[chatbot, chat_input, audio_input],
|
|
|
|
| 462 |
)
|
| 463 |
|
| 464 |
def start_chat_interface(audio_path):
|
|
@@ -466,26 +465,42 @@ def main(
|
|
| 466 |
gr.update(visible=False), # hide onboarding message
|
| 467 |
gr.update(visible=True), # show upload section
|
| 468 |
gr.update(visible=True), # show chat box
|
|
|
|
| 469 |
)
|
| 470 |
|
|
|
|
| 471 |
audio_input.change(
|
| 472 |
fn=start_chat_interface,
|
| 473 |
inputs=[audio_input],
|
| 474 |
-
outputs=[onboarding_message, upload_section, chat],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
)
|
| 476 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
chat_input.submit(
|
| 478 |
validate_and_submit,
|
| 479 |
-
inputs=[chatbot,
|
| 480 |
-
outputs=[chatbot],
|
| 481 |
).then(
|
| 482 |
get_response,
|
| 483 |
-
inputs=[chatbot, audio_input
|
| 484 |
outputs=[chatbot],
|
| 485 |
-
).then(
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
|
|
|
| 489 |
|
| 490 |
clear_button.click(
|
| 491 |
lambda: gr.ClearButton(visible=False), None, [clear_button]
|
|
@@ -493,10 +508,11 @@ def main(
|
|
| 493 |
|
| 494 |
with gr.Tab("Sample Library"):
|
| 495 |
gr.Markdown("## Sample Library\n\nExplore example audio files below.")
|
|
|
|
| 496 |
gr.Examples(
|
| 497 |
list(examples.values()),
|
| 498 |
-
|
| 499 |
-
|
| 500 |
example_labels=list(examples.keys()),
|
| 501 |
examples_per_page=20,
|
| 502 |
)
|
|
|
|
| 195 |
return results
|
| 196 |
|
| 197 |
|
| 198 |
+
def make_spectrogram_figure(audio_input: str) -> list[dict]:
|
| 199 |
+
# Load audio with torchaudio and compute spectrogram
|
| 200 |
+
if not audio_input:
|
| 201 |
+
# Return an empty figure if no audio input is provided
|
| 202 |
+
return get_spectrogram(torch.zeros(1, SAMPLE_RATE))
|
| 203 |
+
|
| 204 |
+
# Check if file exists and is accessible
|
| 205 |
+
try:
|
| 206 |
+
if not Path(audio_input).exists():
|
| 207 |
+
print(f"Audio file does not exist: {audio_input}")
|
| 208 |
+
return get_spectrogram(torch.zeros(1, SAMPLE_RATE))
|
| 209 |
|
| 210 |
+
if not Path(audio_input).is_file():
|
| 211 |
+
print(f"Path is not a valid file: {audio_input}")
|
| 212 |
+
return get_spectrogram(torch.zeros(1, SAMPLE_RATE))
|
| 213 |
|
| 214 |
+
audio_tensor, sample_rate = torchaudio.load(audio_input)
|
| 215 |
+
spectrogram_fig = get_spectrogram(audio_tensor)
|
| 216 |
+
return spectrogram_fig
|
| 217 |
+
except Exception as e:
|
| 218 |
+
print(f"Error loading audio file {audio_input}: {e}")
|
| 219 |
+
# Return an empty spectrogram on error
|
| 220 |
+
return get_spectrogram(torch.zeros(1, SAMPLE_RATE))
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def add_user_query(chatbot_history: list[dict], chat_input: str) -> list[dict]:
|
| 224 |
"""Add user message to chat and get model response"""
|
| 225 |
# Validate input
|
| 226 |
if not chat_input.strip():
|
| 227 |
return chatbot_history
|
| 228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
chatbot_history.append({"role": "user", "content": chat_input.strip()})
|
|
|
|
| 230 |
return chatbot_history
|
| 231 |
|
| 232 |
|
| 233 |
+
def get_response(chatbot_history: list[dict], audio_input: str) -> list[dict]:
|
|
|
|
|
|
|
| 234 |
"""Generate response from the model based on user input and audio file"""
|
| 235 |
try:
|
| 236 |
+
# Get the last user message from chat history
|
| 237 |
+
last_user_message = ""
|
| 238 |
+
for message in reversed(chatbot_history):
|
| 239 |
+
if message["role"] == "user":
|
| 240 |
+
last_user_message = message["content"]
|
| 241 |
+
break
|
| 242 |
+
|
| 243 |
response = prompt_lm(
|
| 244 |
audios=[audio_input],
|
| 245 |
+
queries=[last_user_message.strip()],
|
| 246 |
window_length_seconds=100_000,
|
| 247 |
hop_length_seconds=100_000,
|
| 248 |
)
|
|
|
|
| 260 |
return chatbot_history
|
| 261 |
|
| 262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
def main(
|
| 264 |
assets_dir: Path,
|
| 265 |
cfg_path: str | Path,
|
|
|
|
| 288 |
|
| 289 |
examples = {
|
| 290 |
"Caption the audio (Lazuli Bunting)": [
|
| 291 |
+
str(laz_audio),
|
| 292 |
+
"What is the common name for the focal species in the audio?",
|
|
|
|
|
|
|
| 293 |
],
|
| 294 |
"Caption the audio (Green Tree Frog)": [
|
| 295 |
+
str(frog_audio),
|
| 296 |
+
"Caption the audio, using the common name for any animal species.",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
],
|
| 298 |
"Caption the audio (American Robin)": [
|
| 299 |
+
str(robin_audio),
|
| 300 |
+
"Caption the audio, using the scientific name for any animal species.",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
],
|
| 302 |
+
"Caption the audio (Warbling Vireo)": [str(vireo_audio), "Caption the audio."],
|
| 303 |
}
|
| 304 |
|
| 305 |
with gr.Blocks(
|
|
|
|
| 317 |
with gr.Tab("Analyze Audio"):
|
| 318 |
uploaded_audio = gr.State()
|
| 319 |
# Status indicator
|
| 320 |
+
# status_text = gr.Textbox(
|
| 321 |
+
# value=model_manager.get_status(),
|
| 322 |
+
# label="Model Status",
|
| 323 |
+
# interactive=False,
|
| 324 |
+
# visible=True,
|
| 325 |
+
# )
|
| 326 |
|
| 327 |
with gr.Column(visible=True) as onboarding_message:
|
| 328 |
gr.HTML(
|
|
|
|
| 375 |
sources=["upload"],
|
| 376 |
)
|
| 377 |
with gr.Group(visible=False) as chat:
|
| 378 |
+
plotter = gr.Plot(
|
| 379 |
+
get_spectrogram(torch.zeros(1, SAMPLE_RATE)),
|
| 380 |
+
label="Spectrogram",
|
| 381 |
+
visible=False,
|
| 382 |
+
elem_id="spectrogram-plot",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
)
|
|
|
|
| 384 |
task_dropdown = gr.Dropdown(
|
| 385 |
[
|
| 386 |
"What are the common names for the species in the audio, if any?",
|
|
|
|
| 400 |
info="Select a task or enter a custom query below",
|
| 401 |
value=None,
|
| 402 |
)
|
| 403 |
+
chatbot = gr.Chatbot(
|
| 404 |
+
elem_id="chatbot",
|
| 405 |
+
type="messages",
|
| 406 |
+
label="Chat",
|
| 407 |
+
render_markdown=False,
|
| 408 |
+
group_consecutive_messages=False,
|
| 409 |
+
feedback_options=[
|
| 410 |
+
"like",
|
| 411 |
+
"dislike",
|
| 412 |
+
"wrong species",
|
| 413 |
+
"incorrect response",
|
| 414 |
+
"other",
|
| 415 |
+
],
|
| 416 |
+
resizeable=True,
|
| 417 |
+
)
|
| 418 |
+
gr.Markdown("### Your Query")
|
| 419 |
|
| 420 |
+
def validate_and_submit(chatbot_history, chat_input):
|
| 421 |
if not chat_input or not chat_input.strip():
|
| 422 |
gr.Warning("Please enter a query before sending.")
|
| 423 |
return chatbot_history, chat_input
|
| 424 |
|
| 425 |
+
updated_history = add_user_query(chatbot_history, chat_input)
|
| 426 |
+
return updated_history, ""
|
| 427 |
+
|
| 428 |
+
def update_current_audio(audio_input):
|
| 429 |
# if this audio_input is the same as the CURRENT_AUDIO, set it None
|
| 430 |
# else update CURRENT_AUDIO
|
| 431 |
global CURRENT_AUDIO
|
| 432 |
+
if audio_input != CURRENT_AUDIO:
|
|
|
|
|
|
|
| 433 |
CURRENT_AUDIO = audio_input
|
|
|
|
|
|
|
| 434 |
|
| 435 |
chat_input = gr.Textbox(
|
| 436 |
placeholder="Enter a query and press Shift+Enter to send",
|
|
|
|
| 456 |
)
|
| 457 |
|
| 458 |
clear_button = gr.ClearButton(
|
| 459 |
+
components=[chatbot, chat_input, audio_input, plotter],
|
| 460 |
+
visible=False,
|
| 461 |
)
|
| 462 |
|
| 463 |
def start_chat_interface(audio_path):
|
|
|
|
| 465 |
gr.update(visible=False), # hide onboarding message
|
| 466 |
gr.update(visible=True), # show upload section
|
| 467 |
gr.update(visible=True), # show chat box
|
| 468 |
+
gr.update(visible=True), # show plotter
|
| 469 |
)
|
| 470 |
|
| 471 |
+
# When audio added, set spectrogram
|
| 472 |
audio_input.change(
|
| 473 |
fn=start_chat_interface,
|
| 474 |
inputs=[audio_input],
|
| 475 |
+
outputs=[onboarding_message, upload_section, chat, plotter],
|
| 476 |
+
).then(
|
| 477 |
+
fn=update_current_audio,
|
| 478 |
+
inputs=[audio_input],
|
| 479 |
+
outputs=[],
|
| 480 |
+
).then(
|
| 481 |
+
fn=make_spectrogram_figure,
|
| 482 |
+
inputs=[audio_input],
|
| 483 |
+
outputs=[plotter],
|
| 484 |
)
|
| 485 |
|
| 486 |
+
# When submit clicked first:
|
| 487 |
+
# 1. Validate and add user query to chat history
|
| 488 |
+
# 2. Get response from model
|
| 489 |
+
# 3. Clear the chat input box
|
| 490 |
+
# 4. Show clear button
|
| 491 |
chat_input.submit(
|
| 492 |
validate_and_submit,
|
| 493 |
+
inputs=[chatbot, chat_input],
|
| 494 |
+
outputs=[chatbot, chat_input],
|
| 495 |
).then(
|
| 496 |
get_response,
|
| 497 |
+
inputs=[chatbot, audio_input],
|
| 498 |
outputs=[chatbot],
|
| 499 |
+
).then(
|
| 500 |
+
lambda: gr.update(visible=True), # Show clear button
|
| 501 |
+
None,
|
| 502 |
+
[clear_button],
|
| 503 |
+
)
|
| 504 |
|
| 505 |
clear_button.click(
|
| 506 |
lambda: gr.ClearButton(visible=False), None, [clear_button]
|
|
|
|
| 508 |
|
| 509 |
with gr.Tab("Sample Library"):
|
| 510 |
gr.Markdown("## Sample Library\n\nExplore example audio files below.")
|
| 511 |
+
|
| 512 |
gr.Examples(
|
| 513 |
list(examples.values()),
|
| 514 |
+
[audio_input, chat_input],
|
| 515 |
+
[audio_input, chat_input],
|
| 516 |
example_labels=list(examples.keys()),
|
| 517 |
examples_per_page=20,
|
| 518 |
)
|