gagannarula commited on
Commit
a774f0a
·
verified ·
1 Parent(s): 3522665

More UI changes and samples loading fix

Browse files
Files changed (1) hide show
  1. app.py +102 -86
app.py CHANGED
@@ -195,44 +195,54 @@ def prompt_lm(
195
  return results
196
 
197
 
198
- def user_message(content):
199
- return {"role": "user", "content": content}
 
 
 
 
 
 
 
 
 
200
 
 
 
 
201
 
202
- def add_user_query(
203
- chatbot_history: list[dict], audio_input: str | None, chat_input: str
204
- ) -> list[dict]:
 
 
 
 
 
 
 
205
  """Add user message to chat and get model response"""
206
  # Validate input
207
  if not chat_input.strip():
208
  return chatbot_history
209
 
210
- if not audio_input:
211
- chatbot_history.append({"role": "user", "content": chat_input.strip()})
212
- chatbot_history.append({"role": "assistant", "content": "Thinking..."})
213
- return chatbot_history
214
-
215
- # Load audio with torchaudio and compute spectrogram
216
- audio_tensor, sample_rate = torchaudio.load(audio_input)
217
- spectrogram_fig = get_spectrogram(audio_tensor)
218
- # Add gr.Plot to chatbot history
219
- chatbot_history.append(
220
- {"role": "user", "content": gr.Plot(spectrogram_fig, label="Spectrogram")}
221
- )
222
- # Add user message to chat history first
223
  chatbot_history.append({"role": "user", "content": chat_input.strip()})
224
- chatbot_history.append({"role": "assistant", "content": "Thinking..."})
225
  return chatbot_history
226
 
227
 
228
- def get_response(
229
- chatbot_history: list[dict], audio_input: str, chat_input: str
230
- ) -> list[dict]:
231
  """Generate response from the model based on user input and audio file"""
232
  try:
 
 
 
 
 
 
 
233
  response = prompt_lm(
234
  audios=[audio_input],
235
- queries=[chat_input.strip()],
236
  window_length_seconds=100_000,
237
  hop_length_seconds=100_000,
238
  )
@@ -250,11 +260,6 @@ def get_response(
250
  return chatbot_history
251
 
252
 
253
- def temp_func(chatbot_history: list[dict]):
254
- # Search for the last user message that
255
- pass
256
-
257
-
258
  def main(
259
  assets_dir: Path,
260
  cfg_path: str | Path,
@@ -283,31 +288,18 @@ def main(
283
 
284
  examples = {
285
  "Caption the audio (Lazuli Bunting)": [
286
- [
287
- user_message({"path": str(laz_audio)}),
288
- user_message("Caption the audio."),
289
- ]
290
  ],
291
  "Caption the audio (Green Tree Frog)": [
292
- [
293
- user_message({"path": str(frog_audio)}),
294
- user_message(
295
- "Caption the audio, using the common name for any animal species."
296
- ),
297
- ]
298
  ],
299
  "Caption the audio (American Robin)": [
300
- [
301
- user_message({"path": str(robin_audio)}),
302
- user_message("Caption the audio."),
303
- ]
304
- ],
305
- "Caption the audio (Warbling Vireo)": [
306
- [
307
- user_message({"path": str(vireo_audio)}),
308
- user_message("Caption the audio."),
309
- ]
310
  ],
 
311
  }
312
 
313
  with gr.Blocks(
@@ -325,12 +317,12 @@ def main(
325
  with gr.Tab("Analyze Audio"):
326
  uploaded_audio = gr.State()
327
  # Status indicator
328
- status_text = gr.Textbox(
329
- value=model_manager.get_status(),
330
- label="Model Status",
331
- interactive=False,
332
- visible=True,
333
- )
334
 
335
  with gr.Column(visible=True) as onboarding_message:
336
  gr.HTML(
@@ -383,22 +375,12 @@ def main(
383
  sources=["upload"],
384
  )
385
  with gr.Group(visible=False) as chat:
386
- chatbot = gr.Chatbot(
387
- elem_id="chatbot",
388
- type="messages",
389
- label="Chat",
390
- render_markdown=False,
391
- group_consecutive_messages=False,
392
- feedback_options=[
393
- "like",
394
- "dislike",
395
- "wrong species",
396
- "incorrect response",
397
- "other",
398
- ],
399
- resizeable=True,
400
  )
401
- gr.Markdown("### Your Query")
402
  task_dropdown = gr.Dropdown(
403
  [
404
  "What are the common names for the species in the audio, if any?",
@@ -418,21 +400,37 @@ def main(
418
  info="Select a task or enter a custom query below",
419
  value=None,
420
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
 
422
- def validate_and_submit(chatbot_history, audio_input, chat_input):
423
  if not chat_input or not chat_input.strip():
424
  gr.Warning("Please enter a query before sending.")
425
  return chatbot_history, chat_input
426
 
 
 
 
 
427
  # if this audio_input is the same as the CURRENT_AUDIO, set it None
428
  # else update CURRENT_AUDIO
429
  global CURRENT_AUDIO
430
- if audio_input == CURRENT_AUDIO:
431
- audio_in = None
432
- else:
433
  CURRENT_AUDIO = audio_input
434
- audio_in = audio_input
435
- return add_user_query(chatbot_history, audio_in, chat_input)
436
 
437
  chat_input = gr.Textbox(
438
  placeholder="Enter a query and press Shift+Enter to send",
@@ -458,7 +456,8 @@ def main(
458
  )
459
 
460
  clear_button = gr.ClearButton(
461
- components=[chatbot, chat_input, audio_input], visible=False
 
462
  )
463
 
464
  def start_chat_interface(audio_path):
@@ -466,26 +465,42 @@ def main(
466
  gr.update(visible=False), # hide onboarding message
467
  gr.update(visible=True), # show upload section
468
  gr.update(visible=True), # show chat box
 
469
  )
470
 
 
471
  audio_input.change(
472
  fn=start_chat_interface,
473
  inputs=[audio_input],
474
- outputs=[onboarding_message, upload_section, chat],
 
 
 
 
 
 
 
 
475
  )
476
 
 
 
 
 
 
477
  chat_input.submit(
478
  validate_and_submit,
479
- inputs=[chatbot, audio_input, chat_input],
480
- outputs=[chatbot],
481
  ).then(
482
  get_response,
483
- inputs=[chatbot, audio_input, chat_input],
484
  outputs=[chatbot],
485
- ).then(lambda: gr.ClearButton(visible=True), None, [clear_button])
486
-
487
- if model_manager.is_loaded:
488
- status_text.update(value=model_manager.get_status())
 
489
 
490
  clear_button.click(
491
  lambda: gr.ClearButton(visible=False), None, [clear_button]
@@ -493,10 +508,11 @@ def main(
493
 
494
  with gr.Tab("Sample Library"):
495
  gr.Markdown("## Sample Library\n\nExplore example audio files below.")
 
496
  gr.Examples(
497
  list(examples.values()),
498
- chatbot,
499
- chatbot,
500
  example_labels=list(examples.keys()),
501
  examples_per_page=20,
502
  )
 
195
  return results
196
 
197
 
198
+ def make_spectrogram_figure(audio_input: str) -> list[dict]:
199
+ # Load audio with torchaudio and compute spectrogram
200
+ if not audio_input:
201
+ # Return an empty figure if no audio input is provided
202
+ return get_spectrogram(torch.zeros(1, SAMPLE_RATE))
203
+
204
+ # Check if file exists and is accessible
205
+ try:
206
+ if not Path(audio_input).exists():
207
+ print(f"Audio file does not exist: {audio_input}")
208
+ return get_spectrogram(torch.zeros(1, SAMPLE_RATE))
209
 
210
+ if not Path(audio_input).is_file():
211
+ print(f"Path is not a valid file: {audio_input}")
212
+ return get_spectrogram(torch.zeros(1, SAMPLE_RATE))
213
 
214
+ audio_tensor, sample_rate = torchaudio.load(audio_input)
215
+ spectrogram_fig = get_spectrogram(audio_tensor)
216
+ return spectrogram_fig
217
+ except Exception as e:
218
+ print(f"Error loading audio file {audio_input}: {e}")
219
+ # Return an empty spectrogram on error
220
+ return get_spectrogram(torch.zeros(1, SAMPLE_RATE))
221
+
222
+
223
+ def add_user_query(chatbot_history: list[dict], chat_input: str) -> list[dict]:
224
  """Add user message to chat and get model response"""
225
  # Validate input
226
  if not chat_input.strip():
227
  return chatbot_history
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  chatbot_history.append({"role": "user", "content": chat_input.strip()})
 
230
  return chatbot_history
231
 
232
 
233
+ def get_response(chatbot_history: list[dict], audio_input: str) -> list[dict]:
 
 
234
  """Generate response from the model based on user input and audio file"""
235
  try:
236
+ # Get the last user message from chat history
237
+ last_user_message = ""
238
+ for message in reversed(chatbot_history):
239
+ if message["role"] == "user":
240
+ last_user_message = message["content"]
241
+ break
242
+
243
  response = prompt_lm(
244
  audios=[audio_input],
245
+ queries=[last_user_message.strip()],
246
  window_length_seconds=100_000,
247
  hop_length_seconds=100_000,
248
  )
 
260
  return chatbot_history
261
 
262
 
 
 
 
 
 
263
  def main(
264
  assets_dir: Path,
265
  cfg_path: str | Path,
 
288
 
289
  examples = {
290
  "Caption the audio (Lazuli Bunting)": [
291
+ str(laz_audio),
292
+ "What is the common name for the focal species in the audio?",
 
 
293
  ],
294
  "Caption the audio (Green Tree Frog)": [
295
+ str(frog_audio),
296
+ "Caption the audio, using the common name for any animal species.",
 
 
 
 
297
  ],
298
  "Caption the audio (American Robin)": [
299
+ str(robin_audio),
300
+ "Caption the audio, using the scientific name for any animal species.",
 
 
 
 
 
 
 
 
301
  ],
302
+ "Caption the audio (Warbling Vireo)": [str(vireo_audio), "Caption the audio."],
303
  }
304
 
305
  with gr.Blocks(
 
317
  with gr.Tab("Analyze Audio"):
318
  uploaded_audio = gr.State()
319
  # Status indicator
320
+ # status_text = gr.Textbox(
321
+ # value=model_manager.get_status(),
322
+ # label="Model Status",
323
+ # interactive=False,
324
+ # visible=True,
325
+ # )
326
 
327
  with gr.Column(visible=True) as onboarding_message:
328
  gr.HTML(
 
375
  sources=["upload"],
376
  )
377
  with gr.Group(visible=False) as chat:
378
+ plotter = gr.Plot(
379
+ get_spectrogram(torch.zeros(1, SAMPLE_RATE)),
380
+ label="Spectrogram",
381
+ visible=False,
382
+ elem_id="spectrogram-plot",
 
 
 
 
 
 
 
 
 
383
  )
 
384
  task_dropdown = gr.Dropdown(
385
  [
386
  "What are the common names for the species in the audio, if any?",
 
400
  info="Select a task or enter a custom query below",
401
  value=None,
402
  )
403
+ chatbot = gr.Chatbot(
404
+ elem_id="chatbot",
405
+ type="messages",
406
+ label="Chat",
407
+ render_markdown=False,
408
+ group_consecutive_messages=False,
409
+ feedback_options=[
410
+ "like",
411
+ "dislike",
412
+ "wrong species",
413
+ "incorrect response",
414
+ "other",
415
+ ],
416
+ resizeable=True,
417
+ )
418
+ gr.Markdown("### Your Query")
419
 
420
+ def validate_and_submit(chatbot_history, chat_input):
421
  if not chat_input or not chat_input.strip():
422
  gr.Warning("Please enter a query before sending.")
423
  return chatbot_history, chat_input
424
 
425
+ updated_history = add_user_query(chatbot_history, chat_input)
426
+ return updated_history, ""
427
+
428
+ def update_current_audio(audio_input):
429
  # if this audio_input is the same as the CURRENT_AUDIO, set it None
430
  # else update CURRENT_AUDIO
431
  global CURRENT_AUDIO
432
+ if audio_input != CURRENT_AUDIO:
 
 
433
  CURRENT_AUDIO = audio_input
 
 
434
 
435
  chat_input = gr.Textbox(
436
  placeholder="Enter a query and press Shift+Enter to send",
 
456
  )
457
 
458
  clear_button = gr.ClearButton(
459
+ components=[chatbot, chat_input, audio_input, plotter],
460
+ visible=False,
461
  )
462
 
463
  def start_chat_interface(audio_path):
 
465
  gr.update(visible=False), # hide onboarding message
466
  gr.update(visible=True), # show upload section
467
  gr.update(visible=True), # show chat box
468
+ gr.update(visible=True), # show plotter
469
  )
470
 
471
+ # When audio added, set spectrogram
472
  audio_input.change(
473
  fn=start_chat_interface,
474
  inputs=[audio_input],
475
+ outputs=[onboarding_message, upload_section, chat, plotter],
476
+ ).then(
477
+ fn=update_current_audio,
478
+ inputs=[audio_input],
479
+ outputs=[],
480
+ ).then(
481
+ fn=make_spectrogram_figure,
482
+ inputs=[audio_input],
483
+ outputs=[plotter],
484
  )
485
 
486
+ # When submit clicked first:
487
+ # 1. Validate and add user query to chat history
488
+ # 2. Get response from model
489
+ # 3. Clear the chat input box
490
+ # 4. Show clear button
491
  chat_input.submit(
492
  validate_and_submit,
493
+ inputs=[chatbot, chat_input],
494
+ outputs=[chatbot, chat_input],
495
  ).then(
496
  get_response,
497
+ inputs=[chatbot, audio_input],
498
  outputs=[chatbot],
499
+ ).then(
500
+ lambda: gr.update(visible=True), # Show clear button
501
+ None,
502
+ [clear_button],
503
+ )
504
 
505
  clear_button.click(
506
  lambda: gr.ClearButton(visible=False), None, [clear_button]
 
508
 
509
  with gr.Tab("Sample Library"):
510
  gr.Markdown("## Sample Library\n\nExplore example audio files below.")
511
+
512
  gr.Examples(
513
  list(examples.values()),
514
+ [audio_input, chat_input],
515
+ [audio_input, chat_input],
516
  example_labels=list(examples.keys()),
517
  examples_per_page=20,
518
  )