Spaces:
Paused
Paused
| import asyncio | |
| import concurrent.futures | |
| import json | |
| from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union | |
| import litellm | |
| from litellm._logging import verbose_logger | |
| from litellm.llms.base_llm.realtime.transformation import BaseRealtimeConfig | |
| from litellm.types.llms.openai import ( | |
| OpenAIRealtimeEvents, | |
| OpenAIRealtimeOutputItemDone, | |
| OpenAIRealtimeResponseDelta, | |
| OpenAIRealtimeStreamResponseBaseObject, | |
| OpenAIRealtimeStreamSessionEvents, | |
| ) | |
| from litellm.types.realtime import ALL_DELTA_TYPES | |
| from .litellm_logging import Logging as LiteLLMLogging | |
| if TYPE_CHECKING: | |
| from websockets.asyncio.client import ClientConnection | |
| CLIENT_CONNECTION_CLASS = ClientConnection | |
| else: | |
| CLIENT_CONNECTION_CLASS = Any | |
| # Create a thread pool with a maximum of 10 threads | |
| executor = concurrent.futures.ThreadPoolExecutor(max_workers=10) | |
| DefaultLoggedRealTimeEventTypes = [ | |
| "session.created", | |
| "response.create", | |
| "response.done", | |
| ] | |
| class RealTimeStreaming: | |
| def __init__( | |
| self, | |
| websocket: Any, | |
| backend_ws: CLIENT_CONNECTION_CLASS, | |
| logging_obj: LiteLLMLogging, | |
| provider_config: Optional[BaseRealtimeConfig] = None, | |
| model: str = "", | |
| ): | |
| self.websocket = websocket | |
| self.backend_ws = backend_ws | |
| self.logging_obj = logging_obj | |
| self.messages: List[OpenAIRealtimeEvents] = [] | |
| self.input_message: Dict = {} | |
| _logged_real_time_event_types = litellm.logged_real_time_event_types | |
| if _logged_real_time_event_types is None: | |
| _logged_real_time_event_types = DefaultLoggedRealTimeEventTypes | |
| self.logged_real_time_event_types = _logged_real_time_event_types | |
| self.provider_config = provider_config | |
| self.model = model | |
| self.current_delta_chunks: Optional[List[OpenAIRealtimeResponseDelta]] = None | |
| self.current_output_item_id: Optional[str] = None | |
| self.current_response_id: Optional[str] = None | |
| self.current_conversation_id: Optional[str] = None | |
| self.current_item_chunks: Optional[List[OpenAIRealtimeOutputItemDone]] = None | |
| self.current_delta_type: Optional[ALL_DELTA_TYPES] = None | |
| self.session_configuration_request: Optional[str] = None | |
| def _should_store_message( | |
| self, | |
| message_obj: Union[dict, OpenAIRealtimeEvents], | |
| ) -> bool: | |
| _msg_type = message_obj["type"] if "type" in message_obj else None | |
| if self.logged_real_time_event_types == "*": | |
| return True | |
| if _msg_type and _msg_type in self.logged_real_time_event_types: | |
| return True | |
| return False | |
| def store_message(self, message: Union[str, bytes, OpenAIRealtimeEvents]): | |
| """Store message in list""" | |
| if isinstance(message, bytes): | |
| message = message.decode("utf-8") | |
| if isinstance(message, dict): | |
| message_obj = message | |
| else: | |
| message_obj = json.loads(message) | |
| try: | |
| if ( | |
| not isinstance(message, dict) | |
| or message_obj.get("type") == "session.created" | |
| or message_obj.get("type") == "session.updated" | |
| ): | |
| message_obj = OpenAIRealtimeStreamSessionEvents(**message_obj) # type: ignore | |
| elif not isinstance(message, dict): | |
| message_obj = OpenAIRealtimeStreamResponseBaseObject(**message_obj) # type: ignore | |
| except Exception as e: | |
| verbose_logger.debug(f"Error parsing message for logging: {e}") | |
| raise e | |
| if self._should_store_message(message_obj): | |
| self.messages.append(message_obj) | |
| def store_input(self, message: dict): | |
| """Store input message""" | |
| self.input_message = message | |
| if self.logging_obj: | |
| self.logging_obj.pre_call(input=message, api_key="") | |
| async def log_messages(self): | |
| """Log messages in list""" | |
| if self.logging_obj: | |
| ## ASYNC LOGGING | |
| # Create an event loop for the new thread | |
| asyncio.create_task(self.logging_obj.async_success_handler(self.messages)) | |
| ## SYNC LOGGING | |
| executor.submit(self.logging_obj.success_handler(self.messages)) | |
| async def backend_to_client_send_messages(self): | |
| import websockets | |
| try: | |
| while True: | |
| try: | |
| raw_response = await self.backend_ws.recv( | |
| decode=False | |
| ) # improves performance | |
| except TypeError: | |
| raw_response = await self.backend_ws.recv() # type: ignore[assignment] | |
| if self.provider_config: | |
| returned_object = self.provider_config.transform_realtime_response( | |
| raw_response, | |
| self.model, | |
| self.logging_obj, | |
| realtime_response_transform_input={ | |
| "session_configuration_request": self.session_configuration_request, | |
| "current_output_item_id": self.current_output_item_id, | |
| "current_response_id": self.current_response_id, | |
| "current_delta_chunks": self.current_delta_chunks, | |
| "current_conversation_id": self.current_conversation_id, | |
| "current_item_chunks": self.current_item_chunks, | |
| "current_delta_type": self.current_delta_type, | |
| }, | |
| ) | |
| transformed_response = returned_object["response"] | |
| self.current_output_item_id = returned_object[ | |
| "current_output_item_id" | |
| ] | |
| self.current_response_id = returned_object["current_response_id"] | |
| self.current_delta_chunks = returned_object["current_delta_chunks"] | |
| self.current_conversation_id = returned_object[ | |
| "current_conversation_id" | |
| ] | |
| self.current_item_chunks = returned_object["current_item_chunks"] | |
| self.current_delta_type = returned_object["current_delta_type"] | |
| self.session_configuration_request = returned_object[ | |
| "session_configuration_request" | |
| ] | |
| if isinstance(transformed_response, list): | |
| for event in transformed_response: | |
| event_str = json.dumps(event) | |
| ## LOGGING | |
| self.store_message(event_str) | |
| await self.websocket.send_text(event_str) | |
| else: | |
| event_str = json.dumps(transformed_response) | |
| ## LOGGING | |
| self.store_message(event_str) | |
| await self.websocket.send_text(event_str) | |
| else: | |
| ## LOGGING | |
| self.store_message(raw_response) | |
| await self.websocket.send_text(raw_response) | |
| except websockets.exceptions.ConnectionClosed as e: # type: ignore | |
| verbose_logger.exception( | |
| f"Connection closed in backend to client send messages - {e}" | |
| ) | |
| except Exception as e: | |
| verbose_logger.exception(f"Error in backend to client send messages: {e}") | |
| finally: | |
| await self.log_messages() | |
| async def client_ack_messages(self): | |
| try: | |
| while True: | |
| message = await self.websocket.receive_text() | |
| ## LOGGING | |
| self.store_input(message=message) | |
| ## FORWARD TO BACKEND | |
| if self.provider_config: | |
| message = self.provider_config.transform_realtime_request( | |
| message, self.model | |
| ) | |
| for msg in message: | |
| await self.backend_ws.send(msg) | |
| else: | |
| await self.backend_ws.send(message) | |
| except Exception as e: | |
| verbose_logger.debug(f"Error in client ack messages: {e}") | |
| async def bidirectional_forward(self): | |
| forward_task = asyncio.create_task(self.backend_to_client_send_messages()) | |
| try: | |
| await self.client_ack_messages() | |
| except self.websocket.exceptions.ConnectionClosed: # type: ignore | |
| verbose_logger.debug("Connection closed") | |
| forward_task.cancel() | |
| finally: | |
| if not forward_task.done(): | |
| forward_task.cancel() | |
| try: | |
| await forward_task | |
| except asyncio.CancelledError: | |
| pass | |