Add tensor IPC transfer mechanism for multimodal data (#32104)
Signed-off-by: Brandon Pelfrey <bpelfrey@nvidia.com> Signed-off-by: Brandon Pelfrey <brandonpelfrey@gmail.com> Signed-off-by: Nick Hill <nickhill123@gmail.com> Co-authored-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -10,6 +10,7 @@ from contextlib import AbstractContextManager
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import connection
|
||||
from multiprocessing.process import BaseProcess
|
||||
from multiprocessing.queues import Queue
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@@ -173,6 +174,7 @@ class APIServerProcessManager:
|
||||
input_addresses: list[str],
|
||||
output_addresses: list[str],
|
||||
stats_update_address: str | None = None,
|
||||
tensor_queue: Queue | None = None,
|
||||
):
|
||||
"""Initialize and start API server worker processes.
|
||||
|
||||
@@ -185,6 +187,7 @@ class APIServerProcessManager:
|
||||
input_addresses: Input addresses for each API server
|
||||
output_addresses: Output addresses for each API server
|
||||
stats_update_address: Optional stats update address
|
||||
tensor_queue: Optional tensor IPC queue for sharing MM tensors
|
||||
"""
|
||||
self.listen_address = listen_address
|
||||
self.sock = sock
|
||||
@@ -205,6 +208,8 @@ class APIServerProcessManager:
|
||||
}
|
||||
if stats_update_address is not None:
|
||||
client_config["stats_update_address"] = stats_update_address
|
||||
if tensor_queue is not None:
|
||||
client_config["tensor_queue"] = tensor_queue
|
||||
|
||||
proc = spawn_context.Process(
|
||||
target=target_server_fn,
|
||||
@@ -419,7 +424,7 @@ def tensor_data(tensor: torch.Tensor) -> memoryview:
|
||||
Returns:
|
||||
A memoryview of the tensor data as uint8.
|
||||
"""
|
||||
return tensor.flatten().contiguous().view(torch.uint8).numpy().data
|
||||
return tensor.flatten().cpu().contiguous().view(torch.uint8).numpy().data
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user