File size: 1,514 Bytes
89c09cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import gradio as gr
import torch
import torchaudio
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification

# মডেল লোড করো
model_name = "rakib730/finetuned-gtzan"
extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = AutoModelForAudioClassification.from_pretrained(model_name)

# মডেলকে eval মোডে নাও
model.eval()

# অডিও ক্লাসিফিকেশন ফাংশন
def classify_music(audio):
    # audio: (numpy array, sample_rate)
    waveform, sample_rate = audio
    # মডেল ট্রেনিংয়ে ব্যবহৃত sample rate ঠিক করো
    if sample_rate != 16000:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
        waveform = resampler(torch.tensor(waveform))

    inputs = extractor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt")
    
    with torch.no_grad():
        logits = model(**inputs).logits
        predicted_class_id = torch.argmax(logits, dim=1).item()
        predicted_label = model.config.id2label[predicted_class_id]
    
    return predicted_label

# Gradio UI
gr.Interface(
    fn=classify_music,
    inputs=gr.Audio(type="numpy", label="Upload a Music Clip (WAV/MP3)"),
    outputs=gr.Textbox(label="Predicted Genre"),
    title="🎵 Music Genre Classifier",
    description="Upload a short music clip and get the predicted genre using a fine-tuned GTZAN model.",
    live=False
).launch()