datbkpro commited on
Commit
b5e51ac
·
verified ·
1 Parent(s): c6367cc

Create speechbrain_vad.py

Browse files
Files changed (1) hide show
  1. core/speechbrain_vad.py +125 -132
core/speechbrain_vad.py CHANGED
@@ -1,154 +1,147 @@
1
  import torch
2
  import torchaudio
3
  import numpy as np
4
- from speechbrain.inference import VAD
5
- from typing import List, Tuple, Optional
6
- import queue
7
- import threading
8
- import time
9
  from config.settings import settings
10
 
11
  class SpeechBrainVAD:
12
  def __init__(self):
13
- self.vad_model = None
14
  self.sample_rate = settings.SAMPLE_RATE
15
- self.threshold = settings.VAD_THRESHOLD
16
- self.min_silence_duration = settings.VAD_MIN_SILENCE_DURATION
17
- self.speech_pad_duration = settings.VAD_SPEECH_PAD_DURATION
18
- self.is_running = False
19
- self.audio_queue = queue.Queue()
20
- self.speech_buffer = []
21
- self.silence_start_time = None
22
- self.callback = None
23
-
24
  self._initialize_model()
25
-
26
  def _initialize_model(self):
27
- """Khởi tạo hình VAD từ SpeechBrain"""
28
  try:
29
- print("🔄 Đang tải mô hình SpeechBrain VAD...")
30
- self.vad_model = VAD.from_hparams(
 
31
  source=settings.VAD_MODEL,
32
- savedir=f"pretrained_models/{settings.VAD_MODEL}"
33
  )
34
- print("✅ Đã tải hình VAD thành công")
35
  except Exception as e:
36
- print(f"❌ Lỗi tải hình VAD: {e}")
37
- self.vad_model = None
38
-
39
- def preprocess_audio(self, audio_data: np.ndarray, original_sr: int) -> np.ndarray:
40
- """Tiền xử audio cho VAD"""
41
- if original_sr != self.sample_rate:
42
- # Resample audio to VAD sample rate
43
- audio_tensor = torch.from_numpy(audio_data).float()
44
- if len(audio_tensor.shape) > 1:
45
- audio_tensor = audio_tensor.mean(dim=0) # Convert to mono
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- resampler = torchaudio.transforms.Resample(
48
- orig_freq=original_sr,
49
- new_freq=self.sample_rate
 
 
 
 
50
  )
51
- audio_tensor = resampler(audio_tensor)
52
- audio_data = audio_tensor.numpy()
53
-
54
- # Normalize audio
55
- if np.max(np.abs(audio_data)) > 0:
56
- audio_data = audio_data / np.max(np.abs(audio_data))
57
-
58
- return audio_data
59
-
60
- def detect_voice_activity(self, audio_chunk: np.ndarray) -> bool:
61
- """Phát hiện hoạt động giọng nói trong audio chunk"""
62
- if self.vad_model is None:
63
- # Fallback: simple energy-based VAD
64
- return self._energy_based_vad(audio_chunk)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  try:
67
- # Convert to tensor and add batch dimension
68
- audio_tensor = torch.from_numpy(audio_chunk).float().unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
69
 
70
- # Get VAD probabilities
71
- with torch.no_grad():
72
- prob = self.vad_model.get_speech_prob_chunk(audio_tensor)
 
 
 
 
73
 
74
- return prob.item() > self.threshold
 
 
 
 
75
 
76
  except Exception as e:
77
- print(f"❌ Lỗi VAD detection: {e}")
78
- return self._energy_based_vad(audio_chunk)
79
-
80
- def _energy_based_vad(self, audio_chunk: np.ndarray) -> bool:
81
- """Fallback VAD dựa trên năng lượng âm thanh"""
82
- energy = np.mean(audio_chunk ** 2)
83
- return energy > 0.01 # Simple threshold
84
-
85
- def process_stream(self, audio_chunk: np.ndarray, original_sr: int):
86
- """Xử lý audio stream real-time"""
87
- if not self.is_running:
88
- return
89
-
90
- # Preprocess audio
91
- processed_audio = self.preprocess_audio(audio_chunk, original_sr)
92
-
93
- # Detect voice activity
94
- is_speech = self.detect_voice_activity(processed_audio)
95
-
96
- if is_speech:
97
- self.silence_start_time = None
98
- self.speech_buffer.extend(processed_audio)
99
- print("🎤 Đang nói...")
100
- else:
101
- # Silence detected
102
- if self.silence_start_time is None:
103
- self.silence_start_time = time.time()
104
- elif len(self.speech_buffer) > 0:
105
- silence_duration = time.time() - self.silence_start_time
106
- if silence_duration >= self.min_silence_duration:
107
- # End of speech segment
108
- self._process_speech_segment()
109
-
110
- return is_speech
111
-
112
- def _process_speech_segment(self):
113
- """Xử lý segment giọng nói khi kết thúc"""
114
- if len(self.speech_buffer) == 0:
115
- return
116
-
117
- # Convert buffer to numpy array
118
- speech_audio = np.array(self.speech_buffer)
119
-
120
- # Call callback with speech segment
121
- if self.callback and callable(self.callback):
122
- self.callback(speech_audio, self.sample_rate)
123
-
124
- # Clear buffer
125
- self.speech_buffer = []
126
- self.silence_start_time = None
127
-
128
- print("✅ Đã xử lý segment giọng nói")
129
-
130
- def start_stream(self, callback: callable):
131
- """Bắt đầu xử lý stream"""
132
- self.is_running = True
133
- self.callback = callback
134
- self.speech_buffer = []
135
- self.silence_start_time = None
136
- print("🎙️ Bắt đầu stream VAD...")
137
-
138
- def stop_stream(self):
139
- """Dừng xử lý stream"""
140
- self.is_running = False
141
- # Process any remaining speech
142
- if len(self.speech_buffer) > 0:
143
- self._process_speech_segment()
144
- print("🛑 Đã dừng stream VAD")
145
-
146
- def get_audio_chunk_from_stream(self, stream, chunk_size: int = 1024):
147
- """Lấy audio chunk từ stream (for microphone input)"""
148
- try:
149
- data = stream.read(chunk_size, exception_on_overflow=False)
150
- audio_data = np.frombuffer(data, dtype=np.int16)
151
- return audio_data.astype(np.float32) / 32768.0 # Normalize to [-1, 1]
152
- except Exception as e:
153
- print(f"❌ Lỗi đọc audio stream: {e}")
154
- return None
 
1
  import torch
2
  import torchaudio
3
  import numpy as np
4
+ from typing import Optional, Callable
 
 
 
 
5
  from config.settings import settings
6
 
7
  class SpeechBrainVAD:
8
  def __init__(self):
9
+ self.model = None
10
  self.sample_rate = settings.SAMPLE_RATE
11
+ self.is_streaming = False
12
+ self.speech_callback = None
13
+ self.audio_buffer = []
 
 
 
 
 
 
14
  self._initialize_model()
15
+
16
  def _initialize_model(self):
17
+ """Khởi tạo VAD model từ SpeechBrain"""
18
  try:
19
+ from speechbrain.pretrained import VAD
20
+ print("🔄 Đang tải VAD model từ SpeechBrain...")
21
+ self.model = VAD.from_hparams(
22
  source=settings.VAD_MODEL,
23
+ savedir=f"/tmp/{settings.VAD_MODEL.replace('/', '_')}"
24
  )
25
+ print("✅ Đã tải VAD model thành công")
26
  except Exception as e:
27
+ print(f"❌ Lỗi tải VAD model: {e}")
28
+ self.model = None
29
+
30
+ def start_stream(self, speech_callback: Callable):
31
+ """Bắt đầu stream với VAD"""
32
+ if self.model is None:
33
+ print("❌ VAD model chưa được khởi tạo")
34
+ return False
35
+
36
+ self.is_streaming = True
37
+ self.speech_callback = speech_callback
38
+ self.audio_buffer = []
39
+ print("🎙️ Bắt đầu VAD streaming...")
40
+ return True
41
+
42
+ def stop_stream(self):
43
+ """Dừng stream"""
44
+ self.is_streaming = False
45
+ self.speech_callback = None
46
+ self.audio_buffer = []
47
+ print("🛑 Đã dừng VAD streaming")
48
+
49
+ def process_stream(self, audio_chunk: np.ndarray, sample_rate: int):
50
+ """Xử lý audio chunk với VAD"""
51
+ if not self.is_streaming or self.model is None:
52
+ return
53
+
54
+ try:
55
+ # Resample nếu cần
56
+ if sample_rate != self.sample_rate:
57
+ audio_chunk = self._resample_audio(audio_chunk, sample_rate, self.sample_rate)
58
+
59
+ # Thêm vào buffer
60
+ self.audio_buffer.extend(audio_chunk)
61
+
62
+ # Xử lý khi buffer đủ lớn (2 giây)
63
+ buffer_duration = len(self.audio_buffer) / self.sample_rate
64
+ if buffer_duration >= 2.0:
65
+ self._process_buffer()
66
+
67
+ except Exception as e:
68
+ print(f"❌ Lỗi xử lý VAD: {e}")
69
+
70
+ def _process_buffer(self):
71
+ """Xử lý buffer audio với VAD"""
72
+ try:
73
+ # Chuyển buffer thành tensor
74
+ audio_tensor = torch.FloatTensor(self.audio_buffer).unsqueeze(0)
75
 
76
+ # Phát hiện speech với VAD
77
+ boundaries = self.model.get_speech_segments(
78
+ audio_tensor,
79
+ # Điều chỉnh parameters để nhạy hơn
80
+ threshold=settings.VAD_THRESHOLD - 0.1, # Giảm threshold
81
+ min_silence_duration=settings.VAD_MIN_SILENCE_DURATION + 0.3, # Tăng silence duration
82
+ speech_pad_duration=settings.VAD_SPEECH_PAD_DURATION
83
  )
84
+
85
+ # Xử lý speech segments
86
+ if len(boundaries) > 0:
87
+ for start, end in boundaries:
88
+ start_sample = int(start * self.sample_rate)
89
+ end_sample = int(end * self.sample_rate)
90
+
91
+ # Trích xuất speech segment
92
+ speech_audio = np.array(self.audio_buffer[start_sample:end_sample])
93
+
94
+ if len(speech_audio) > self.sample_rate * 0.5: # Ít nhất 0.5 giây
95
+ print(f"🎯 VAD phát hiện speech: {len(speech_audio)/self.sample_rate:.2f}s")
96
+
97
+ # Gọi callback với speech segment
98
+ if self.speech_callback:
99
+ self.speech_callback(speech_audio, self.sample_rate)
100
+
101
+ # Giữ lại 0.5 giây cuối để overlap
102
+ keep_samples = int(self.sample_rate * 0.5)
103
+ if len(self.audio_buffer) > keep_samples:
104
+ self.audio_buffer = self.audio_buffer[-keep_samples:]
105
+ else:
106
+ self.audio_buffer = []
107
+
108
+ except Exception as e:
109
+ print(f"❌ Lỗi xử lý VAD buffer: {e}")
110
+ self.audio_buffer = []
111
+
112
+ def _resample_audio(self, audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
113
+ """Resample audio nếu cần"""
114
+ if orig_sr == target_sr:
115
+ return audio
116
 
117
  try:
118
+ audio_tensor = torch.FloatTensor(audio).unsqueeze(0)
119
+ resampler = torchaudio.transforms.Resample(orig_sr, target_sr)
120
+ resampled = resampler(audio_tensor)
121
+ return resampled.squeeze(0).numpy()
122
+ except Exception as e:
123
+ print(f"⚠️ Lỗi resample: {e}")
124
+ return audio
125
+
126
+ def is_speech(self, audio_chunk: np.ndarray, sample_rate: int) -> bool:
127
+ """Kiểm tra xem audio chunk có phải là speech không"""
128
+ if self.model is None:
129
+ return True # Fallback: luôn coi là speech
130
 
131
+ try:
132
+ # Resample nếu cần
133
+ if sample_rate != self.sample_rate:
134
+ audio_chunk = self._resample_audio(audio_chunk, sample_rate, self.sample_rate)
135
+
136
+ # Chuyển thành tensor
137
+ audio_tensor = torch.FloatTensor(audio_chunk).unsqueeze(0)
138
 
139
+ # Phát hiện speech
140
+ prob_speech = self.model.get_speech_prob_chunk(audio_tensor)
141
+
142
+ # Kiểm tra ngưỡng
143
+ return prob_speech.mean().item() > (settings.VAD_THRESHOLD - 0.1)
144
 
145
  except Exception as e:
146
+ print(f"❌ Lỗi kiểm tra speech: {e}")
147
+ return True