Spaces:
Configuration error
Configuration error
| import os | |
| import signal | |
| import socket | |
| import sys | |
| from functools import partial | |
| from multiprocessing import Process, Queue | |
| from socketserver import BaseRequestHandler, BaseServer | |
| from types import FrameType | |
| from typing import Any, Dict, Optional, Tuple | |
| from uuid import uuid4 | |
| from inference.core import logger | |
| from inference.enterprise.stream_management.manager.communication import ( | |
| receive_socket_data, | |
| send_data_trough_socket, | |
| ) | |
| from inference.enterprise.stream_management.manager.entities import ( | |
| PIPELINE_ID_KEY, | |
| STATUS_KEY, | |
| TYPE_KEY, | |
| CommandType, | |
| ErrorType, | |
| OperationStatus, | |
| ) | |
| from inference.enterprise.stream_management.manager.errors import MalformedPayloadError | |
| from inference.enterprise.stream_management.manager.inference_pipeline_manager import ( | |
| InferencePipelineManager, | |
| ) | |
| from inference.enterprise.stream_management.manager.serialisation import ( | |
| describe_error, | |
| prepare_error_response, | |
| prepare_response, | |
| ) | |
| from inference.enterprise.stream_management.manager.tcp_server import RoboflowTCPServer | |
| PROCESSES_TABLE: Dict[str, Tuple[Process, Queue, Queue]] = {} | |
| HEADER_SIZE = 4 | |
| SOCKET_BUFFER_SIZE = 16384 | |
| HOST = os.getenv("STREAM_MANAGER_HOST", "127.0.0.1") | |
| PORT = int(os.getenv("STREAM_MANAGER_PORT", "7070")) | |
| SOCKET_TIMEOUT = float(os.getenv("STREAM_MANAGER_SOCKET_TIMEOUT", "5.0")) | |
| class InferencePipelinesManagerHandler(BaseRequestHandler): | |
| def __init__( | |
| self, | |
| request: socket.socket, | |
| client_address: Any, | |
| server: BaseServer, | |
| processes_table: Dict[str, Tuple[Process, Queue, Queue]], | |
| ): | |
| self._processes_table = processes_table # in this case it's required to set the state of class before superclass init - as it invokes handle() | |
| super().__init__(request, client_address, server) | |
| def handle(self) -> None: | |
| pipeline_id: Optional[str] = None | |
| request_id = str(uuid4()) | |
| try: | |
| data = receive_socket_data( | |
| source=self.request, | |
| header_size=HEADER_SIZE, | |
| buffer_size=SOCKET_BUFFER_SIZE, | |
| ) | |
| data[TYPE_KEY] = CommandType(data[TYPE_KEY]) | |
| if data[TYPE_KEY] is CommandType.LIST_PIPELINES: | |
| return self._list_pipelines(request_id=request_id) | |
| if data[TYPE_KEY] is CommandType.INIT: | |
| return self._initialise_pipeline(request_id=request_id, command=data) | |
| pipeline_id = data[PIPELINE_ID_KEY] | |
| if data[TYPE_KEY] is CommandType.TERMINATE: | |
| self._terminate_pipeline( | |
| request_id=request_id, pipeline_id=pipeline_id, command=data | |
| ) | |
| else: | |
| response = handle_command( | |
| processes_table=self._processes_table, | |
| request_id=request_id, | |
| pipeline_id=pipeline_id, | |
| command=data, | |
| ) | |
| serialised_response = prepare_response( | |
| request_id=request_id, response=response, pipeline_id=pipeline_id | |
| ) | |
| send_data_trough_socket( | |
| target=self.request, | |
| header_size=HEADER_SIZE, | |
| data=serialised_response, | |
| request_id=request_id, | |
| pipeline_id=pipeline_id, | |
| ) | |
| except (KeyError, ValueError, MalformedPayloadError) as error: | |
| logger.error( | |
| f"Invalid payload in processes manager. error={error} request_id={request_id}..." | |
| ) | |
| payload = prepare_error_response( | |
| request_id=request_id, | |
| error=error, | |
| error_type=ErrorType.INVALID_PAYLOAD, | |
| pipeline_id=pipeline_id, | |
| ) | |
| send_data_trough_socket( | |
| target=self.request, | |
| header_size=HEADER_SIZE, | |
| data=payload, | |
| request_id=request_id, | |
| pipeline_id=pipeline_id, | |
| ) | |
| except Exception as error: | |
| logger.error( | |
| f"Internal error in processes manager. error={error} request_id={request_id}..." | |
| ) | |
| payload = prepare_error_response( | |
| request_id=request_id, | |
| error=error, | |
| error_type=ErrorType.INTERNAL_ERROR, | |
| pipeline_id=pipeline_id, | |
| ) | |
| send_data_trough_socket( | |
| target=self.request, | |
| header_size=HEADER_SIZE, | |
| data=payload, | |
| request_id=request_id, | |
| pipeline_id=pipeline_id, | |
| ) | |
| def _list_pipelines(self, request_id: str) -> None: | |
| serialised_response = prepare_response( | |
| request_id=request_id, | |
| response={ | |
| "pipelines": list(self._processes_table.keys()), | |
| STATUS_KEY: OperationStatus.SUCCESS, | |
| }, | |
| pipeline_id=None, | |
| ) | |
| send_data_trough_socket( | |
| target=self.request, | |
| header_size=HEADER_SIZE, | |
| data=serialised_response, | |
| request_id=request_id, | |
| ) | |
| def _initialise_pipeline(self, request_id: str, command: dict) -> None: | |
| pipeline_id = str(uuid4()) | |
| command_queue = Queue() | |
| responses_queue = Queue() | |
| inference_pipeline_manager = InferencePipelineManager.init( | |
| command_queue=command_queue, | |
| responses_queue=responses_queue, | |
| ) | |
| inference_pipeline_manager.start() | |
| self._processes_table[pipeline_id] = ( | |
| inference_pipeline_manager, | |
| command_queue, | |
| responses_queue, | |
| ) | |
| command_queue.put((request_id, command)) | |
| response = get_response_ignoring_thrash( | |
| responses_queue=responses_queue, matching_request_id=request_id | |
| ) | |
| serialised_response = prepare_response( | |
| request_id=request_id, response=response, pipeline_id=pipeline_id | |
| ) | |
| send_data_trough_socket( | |
| target=self.request, | |
| header_size=HEADER_SIZE, | |
| data=serialised_response, | |
| request_id=request_id, | |
| pipeline_id=pipeline_id, | |
| ) | |
| def _terminate_pipeline( | |
| self, request_id: str, pipeline_id: str, command: dict | |
| ) -> None: | |
| response = handle_command( | |
| processes_table=self._processes_table, | |
| request_id=request_id, | |
| pipeline_id=pipeline_id, | |
| command=command, | |
| ) | |
| if response[STATUS_KEY] is OperationStatus.SUCCESS: | |
| logger.info( | |
| f"Joining inference pipeline. pipeline_id={pipeline_id} request_id={request_id}" | |
| ) | |
| join_inference_pipeline( | |
| processes_table=self._processes_table, pipeline_id=pipeline_id | |
| ) | |
| logger.info( | |
| f"Joined inference pipeline. pipeline_id={pipeline_id} request_id={request_id}" | |
| ) | |
| serialised_response = prepare_response( | |
| request_id=request_id, response=response, pipeline_id=pipeline_id | |
| ) | |
| send_data_trough_socket( | |
| target=self.request, | |
| header_size=HEADER_SIZE, | |
| data=serialised_response, | |
| request_id=request_id, | |
| pipeline_id=pipeline_id, | |
| ) | |
| def handle_command( | |
| processes_table: Dict[str, Tuple[Process, Queue, Queue]], | |
| request_id: str, | |
| pipeline_id: str, | |
| command: dict, | |
| ) -> dict: | |
| if pipeline_id not in processes_table: | |
| return describe_error(exception=None, error_type=ErrorType.NOT_FOUND) | |
| _, command_queue, responses_queue = processes_table[pipeline_id] | |
| command_queue.put((request_id, command)) | |
| return get_response_ignoring_thrash( | |
| responses_queue=responses_queue, matching_request_id=request_id | |
| ) | |
| def get_response_ignoring_thrash( | |
| responses_queue: Queue, matching_request_id: str | |
| ) -> dict: | |
| while True: | |
| response = responses_queue.get() | |
| if response[0] == matching_request_id: | |
| return response[1] | |
| logger.warning( | |
| f"Dropping response for request_id={response[0]} with payload={response[1]}" | |
| ) | |
| def execute_termination( | |
| signal_number: int, | |
| frame: FrameType, | |
| processes_table: Dict[str, Tuple[Process, Queue, Queue]], | |
| ) -> None: | |
| pipeline_ids = list(processes_table.keys()) | |
| for pipeline_id in pipeline_ids: | |
| logger.info(f"Terminating pipeline: {pipeline_id}") | |
| processes_table[pipeline_id][0].terminate() | |
| logger.info(f"Pipeline: {pipeline_id} terminated.") | |
| logger.info(f"Joining pipeline: {pipeline_id}") | |
| processes_table[pipeline_id][0].join() | |
| logger.info(f"Pipeline: {pipeline_id} joined.") | |
| logger.info(f"Termination handler completed.") | |
| sys.exit(0) | |
| def join_inference_pipeline( | |
| processes_table: Dict[str, Tuple[Process, Queue, Queue]], pipeline_id: str | |
| ) -> None: | |
| inference_pipeline_manager, command_queue, responses_queue = processes_table[ | |
| pipeline_id | |
| ] | |
| inference_pipeline_manager.join() | |
| del processes_table[pipeline_id] | |
| if __name__ == "__main__": | |
| signal.signal( | |
| signal.SIGINT, partial(execute_termination, processes_table=PROCESSES_TABLE) | |
| ) | |
| signal.signal( | |
| signal.SIGTERM, partial(execute_termination, processes_table=PROCESSES_TABLE) | |
| ) | |
| with RoboflowTCPServer( | |
| server_address=(HOST, PORT), | |
| handler_class=partial( | |
| InferencePipelinesManagerHandler, processes_table=PROCESSES_TABLE | |
| ), | |
| socket_operations_timeout=SOCKET_TIMEOUT, | |
| ) as tcp_server: | |
| logger.info( | |
| f"Inference Pipeline Processes Manager is ready to accept connections at {(HOST, PORT)}" | |
| ) | |
| tcp_server.serve_forever() | |