Spaces:
Running
on
Zero
Running
on
Zero
| import warnings | |
| import numpy as np | |
| from pathlib import Path | |
| from typing import Optional | |
| from collections import Counter | |
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| import matplotlib.pyplot as plt | |
| from NatureLM.config import Config | |
| from NatureLM.models.NatureLM import NatureLM | |
| from NatureLM.infer import Pipeline | |
| import spaces | |
| warnings.filterwarnings("ignore") | |
| SAMPLE_RATE = 16000 # Default sample rate for NatureLM-audio | |
| def get_spectrogram(audio: torch.Tensor) -> plt.Figure: | |
| """Generate a spectrogram from the audio tensor.""" | |
| spectrogram = torchaudio.transforms.Spectrogram(n_fft=1024)(audio) | |
| spectrogram = spectrogram.numpy()[0].squeeze() | |
| # Convert to matplotlib figure with imshow | |
| fig, ax = plt.subplots(figsize=(13, 5)) | |
| ax.imshow(np.log(spectrogram + 1e-3), aspect="auto", origin="lower", cmap="viridis") | |
| ax.set_title("Spectrogram") | |
| ax.set_xlabel("Time") | |
| # Set x ticks to reflect 0 to audio duration seconds | |
| if audio.dim() > 1: | |
| duration = audio.size(1) / SAMPLE_RATE | |
| else: | |
| duration = audio.size(0) / SAMPLE_RATE | |
| ax.set_xticks([0, spectrogram.shape[1]]) | |
| ax.set_xticklabels(["0s", f"{duration:.2f}s"]) | |
| ax.set_ylabel("Frequency") | |
| # Set y ticks to reflect 0 to nyquist frequency (sample_rate/2) | |
| nyquist_freq = SAMPLE_RATE / 2 | |
| ax.set_yticks( | |
| [ | |
| 0, | |
| spectrogram.shape[0] // 4, | |
| spectrogram.shape[0] // 2, | |
| 3 * spectrogram.shape[0] // 4, | |
| spectrogram.shape[0] - 1, | |
| ] | |
| ) | |
| ax.set_yticklabels( | |
| [ | |
| "0 Hz", | |
| f"{nyquist_freq / 4:.0f} Hz", | |
| f"{nyquist_freq / 2:.0f} Hz", | |
| f"{3 * nyquist_freq / 4:.0f} Hz", | |
| f"{nyquist_freq:.0f} Hz", | |
| ] | |
| ) | |
| fig.tight_layout() | |
| return fig | |
| class ModelManager: | |
| """Manages model loading and state""" | |
| def __init__(self): | |
| self.model: Optional[NatureLM] = None | |
| self.config: Optional[Config] = None | |
| self.is_loaded = False | |
| self.is_loading = False | |
| self.load_failed = False | |
| def check_availability(self) -> tuple[bool, str]: | |
| """Check if the model is available for download""" | |
| try: | |
| from huggingface_hub import model_info | |
| info = model_info("EarthSpeciesProject/NatureLM-audio") | |
| return True, "Model is available" | |
| except Exception as e: | |
| return False, f"Model not available: {str(e)}" | |
| def reset_state(self): | |
| """Reset the model loading state to allow retrying after a failure""" | |
| self.model = None | |
| self.is_loaded = False | |
| self.is_loading = False | |
| self.load_failed = False | |
| return self.get_status() | |
| def get_status(self) -> str: | |
| """Get the current model loading status""" | |
| if self.is_loaded: | |
| return "✅ Model loaded and ready" | |
| elif self.is_loading: | |
| return "🔄 Loading model... Please wait" | |
| elif self.load_failed: | |
| return "❌ Model failed to load. Please check the configuration." | |
| else: | |
| return "⏳ Ready to load model on first use" | |
| def load_model(self) -> Optional[NatureLM]: | |
| """Load the model if needed""" | |
| if self.is_loaded: | |
| return self.model | |
| if self.is_loading or self.load_failed: | |
| return None | |
| try: | |
| self.is_loading = True | |
| print("Loading model...") | |
| # Check if model is available first | |
| available, message = self.check_availability() | |
| if not available: | |
| raise Exception(f"Model not available: {message}") | |
| model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio") | |
| model.to("cpu") | |
| model.eval() | |
| pipe = Pipeline(model) | |
| self.model = pipe | |
| self.is_loaded = True | |
| self.is_loading = False | |
| print("Model loaded successfully!") | |
| return pipe | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| self.is_loading = False | |
| self.load_failed = True | |
| return None | |
| # Global model manager instance | |
| model_manager = ModelManager() | |
| def take_majority_vote(results: list[list[dict]]) -> list[str]: | |
| """For each audio file, take the majority vote of the labels across all windows""" | |
| outputs = [] | |
| for result in results: | |
| predictions = [window["prediction"] for window in result] | |
| if not predictions: | |
| continue | |
| # Count occurrences of each label | |
| counts = Counter(predictions) | |
| # Find the most common label | |
| most_common_label, _ = counts.most_common(1)[0] | |
| outputs.append(most_common_label) | |
| return outputs | |
| def prompt_lm( | |
| audios: list[str], | |
| queries: list[str] | str, | |
| window_length_seconds: float = 10.0, | |
| hop_length_seconds: float = 10.0, | |
| progress=gr.Progress(), | |
| ) -> list[str]: | |
| """Generate response using the model | |
| Args: | |
| audios (list[str]): List of audio file paths | |
| queries (list[str] | str): Query or list of queries to process | |
| window_length_seconds (float): Length of the window for processing audio | |
| hop_length_seconds (float): Hop length for processing audio | |
| Returns: | |
| list[str]: List of generated responses for each audio-query pair | |
| """ | |
| model = model_manager.load_model() | |
| if model is None: | |
| if model_manager.is_loading: | |
| return "🔄 Loading model... This may take a few minutes on first use. Please try again in a moment." | |
| elif model_manager.load_failed: | |
| return "❌ Model failed to load. This could be due to:\n• No internet connection\n• Insufficient disk space\n• Model repository access issues\n\nPlease check your connection and try again using the retry button." | |
| else: | |
| return "Demo mode: Model not loaded. Please check the model configuration." | |
| results: list[list[dict]] = model( | |
| audios, | |
| queries, | |
| window_length_seconds=window_length_seconds, | |
| hop_length_seconds=hop_length_seconds, | |
| input_sample_rate=None, | |
| progress_bar=progress, | |
| ) | |
| return results | |
| def user_message(content): | |
| return {"role": "user", "content": content} | |
| def add_message_and_get_response( | |
| chatbot_history: list[dict], audio_input: str, chat_input: str | |
| ) -> tuple[list[dict], str]: | |
| """Add user message to chat and get model response""" | |
| # Load audio with torchaudio and compute spectrogram | |
| audio_tensor, sample_rate = torchaudio.load(audio_input) | |
| duration = audio_tensor.size(1) / sample_rate | |
| spectrogram_fig = get_spectrogram(audio_tensor) | |
| # Add gr.Plot to chatbot history | |
| chatbot_history.append( | |
| {"role": "user", "content": gr.Plot(spectrogram_fig, label="Spectrogram")} | |
| ) | |
| # Get response | |
| try: | |
| response = prompt_lm( | |
| audios=[audio_input], | |
| queries=[chat_input], | |
| window_length_seconds=duration, | |
| hop_length_seconds=duration, | |
| ) | |
| # get first item | |
| if isinstance(response, list) and len(response) > 0: | |
| response = response[0][0]["prediction"] | |
| else: | |
| response = "No response generated." | |
| except Exception as e: | |
| print(f"Error generating response: {e}") | |
| response = "Error generating response. Please try again." | |
| # Add user message to chat history | |
| chatbot_history.append({"role": "user", "content": "Q: " + chat_input}) | |
| # Add model response to chat history | |
| chatbot_history.append({"role": "assistant", "content": response}) | |
| return chatbot_history, "" | |
| def main( | |
| assets_dir: Path, | |
| cfg_path: str | Path, | |
| options: list[str] = [], | |
| ): | |
| # Load configuration | |
| try: | |
| cfg = Config.from_sources(yaml_file=cfg_path, cli_args=options) | |
| model_manager.config = cfg | |
| print("Configuration loaded successfully") | |
| except Exception as e: | |
| print(f"Warning: Could not load config: {e}") | |
| print("Running in demo mode") | |
| model_manager.config = None | |
| # Check if assets directory exists, if not create a placeholder | |
| if not assets_dir.exists(): | |
| print(f"Warning: Assets directory {assets_dir} does not exist") | |
| assets_dir.mkdir(exist_ok=True) | |
| # Create placeholder audio files if they don't exist | |
| laz_audio = assets_dir / "Lazuli_Bunting_yell-YELLLAZB20160625SM303143.mp3" | |
| frog_audio = assets_dir / "nri-GreenTreeFrogEvergladesNP.mp3" | |
| robin_audio = assets_dir / "yell-YELLAMRO20160506SM3.mp3" | |
| vireo_audio = assets_dir / "yell-YELLWarblingVireoMammoth20150614T29ms.mp3" | |
| examples = { | |
| "Caption the audio (Lazuli Bunting)": [ | |
| [ | |
| user_message({"path": str(laz_audio)}), | |
| user_message("Caption the audio."), | |
| ] | |
| ], | |
| "Caption the audio (Green Tree Frog)": [ | |
| [ | |
| user_message({"path": str(frog_audio)}), | |
| user_message( | |
| "Caption the audio, using the common name for any animal species." | |
| ), | |
| ] | |
| ], | |
| "Caption the audio (American Robin)": [ | |
| [ | |
| user_message({"path": str(robin_audio)}), | |
| user_message("Caption the audio."), | |
| ] | |
| ], | |
| "Caption the audio (Warbling Vireo)": [ | |
| [ | |
| user_message({"path": str(vireo_audio)}), | |
| user_message("Caption the audio."), | |
| ] | |
| ], | |
| } | |
| with gr.Blocks( | |
| title="NatureLM-audio", | |
| theme=gr.themes.Base( | |
| primary_hue="blue", font=[gr.themes.GoogleFont("Noto Sans")] | |
| ), | |
| ) as app: | |
| header = gr.HTML(""" | |
| <div style="display: flex; align-items: center; gap: 12px;"><h2 style="margin: 0;">NatureLM-audio<span style="font-size: 0.55em; color: #28a745; background: #e6f4ea; padding: 2px 6px; border-radius: 4px; margin-left: 8px; display: inline-block; vertical-align: top;">BETA</span></h2></div> | |
| """) | |
| with gr.Tabs(): | |
| with gr.Tab("Analyze Audio"): | |
| uploaded_audio = gr.State() | |
| # Status indicator | |
| # status_text = gr.Textbox( | |
| # value=model_manager.get_status(), | |
| # label="Model Status", | |
| # interactive=False, | |
| # visible=True, | |
| # ) | |
| with gr.Column(visible=True) as onboarding_message: | |
| gr.HTML( | |
| """ | |
| <div style=" | |
| background: transparent; | |
| border: 1px solid #e5e7eb; | |
| border-radius: 8px; | |
| padding: 16px 20px; | |
| display: flex; | |
| align-items: center; | |
| justify-content: space-between; | |
| margin-bottom: 16px; | |
| margin-left: 0; | |
| margin-right: 0; | |
| box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); | |
| "> | |
| <div style="display: flex; padding: 0px; align-items: center; flex: 1;"> | |
| <div style="font-size: 20px; margin-right: 12px;">👋</div> | |
| <div style="flex: 1;"> | |
| <div style="font-size: 16px; font-weight: 600; color: #374151; margin-bottom: 4px;">Welcome to NatureLM-audio!</div> | |
| <div style="font-size: 14px; color: #6b7280; line-height: 1.4;">Upload your first audio file below or try a sample from our library.</div> | |
| </div> | |
| </div> | |
| <a href="https://www.earthspecies.org/blog" target="_blank" style=" | |
| padding: 6px 12px; | |
| border-radius: 6px; | |
| font-size: 13px; | |
| font-weight: 500; | |
| cursor: pointer; | |
| border: none; | |
| background: #3b82f6; | |
| color: white; | |
| text-decoration: none; | |
| display: inline-block; | |
| transition: background 0.2s ease; | |
| " | |
| onmouseover="this.style.background='#2563eb';" | |
| onmouseout="this.style.background='#3b82f6';" | |
| >View Tutorial</a> | |
| </div> | |
| """, | |
| padding=False, | |
| ) | |
| with gr.Column(visible=True) as upload_section: | |
| audio_input = gr.Audio( | |
| type="filepath", | |
| container=True, | |
| interactive=True, | |
| sources=["upload"], | |
| ) | |
| with gr.Group(visible=False) as chat: | |
| chatbot = gr.Chatbot( | |
| elem_id="chatbot", | |
| type="messages", | |
| label="Chat", | |
| render_markdown=False, | |
| feedback_options=[ | |
| "like", | |
| "dislike", | |
| "wrong species", | |
| "incorrect response", | |
| "other", | |
| ], | |
| resizeable=True, | |
| ) | |
| gr.Markdown("### Your Query") | |
| task_dropdown = gr.Dropdown( | |
| [ | |
| "What are the common names for the species in the audio, if any?", | |
| "Caption the audio.", | |
| "Caption the audio, using the scientific name for any animal species.", | |
| "Caption the audio, using the common name for any animal species.", | |
| "What is the scientific name for the focal species in the audio?", | |
| "What is the common name for the focal species in the audio?", | |
| "What is the family of the focal species in the audio?", | |
| "What is the genus of the focal species in the audio?", | |
| "What is the taxonomic name of the focal species in the audio?", | |
| "What call types are heard from the focal species in the audio?", | |
| "What is the life stage of the focal species in the audio?", | |
| ], | |
| label="Pre-configured Tasks", | |
| allow_custom_value=True, | |
| info="Select a task or enter a custom query below", | |
| ) | |
| chat_input = gr.Textbox( | |
| placeholder="e.g. 'Caption this audio'...", | |
| type="text", | |
| label="Query", | |
| lines=2, | |
| show_label=True, | |
| container=False, | |
| submit_btn="Send", | |
| elem_id="chat-input", | |
| ) | |
| # if task_dropdown is selected, set chat_input to that value | |
| def set_query(task): | |
| if task: | |
| return gr.update(value=task) | |
| return gr.update(value="") | |
| task_dropdown.change( | |
| fn=set_query, | |
| inputs=[task_dropdown], | |
| outputs=[chat_input], | |
| ) | |
| clear_button = gr.ClearButton( | |
| components=[chatbot, chat_input, audio_input], visible=False | |
| ) | |
| def start_chat_interface(audio_path): | |
| return ( | |
| gr.update(visible=False), # hide onboarding message | |
| gr.update(visible=True), # show upload section | |
| gr.update(visible=True), # show chat box | |
| ) | |
| audio_input.change( | |
| fn=start_chat_interface, | |
| inputs=[audio_input], | |
| outputs=[onboarding_message, upload_section, chat], | |
| ) | |
| chat_input.submit( | |
| add_message_and_get_response, | |
| inputs=[chatbot, audio_input, chat_input], | |
| outputs=[chatbot, chat_input], | |
| ).then(lambda: gr.ClearButton(visible=True), None, [clear_button]) | |
| clear_button.click( | |
| lambda: gr.ClearButton(visible=False), None, [clear_button] | |
| ) | |
| with gr.Tab("Sample Library"): | |
| gr.Markdown("## Sample Library\n\nExplore example audio files below.") | |
| gr.Examples( | |
| list(examples.values()), | |
| chatbot, | |
| chatbot, | |
| example_labels=list(examples.keys()), | |
| examples_per_page=20, | |
| ) | |
| with gr.Tab("💡 Help"): | |
| gr.Markdown("## User Guide") # to fill out | |
| gr.Markdown("## Share Feedback") # to fill out | |
| gr.Markdown("## FAQs") # to fill out | |
| app.css = """ | |
| .welcome-banner { | |
| background: transparent !important; | |
| border: 1px solid #e5e7eb !important; | |
| border-radius: 8px !important; | |
| padding: 16px 20px !important; | |
| margin-bottom: 16px !important; | |
| box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1) !important; | |
| } | |
| .welcome-banner > div { | |
| background: transparent !important; | |
| } | |
| .welcome-banner button { | |
| margin: 0 4px !important; | |
| } | |
| """ | |
| # Disabling Batch and Long Recording tabs for now | |
| """ with gr.Tab("Batch"): | |
| _batch_tab() | |
| with gr.Tab("Long Recording"): | |
| _long_recording_tab() """ | |
| return app | |
| # Create and launch the app | |
| app = main( | |
| assets_dir=Path("assets"), | |
| cfg_path=Path("configs/inference.yml"), | |
| options=[], | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() | |