datbkpro commited on
Commit
7d0b18d
·
verified ·
1 Parent(s): 6c27998

Create sambanova_voice_service.py

Browse files
Files changed (1) hide show
  1. services/sambanova_voice_service.py +150 -0
services/sambanova_voice_service.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import os
4
+ from pathlib import Path
5
+ import gradio as gr
6
+ import numpy as np
7
+ import openai
8
+ from fastapi import FastAPI
9
+ from fastapi.responses import HTMLResponse, StreamingResponse
10
+ from fastrtc import (
11
+ AdditionalOutputs,
12
+ ReplyOnStopWords,
13
+ Stream,
14
+ get_stt_model,
15
+ get_twilio_turn_credentials,
16
+ )
17
+ from gradio.utils import get_space
18
+ from pydantic import BaseModel
19
+
20
+ class SambanovaVoiceService:
21
+ """Dịch vụ Voice AI với Sambanova API"""
22
+
23
+ def __init__(self):
24
+ self.curr_dir = Path(__file__).parent
25
+
26
+ # Khởi tạo client Sambanova
27
+ self.client = openai.OpenAI(
28
+ api_key=os.environ.get("SAMBANOVA_API_KEY"),
29
+ base_url="https://api.sambanova.ai/v1",
30
+ )
31
+
32
+ # STT model
33
+ self.model = get_stt_model()
34
+
35
+ # RTC configuration
36
+ self.rtc_configuration = get_twilio_turn_credentials() if get_space() else None
37
+
38
+ # FastAPI app
39
+ self.app = FastAPI()
40
+
41
+ def create_response_handler(self):
42
+ """Tạo response handler cho voice streaming"""
43
+
44
+ def response(
45
+ audio: tuple[int, np.ndarray],
46
+ gradio_chatbot: list[dict] | None = None,
47
+ conversation_state: list[dict] | None = None,
48
+ ):
49
+ gradio_chatbot = gradio_chatbot or []
50
+ conversation_state = conversation_state or []
51
+
52
+ # Speech to Text
53
+ text = self.model.stt(audio)
54
+ print("🎤 STT Result:", text)
55
+
56
+ # Thêm audio vào chatbot
57
+ sample_rate, array = audio
58
+ gradio_chatbot.append(
59
+ {"role": "user", "content": gr.Audio((sample_rate, array.squeeze()))}
60
+ )
61
+ yield AdditionalOutputs(gradio_chatbot, conversation_state)
62
+
63
+ # Thêm text vào conversation state
64
+ conversation_state.append({"role": "user", "content": text})
65
+
66
+ # Gọi Sambanova API
67
+ request = self.client.chat.completions.create(
68
+ model="Meta-Llama-3.2-3B-Instruct",
69
+ messages=conversation_state,
70
+ temperature=0.1,
71
+ top_p=0.1,
72
+ )
73
+ response_content = {"role": "assistant", "content": request.choices[0].message.content}
74
+
75
+ conversation_state.append(response_content)
76
+ gradio_chatbot.append(response_content)
77
+
78
+ yield AdditionalOutputs(gradio_chatbot, conversation_state)
79
+
80
+ return response
81
+
82
+ def create_stream(self):
83
+ """Tạo FastRTC stream"""
84
+ response_handler = self.create_response_handler()
85
+
86
+ return Stream(
87
+ ReplyOnStopWords(
88
+ response_handler,
89
+ stop_words=["computer", "hey", "hello", "xin chào"],
90
+ input_sample_rate=16000,
91
+ ),
92
+ mode="send",
93
+ modality="audio",
94
+ additional_inputs=[gr.Chatbot(type="messages", value=[]), gr.State(value=[])],
95
+ additional_outputs=[gr.Chatbot(type="messages", value=[]), gr.State(value=[])],
96
+ additional_outputs_handler=lambda *a: (a[2], a[3]),
97
+ concurrency_limit=5 if get_space() else None,
98
+ time_limit=90 if get_space() else None,
99
+ rtc_configuration=self.rtc_configuration,
100
+ )
101
+
102
+ def setup_fastapi_routes(self):
103
+ """Thiết lập FastAPI routes"""
104
+
105
+ class Message(BaseModel):
106
+ role: str
107
+ content: str
108
+
109
+ class InputData(BaseModel):
110
+ webrtc_id: str
111
+ chatbot: list[Message]
112
+ state: list[Message]
113
+
114
+ @self.app.get("/")
115
+ async def home():
116
+ rtc_config = get_twilio_turn_credentials() if get_space() else None
117
+ html_content = (self.curr_dir / "templates" / "sambanova_index.html").read_text()
118
+ html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
119
+ return HTMLResponse(content=html_content)
120
+
121
+ @self.app.post("/input_hook")
122
+ async def input_hook(data: InputData):
123
+ body = data.model_dump()
124
+ # stream.set_input(data.webrtc_id, body["chatbot"], body["state"])
125
+ return {"status": "ok"}
126
+
127
+ def audio_to_base64(file_path):
128
+ audio_format = "wav"
129
+ with open(file_path, "rb") as audio_file:
130
+ encoded_audio = base64.b64encode(audio_file.read()).decode("utf-8")
131
+ return f"data:audio/{audio_format};base64,{encoded_audio}"
132
+
133
+ @self.app.get("/outputs")
134
+ async def outputs(webrtc_id: str):
135
+ async def output_stream():
136
+ # async for output in stream.output_stream(webrtc_id):
137
+ # chatbot = output.args[0]
138
+ # state = output.args[1]
139
+ # data = {
140
+ # "message": state[-1],
141
+ # "audio": audio_to_base64(chatbot[-1]["content"].value["path"])
142
+ # if chatbot[-1]["role"] == "user"
143
+ # else None,
144
+ # }
145
+ # yield f"event: output\ndata: {json.dumps(data)}\n\n"
146
+ yield f"event: output\ndata: {json.dumps({'message': 'Stream ready'})}\n\n"
147
+
148
+ return StreamingResponse(output_stream(), media_type="text/event-stream")
149
+
150
+ return self.app