Spaces:
Running
on
Zero
Running
on
Zero
| import re | |
| import tempfile | |
| from collections import Counter | |
| from pathlib import Path | |
| from typing import Literal | |
| import gradio as gr | |
| import torch | |
| from NatureLM.config import Config | |
| from NatureLM.models.NatureLM import NatureLM | |
| from NatureLM.utils import generate_sample_batches, prepare_sample_waveforms | |
| CONFIG: Config = None | |
| MODEL: NatureLM = None | |
| MODEL_LOADED = False | |
| MODEL_LOADING = False | |
| MODEL_LOAD_FAILED = False | |
| def check_model_availability(): | |
| """Check if the model is available for download""" | |
| try: | |
| from huggingface_hub import model_info | |
| info = model_info("EarthSpeciesProject/NatureLM-audio") | |
| return True, "Model is available" | |
| except Exception as e: | |
| return False, f"Model not available: {str(e)}" | |
| def reset_model_state(): | |
| """Reset the model loading state to allow retrying after a failure""" | |
| global MODEL, MODEL_LOADED, MODEL_LOADING, MODEL_LOAD_FAILED | |
| MODEL = None | |
| MODEL_LOADED = False | |
| MODEL_LOADING = False | |
| MODEL_LOAD_FAILED = False | |
| return get_model_status() | |
| def get_model_status(): | |
| """Get the current model loading status""" | |
| if MODEL_LOADED: | |
| return "β Model loaded and ready" | |
| elif MODEL_LOADING: | |
| return "π Loading model... Please wait" | |
| elif MODEL_LOAD_FAILED: | |
| return "β Model failed to load. Please check the configuration." | |
| else: | |
| return "β³ Ready to load model on first use" | |
| def load_model_if_needed(): | |
| """Lazy load the model when first needed""" | |
| global MODEL, MODEL_LOADED, MODEL_LOADING, MODEL_LOAD_FAILED | |
| if MODEL_LOADED: | |
| return MODEL | |
| if MODEL_LOADING: | |
| # Model is currently loading, return a message to try again | |
| return None | |
| if MODEL_LOAD_FAILED: | |
| # Model has already failed to load, don't try again | |
| return None | |
| if MODEL is None: | |
| try: | |
| MODEL_LOADING = True | |
| print("Loading model...") | |
| # Check if model is available first | |
| available, message = check_model_availability() | |
| if not available: | |
| raise Exception(f"Model not available: {message}") | |
| model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio") | |
| model.to("cpu") # Use CPU for HuggingFace Spaces | |
| model.eval() | |
| MODEL = model | |
| MODEL_LOADED = True | |
| MODEL_LOADING = False | |
| print("Model loaded successfully!") | |
| return MODEL | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| MODEL_LOADING = False | |
| MODEL_LOAD_FAILED = True | |
| return None | |
| return MODEL | |
| def prompt_lm(audios: list[str], messages: list[dict[str, str]]): | |
| # Always try to load the model if needed | |
| model = load_model_if_needed() | |
| if model is None: | |
| if MODEL_LOADING: | |
| return "π Loading model... This may take a few minutes on first use. Please try again in a moment." | |
| elif MODEL_LOAD_FAILED: | |
| return "β Model failed to load. This could be due to:\nβ’ No internet connection\nβ’ Insufficient disk space\nβ’ Model repository access issues\n\nPlease check your connection and try again using the retry button." | |
| else: | |
| return "Demo mode: Model not loaded. Please check the model configuration." | |
| cuda_enabled = torch.cuda.is_available() | |
| samples = prepare_sample_waveforms(audios, cuda_enabled) | |
| prompt_text = model.llama_tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ).removeprefix(model.llama_tokenizer.bos_token) | |
| prompt_text = re.sub( | |
| r"<\|start_header_id\|>system<\|end_header_id\|>\n\nCutting Knowledge Date: [^\n]+\nToday Date: [^\n]+\n\n<\|eot_id\|>", | |
| "", | |
| prompt_text, | |
| ) # exclude the system header from the prompt | |
| prompt_text = re.sub("\\n", r"\\n", prompt_text) # FIXME this is a hack to fix the issue #34 | |
| print(f"{prompt_text=}") | |
| with torch.cuda.amp.autocast(dtype=torch.float16): | |
| llm_answer = model.generate(samples, CONFIG.generate, prompts=[prompt_text]) | |
| return llm_answer[0] | |
| def _multimodal_textbox_factory(): | |
| return gr.MultimodalTextbox( | |
| value=None, | |
| interactive=True, | |
| file_count="multiple", | |
| placeholder="Enter message or upload file...", | |
| show_label=False, | |
| submit_btn="Add input", | |
| file_types=["audio"], | |
| ) | |
| def user_message(content): | |
| return {"role": "user", "content": content} | |
| def add_message(history, message): | |
| for x in message["files"]: | |
| history.append(user_message({"path": x})) | |
| if message["text"]: | |
| history.append(user_message(message["text"])) | |
| return history, _multimodal_textbox_factory() | |
| def combine_model_inputs(msgs: list[dict[str, str]]) -> dict[str, list[str]]: | |
| messages = [] | |
| files = [] | |
| for msg in msgs: | |
| print(msg, messages, files) | |
| match msg: | |
| case {"content": (path,)}: | |
| messages.append({"role": msg["role"], "content": "<Audio><AudioHere></Audio> "}) | |
| files.append(path) | |
| case _: | |
| messages.append(msg) | |
| joined_messages = [] | |
| # join consecutive messages from the same role | |
| for msg in messages: | |
| if joined_messages and joined_messages[-1]["role"] == msg["role"]: | |
| joined_messages[-1]["content"] += msg["content"] | |
| else: | |
| joined_messages.append(msg) | |
| return {"messages": joined_messages, "files": files} | |
| def bot_response(history: list): | |
| print(type(history)) | |
| combined_inputs = combine_model_inputs(history) | |
| response = prompt_lm(combined_inputs["files"], combined_inputs["messages"]) | |
| history.append({"role": "assistant", "content": response}) | |
| return history | |
| def _chat_tab(examples): | |
| # Add status indicator | |
| status_text = gr.Textbox( | |
| value=get_model_status(), | |
| label="Model Status", | |
| interactive=False, | |
| visible=True | |
| ) | |
| # Add retry button that only shows when model failed to load | |
| retry_button = gr.Button( | |
| "π Retry Loading Model", | |
| visible=False, | |
| variant="secondary" | |
| ) | |
| chatbot = gr.Chatbot( | |
| label="Model inputs", | |
| elem_id="chatbot", | |
| bubble_full_width=False, | |
| type="messages", | |
| render_markdown=False, | |
| # editable="user", # disable because of https://github.com/gradio-app/gradio/issues/10320 | |
| resizeable=True, | |
| ) | |
| chat_input = _multimodal_textbox_factory() | |
| send_all = gr.Button("Send all", elem_id="send-all") | |
| clear_button = gr.ClearButton(components=[chatbot, chat_input], visible=False) | |
| chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input]) | |
| bot_msg = send_all.click( | |
| bot_response, | |
| [chatbot], | |
| [chatbot], | |
| api_name="bot_response", | |
| ) | |
| # Update status after bot response | |
| bot_msg.then(lambda: get_model_status(), None, [status_text]) | |
| bot_msg.then(lambda: gr.ClearButton(visible=True), None, [clear_button]) | |
| clear_button.click(lambda: gr.ClearButton(visible=False), None, [clear_button]) | |
| # Handle retry button | |
| retry_button.click( | |
| reset_model_state, | |
| None, | |
| [status_text] | |
| ) | |
| # Show/hide retry button based on model status | |
| def update_retry_button_visibility(): | |
| return gr.Button(visible=MODEL_LOAD_FAILED) | |
| # Update retry button visibility when status changes | |
| bot_msg.then(update_retry_button_visibility, None, [retry_button]) | |
| retry_button.click(update_retry_button_visibility, None, [retry_button]) | |
| gr.Examples( | |
| list(examples.values()), | |
| chatbot, | |
| chatbot, | |
| example_labels=list(examples.keys()), | |
| examples_per_page=20, | |
| ) | |
| def summarize_batch_results(results): | |
| summary = Counter(results) | |
| summary_str = "\n".join(f"{k}: {v}" for k, v in summary.most_common()) | |
| return summary_str | |
| def run_batch_inference(files, task, progress=gr.Progress()) -> str: | |
| model = load_model_if_needed() | |
| if model is None: | |
| if MODEL_LOADING: | |
| return "π Loading model... This may take a few minutes on first use. Please try again in a moment." | |
| elif MODEL_LOAD_FAILED: | |
| return "β Model failed to load. This could be due to:\nβ’ No internet connection\nβ’ Insufficient disk space\nβ’ Model repository access issues\n\nPlease check your connection and try again." | |
| else: | |
| return "Demo mode: Model not loaded. Please check the model configuration." | |
| outputs = [] | |
| prompt = "<Audio><AudioHere></Audio> " + task | |
| for file in progress.tqdm(files): | |
| outputs.append(prompt_lm([file], [{"role": "user", "content": prompt}])) | |
| batch_summary: str = summarize_batch_results(outputs) | |
| report = f"Batch summary:\n{batch_summary}\n\n" | |
| return report | |
| def multi_extension_glob_mask(mask_base, *extensions): | |
| mask_ext = ["[{}]".format("".join(set(c))) for c in zip(*extensions)] | |
| if not mask_ext or len(set(len(e) for e in extensions)) > 1: | |
| mask_ext.append("*") | |
| return mask_base + "".join(mask_ext) | |
| def _batch_tab(file_selection: Literal["upload", "explorer"] = "upload"): | |
| if file_selection == "explorer": | |
| files = gr.FileExplorer( | |
| glob=multi_extension_glob_mask("**.", "mp3", "flac", "wav"), | |
| label="Select audio files", | |
| file_count="multiple", | |
| ) | |
| elif file_selection == "upload": | |
| files = gr.Files(label="Uploaded files", file_types=["audio"], height=300) | |
| task = gr.Textbox(label="Task", placeholder="Enter task...", show_label=True) | |
| process_btn = gr.Button("Process") | |
| output = gr.TextArea() | |
| process_btn.click( | |
| run_batch_inference, | |
| [files, task], | |
| [output], | |
| ) | |
| def to_raven_format(outputs: dict[int, str], chunk_len: int = 10) -> str: | |
| def get_line(row, start, end, annotation): | |
| return f"{row}\tSpectrogram 1\t1\t{start}\t{end}\t0\t8000\t{annotation}" | |
| raven_output = ["Selection\tView\tChannel\tBegin Time (s)\tEnd Time (s)\tLow Freq (Hz)\tHigh Freq (Hz)\tAnnotation"] | |
| current_offset = 0 | |
| last_label = "" | |
| row = 1 | |
| # The "Selection" column is just the row number. | |
| # The "view" column will always say "Spectrogram 1". | |
| # Channel can always be "1". | |
| # For the frequency bounds we can just use 0 and 1/2 the sample rate | |
| for offset, label in sorted(outputs.items()): | |
| if label != last_label and last_label: | |
| raven_output.append(get_line(row, current_offset, offset, last_label)) | |
| current_offset = offset | |
| row += 1 | |
| if not last_label: | |
| current_offset = offset | |
| if label != "None": | |
| last_label = label | |
| else: | |
| last_label = "" | |
| if last_label: | |
| raven_output.append(get_line(row, current_offset, current_offset + chunk_len, last_label)) | |
| return "\n".join(raven_output) | |
| def _run_long_recording_inference(file, task, chunk_len: int = 10, hop_len: int = 5, progress=gr.Progress()): | |
| model = load_model_if_needed() | |
| if model is None: | |
| if MODEL_LOADING: | |
| return "π Loading model... This may take a few minutes on first use. Please try again in a moment.", None | |
| elif MODEL_LOAD_FAILED: | |
| return "β Model failed to load. This could be due to:\nβ’ No internet connection\nβ’ Insufficient disk space\nβ’ Model repository access issues\n\nPlease check your connection and try again.", None | |
| else: | |
| return "Demo mode: Model not loaded. Please check the model configuration.", None | |
| cuda_enabled = torch.cuda.is_available() | |
| outputs = {} | |
| offset = 0 | |
| prompt = f"<Audio><AudioHere></Audio> {task}" | |
| prompt = CONFIG.model.prompt_template.format(prompt) | |
| for batch in progress.tqdm(generate_sample_batches(file, cuda_enabled, chunk_len=chunk_len, hop_len=hop_len)): | |
| prompt_strs = [prompt] * len(batch["audio_chunk_sizes"]) | |
| with torch.cuda.amp.autocast(dtype=torch.float16): | |
| llm_answers = model.generate(batch, CONFIG.generate, prompts=prompt_strs) | |
| for answer in llm_answers: | |
| outputs[offset] = answer | |
| offset += hop_len | |
| report = f"Number of chunks: {len(outputs)}\n\n" | |
| for offset in sorted(outputs.keys()): | |
| report += f"{offset:02d}s:\t{outputs[offset]}\n" | |
| raven_output = to_raven_format(outputs, chunk_len=chunk_len) | |
| with tempfile.NamedTemporaryFile(mode="w", prefix="raven-", suffix=".txt", delete=False) as f: | |
| f.write(raven_output) | |
| raven_file = f.name | |
| return report, raven_file | |
| def _long_recording_tab(): | |
| audio_input = gr.Audio(label="Upload audio file", type="filepath") | |
| task = gr.Dropdown( | |
| [ | |
| "What are the common names for the species in the audio, if any?", | |
| "Caption the audio.", | |
| "Caption the audio, using the scientific name for any animal species.", | |
| "Caption the audio, using the common name for any animal species.", | |
| "What is the scientific name for the focal species in the audio?", | |
| "What is the common name for the focal species in the audio?", | |
| "What is the family of the focal species in the audio?", | |
| "What is the genus of the focal species in the audio?", | |
| "What is the taxonomic name of the focal species in the audio?", | |
| "What call types are heard from the focal species in the audio?", | |
| "What is the life stage of the focal species in the audio?", | |
| ], | |
| label="Tasks", | |
| allow_custom_value=True, | |
| ) | |
| with gr.Accordion("Advanced options", open=False): | |
| hop_len = gr.Slider(1, 10, 5, label="Hop length (seconds)", step=1) | |
| chunk_len = gr.Slider(1, 10, 10, label="Chunk length (seconds)", step=1) | |
| process_btn = gr.Button("Process") | |
| output = gr.TextArea() | |
| download_raven = gr.DownloadButton("Download Raven file") | |
| process_btn.click( | |
| _run_long_recording_inference, | |
| [audio_input, task, chunk_len, hop_len], | |
| [output, download_raven], | |
| ) | |
| def main( | |
| assets_dir: Path, | |
| cfg_path: str | Path, | |
| options: list[str] = [], | |
| device: str = "cpu", | |
| ): | |
| global CONFIG | |
| try: | |
| cfg = Config.from_sources(yaml_file=cfg_path, cli_args=options) | |
| CONFIG = cfg | |
| print("Configuration loaded successfully") | |
| except Exception as e: | |
| print(f"Warning: Could not load config: {e}") | |
| print("Running in demo mode") | |
| CONFIG = None | |
| # Check if assets directory exists, if not create a placeholder | |
| if not assets_dir.exists(): | |
| print(f"Warning: Assets directory {assets_dir} does not exist") | |
| assets_dir.mkdir(exist_ok=True) | |
| # Create placeholder audio files if they don't exist | |
| laz_audio = assets_dir / "Lazuli_Bunting_yell-YELLLAZB20160625SM303143.mp3" | |
| frog_audio = assets_dir / "nri-GreenTreeFrogEvergladesNP.mp3" | |
| robin_audio = assets_dir / "yell-YELLAMRO20160506SM3.mp3" | |
| vireo_audio = assets_dir / "yell-YELLWarblingVireoMammoth20150614T29ms.mp3" | |
| examples = { | |
| "Caption the audio (Lazuli Bunting)": [ | |
| [ | |
| user_message({"path": str(laz_audio)}), | |
| user_message("Caption the audio."), | |
| ] | |
| ], | |
| "Caption the audio (Green Tree Frog)": [ | |
| [ | |
| user_message({"path": str(frog_audio)}), | |
| user_message("Caption the audio, using the common name for any animal species."), | |
| ] | |
| ], | |
| "Caption the audio (American Robin)": [ | |
| [ | |
| user_message({"path": str(robin_audio)}), | |
| user_message("Caption the audio."), | |
| ] | |
| ], | |
| "Caption the audio (Warbling Vireo)": [ | |
| [ | |
| user_message({"path": str(vireo_audio)}), | |
| user_message("Caption the audio."), | |
| ] | |
| ], | |
| } | |
| with gr.Blocks(title="NatureLM-audio", theme=gr.themes.Default(primary_hue="slate")) as app: | |
| with gr.Tabs(): | |
| with gr.Tab("Chat"): | |
| _chat_tab(examples) | |
| with gr.Tab("Batch"): | |
| _batch_tab() | |
| with gr.Tab("Long Recording"): | |
| _long_recording_tab() | |
| return app | |
| # At the bottom of the file: | |
| app = main( | |
| assets_dir=Path("assets"), | |
| cfg_path=Path("configs/inference.yml"), | |
| options=[], | |
| device="cpu", # TODO: from config depending on zerogpu! (to change) | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| app.launch() |