File size: 4,156 Bytes
54ed3c0
 
 
 
 
 
 
 
 
 
 
 
28519bb
 
54ed3c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d020ce8
 
1e9d7e6
54ed3c0
 
 
 
 
 
 
 
 
 
 
 
d020ce8
 
 
 
 
 
 
 
1e9d7e6
d020ce8
 
 
 
 
93d97ba
1e9d7e6
d020ce8
 
 
 
f391bbe
 
 
 
d020ce8
1e9d7e6
d020ce8
f391bbe
 
 
 
 
 
1e9d7e6
d020ce8
 
 
 
 
 
1e9d7e6
d020ce8
1e9d7e6
d020ce8
 
 
 
 
 
 
 
 
 
1e9d7e6
d020ce8
 
 
 
f15c490
d020ce8
 
 
 
1e9d7e6
d020ce8
35f4b88
 
 
 
 
 
1e9d7e6
54ed3c0
 
6b12f2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#!/usr/bin/env python

import os
import pathlib

import gradio as gr
import librosa
import spaces
import torch
from transformers import KyutaiSpeechToTextForConditionalGeneration, KyutaiSpeechToTextProcessor

device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "kyutai/stt-2.6b-en-trfs"
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=device, torch_dtype="auto")
processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)

SAMPLE_RATE = 24000
MAX_DURATION = int(os.getenv("MAX_DURATION", "60"))
MAX_SAMPLE_SIZE = SAMPLE_RATE * MAX_DURATION


@spaces.GPU
def transcribe(audio_path: str) -> str:
    """Transcribe an English audio file to text.

    Args:
        audio_path (str): The path to the audio file. The audio must contain English speech.

    Returns:
        str: The transcription of the English audio file.
    """
    if not audio_path:
        return ""

    data, _ = librosa.load(audio_path, sr=SAMPLE_RATE)
    if len(data) > MAX_SAMPLE_SIZE:
        data = data[:MAX_SAMPLE_SIZE]
        gr.Info(f"Audio file is too long. Truncating to {MAX_DURATION} seconds.")

    inputs = processor(data)
    inputs.to(device)
    output_tokens = model.generate(**inputs)
    output = processor.batch_decode(output_tokens, skip_special_tokens=True)
    return output[0]


with gr.Blocks(fill_height=False) as demo:
    # Header
    gr.HTML("""
        <div class="header-container">
            <h1 class="header-title">🎙️ Kyutai Speech-to-Text</h1>
            <p class="header-subtitle">Advanced English Audio Transcription powered by AI</p>
        </div>
    """)

    # Info banner
    gr.HTML(f"""
        <div class="info-banner">
            ℹ️ Upload or record audio in English (max {MAX_DURATION} seconds). Supports WAV, MP3, and other common formats.
        </div>
    """)  # noqa: RUF001

    # Main content
    with gr.Group(elem_classes="main-card"):
        # Audio input
        audio = gr.Audio(
            label="🎵 Audio Input",
            type="filepath",
            sources=["upload", "microphone"],
            elem_classes="audio-container",
        )

        # Transcribe button
        transcribe_btn = gr.Button(
            "✨ Transcribe Audio",
            variant="primary",
            size="lg",
            elem_classes="primary-button",
        )

        # Output
        output = gr.Textbox(
            label="📝 Transcription",
            placeholder="Your transcription will appear here...",
            lines=6,
            max_lines=12,
            elem_classes="transcription-output",
        )

    # Examples section
    with gr.Group(elem_classes="examples-container"):
        gr.Markdown("### 💡 Try These Examples")
        gr.Examples(
            examples=sorted(pathlib.Path("assets").glob("*.wav")) if pathlib.Path("assets").exists() else [],
            inputs=audio,
            outputs=output,
            fn=transcribe,
            examples_per_page=5,
        )

    # Footer
    gr.HTML("""
        <div class="footer-container">
            <p>
                Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" class="footer-link" target="_blank">anycoder</a> •
                Powered by <a href="https://huggingface.co/kyutai/stt-2.6b-en-trfs" class="footer-link" target="_blank">Kyutai STT 2.6B</a>
            </p>
        </div>
    """)

    # Event handlers
    transcribe_btn.click(
        fn=transcribe,
        inputs=audio,
        outputs=output,
        api_name="transcribe",
    )


if __name__ == "__main__":
    # Custom theme for modern, clean design
    theme = gr.themes.Soft(
        primary_hue="blue",
        secondary_hue="slate",
        neutral_hue="slate",
        font=gr.themes.GoogleFont("Inter"),
        text_size="lg",
        spacing_size="md",
        radius_size="lg",
    ).set(
        button_primary_background_fill="*primary_600",
        button_primary_background_fill_hover="*primary_700",
        block_title_text_weight="600",
        block_label_text_weight="500",
    )
    demo.launch(theme=theme, css_paths="style.css", mcp_server=True)