ATOM-WebGPU / index.js
Chillarmo's picture
Update index.js
a66e866 verified
import { WhisperForConditionalGeneration, WhisperProcessor } from 'https://cdn.jsdelivr.net/npm/@huggingface/[email protected]';
// Get DOM elements
const status = document.getElementById('status');
const startBtn = document.getElementById('startBtn');
const stopBtn = document.getElementById('stopBtn');
const clearBtn = document.getElementById('clearBtn');
const transcriptionContainer = document.getElementById('transcriptionContainer');
const chunkLengthSelect = document.getElementById('chunkLength');
const useWebGPUCheckbox = document.getElementById('useWebGPU');
const chunkCountDisplay = document.getElementById('chunkCount');
const recordingTimeDisplay = document.getElementById('recordingTime');
const visualizerBars = document.querySelectorAll('.bar');
// State
let model = null;
let processor = null;
let mediaStream = null;
let audioContext = null;
let mediaRecorder = null;
let recordedChunks = [];
let isRecording = false;
let chunkCount = 0;
let recordingStartTime = null;
let recordingInterval = null;
let analyser = null;
let animationId = null;
// Initialize the ATOM model
async function initModel() {
try {
status.textContent = 'Loading ATOM model with custom tokenizer... This may take a minute.';
status.className = 'loading';
const device = useWebGPUCheckbox.checked ? 'webgpu' : 'wasm';
const dtype = useWebGPUCheckbox.checked ? 'fp32' : 'fp32';
// Load processor (includes the custom Armenian tokenizer)
status.textContent = 'Loading custom Armenian processor/tokenizer...';
processor = await WhisperProcessor.from_pretrained('Chillarmo/ATOM', {
progress_callback: (progress) => {
if (progress.status === 'downloading') {
const percent = Math.round((progress.loaded / progress.total) * 100);
status.textContent = `Downloading ${progress.file}: ${percent}%`;
}
}
});
console.log('βœ“ ATOM Processor loaded (includes custom tokenizer)');
// Load model
status.textContent = 'Loading ATOM model...';
model = await WhisperForConditionalGeneration.from_pretrained('Chillarmo/ATOM', {
device: device,
dtype: dtype,
progress_callback: (progress) => {
if (progress.status === 'downloading') {
const percent = Math.round((progress.loaded / progress.total) * 100);
status.textContent = `Downloading model ${progress.file}: ${percent}%`;
} else if (progress.status === 'loading') {
status.textContent = `Loading ${progress.file}...`;
}
}
});
console.log('βœ“ ATOM Model loaded');
console.log('Model config:', model.config);
console.log('Processor:', processor);
status.textContent = 'ATOM ready! Model + custom tokenizer loaded successfully.';
status.className = 'ready';
startBtn.disabled = false;
} catch (error) {
console.error('Model loading error:', error);
status.textContent = `Error loading model: ${error.message}`;
status.className = 'error';
console.error('Full error details:', error);
}
}
// Format time as MM:SS
function formatTime(seconds) {
const mins = Math.floor(seconds / 60);
const secs = Math.floor(seconds % 60);
return `${mins.toString().padStart(2, '0')}:${secs.toString().padStart(2, '0')}`;
}
// Update recording time
function updateRecordingTime() {
if (recordingStartTime) {
const elapsed = (Date.now() - recordingStartTime) / 1000;
recordingTimeDisplay.textContent = formatTime(elapsed);
}
}
// Visualize audio
function visualizeAudio() {
if (!analyser || !isRecording) return;
const dataArray = new Uint8Array(analyser.frequencyBinCount);
analyser.getByteFrequencyData(dataArray);
// Sample the data for visualization
const barCount = visualizerBars.length;
const step = Math.floor(dataArray.length / barCount);
visualizerBars.forEach((bar, index) => {
const value = dataArray[index * step];
const height = (value / 255) * 70 + 4; // 4px minimum, 74px maximum
bar.style.height = `${height}px`;
});
animationId = requestAnimationFrame(visualizeAudio);
}
// Start recording
async function startRecording() {
try {
// Request microphone access
mediaStream = await navigator.mediaDevices.getUserMedia({
audio: {
channelCount: 1,
sampleRate: 16000,
}
});
// Set up audio context for visualization
audioContext = new AudioContext({ sampleRate: 16000 });
const source = audioContext.createMediaStreamSource(mediaStream);
analyser = audioContext.createAnalyser();
analyser.fftSize = 256;
source.connect(analyser);
// Set up MediaRecorder
mediaRecorder = new MediaRecorder(mediaStream);
recordedChunks = [];
mediaRecorder.ondataavailable = (event) => {
if (event.data.size > 0) {
recordedChunks.push(event.data);
}
};
mediaRecorder.onstop = async () => {
if (recordedChunks.length > 0) {
await processAudioChunk(recordedChunks);
recordedChunks = [];
}
};
// Start recording
const chunkDuration = parseInt(chunkLengthSelect.value) * 1000;
mediaRecorder.start();
// Schedule automatic chunk processing
const chunkInterval = setInterval(() => {
if (!isRecording) {
clearInterval(chunkInterval);
return;
}
mediaRecorder.stop();
mediaRecorder.start();
}, chunkDuration);
isRecording = true;
recordingStartTime = Date.now();
recordingInterval = setInterval(updateRecordingTime, 100);
status.textContent = 'Recording... Speak in Armenian';
status.className = 'recording';
startBtn.disabled = true;
stopBtn.disabled = false;
// Start visualization
visualizeAudio();
} catch (error) {
console.error('Error starting recording:', error);
status.textContent = `Error: ${error.message}`;
status.className = 'error';
}
}
// Stop recording
function stopRecording() {
isRecording = false;
if (mediaRecorder && mediaRecorder.state !== 'inactive') {
mediaRecorder.stop();
}
if (mediaStream) {
mediaStream.getTracks().forEach(track => track.stop());
}
if (audioContext) {
audioContext.close();
}
if (recordingInterval) {
clearInterval(recordingInterval);
}
if (animationId) {
cancelAnimationFrame(animationId);
}
// Reset visualizer
visualizerBars.forEach(bar => {
bar.style.height = '4px';
});
status.textContent = 'Recording stopped. Ready for next recording.';
status.className = 'ready';
startBtn.disabled = false;
stopBtn.disabled = true;
}
// Process audio chunk
async function processAudioChunk(chunks) {
try {
status.textContent = 'Processing audio...';
status.className = 'processing';
// Create audio blob
const audioBlob = new Blob(chunks, { type: 'audio/webm' });
// Convert to array buffer
const arrayBuffer = await audioBlob.arrayBuffer();
// Decode audio
const tempAudioContext = new (window.AudioContext || window.webkitAudioContext)();
const audioBuffer = await tempAudioContext.decodeAudioData(arrayBuffer);
// Get audio data as Float32Array
const audioData = audioBuffer.getChannelData(0);
console.log('Processing audio chunk:', audioData.length, 'samples at', audioBuffer.sampleRate, 'Hz');
// Process audio with the processor (includes custom tokenizer)
const inputs = await processor(audioData, {
sampling_rate: audioBuffer.sampleRate,
});
console.log('Processor output:', inputs);
// Generate with the model
const outputs = await model.generate({
...inputs,
});
console.log('Model outputs:', outputs);
// Decode the output tokens using the custom tokenizer
const decoded = processor.batch_decode(outputs, {
skip_special_tokens: true,
});
console.log('Decoded text:', decoded);
// Add to transcription
const text = decoded[0].trim();
if (text) {
addTranscription(text);
chunkCount++;
chunkCountDisplay.textContent = chunkCount;
}
if (isRecording) {
status.textContent = 'Recording... Speak in Armenian';
status.className = 'recording';
} else {
status.textContent = 'Ready for next recording.';
status.className = 'ready';
}
tempAudioContext.close();
} catch (error) {
console.error('Error processing audio:', error);
status.textContent = `Processing error: ${error.message}`;
status.className = 'error';
console.error('Full processing error:', error);
// Restore recording status if still recording
setTimeout(() => {
if (isRecording) {
status.textContent = 'Recording... Speak in Armenian';
status.className = 'recording';
}
}, 2000);
}
}
// Add transcription to UI
function addTranscription(text) {
// Remove empty state if present
const emptyState = transcriptionContainer.querySelector('.empty-state');
if (emptyState) {
emptyState.remove();
}
// Create transcription item
const item = document.createElement('div');
item.className = 'transcription-item';
const timestamp = document.createElement('div');
timestamp.className = 'timestamp';
timestamp.textContent = new Date().toLocaleTimeString();
const textDiv = document.createElement('div');
textDiv.className = 'text';
textDiv.textContent = text;
item.appendChild(timestamp);
item.appendChild(textDiv);
transcriptionContainer.appendChild(item);
// Auto-scroll to bottom
transcriptionContainer.scrollTop = transcriptionContainer.scrollHeight;
}
// Clear transcriptions
function clearTranscriptions() {
transcriptionContainer.innerHTML = `
<div class="empty-state">
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 11a7 7 0 01-7 7m0 0a7 7 0 01-7-7m7 7v4m0 0H8m4 0h4m-4-8a3 3 0 01-3-3V5a3 3 0 116 0v6a3 3 0 01-3 3z" />
</svg>
<p>Click "Start Recording" to begin transcribing Armenian speech</p>
</div>
`;
chunkCount = 0;
chunkCountDisplay.textContent = '0';
recordingTimeDisplay.textContent = '00:00';
}
// Event listeners
startBtn.addEventListener('click', startRecording);
stopBtn.addEventListener('click', stopRecording);
clearBtn.addEventListener('click', clearTranscriptions);
// Check WebGPU support
if (useWebGPUCheckbox.checked && !navigator.gpu) {
status.textContent = 'WebGPU not supported, falling back to WASM';
status.className = 'error';
useWebGPUCheckbox.checked = false;
setTimeout(() => initModel(), 2000);
} else {
// Initialize model on load
initModel();
}
// Re-initialize if WebGPU setting changes
useWebGPUCheckbox.addEventListener('change', () => {
if (isRecording) {
alert('Please stop recording before changing acceleration settings');
useWebGPUCheckbox.checked = !useWebGPUCheckbox.checked;
return;
}
status.textContent = 'Reinitializing model...';
status.className = 'loading';
startBtn.disabled = true;
initModel();
});