Add tensor IPC transfer mechanism for multimodal data (#32104)

Signed-off-by: Brandon Pelfrey <bpelfrey@nvidia.com>
Signed-off-by: Brandon Pelfrey <brandonpelfrey@gmail.com>
Signed-off-by: Nick Hill <nickhill123@gmail.com>
Co-authored-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Brandon Pelfrey
2026-03-21 13:10:20 -07:00
committed by GitHub
parent 61e381dcf0
commit 80b70884eb
13 changed files with 1430 additions and 25 deletions

View File

@@ -10,6 +10,7 @@ from contextlib import AbstractContextManager
from dataclasses import dataclass
from multiprocessing import connection
from multiprocessing.process import BaseProcess
from multiprocessing.queues import Queue
from typing import (
TYPE_CHECKING,
Any,
@@ -173,6 +174,7 @@ class APIServerProcessManager:
input_addresses: list[str],
output_addresses: list[str],
stats_update_address: str | None = None,
tensor_queue: Queue | None = None,
):
"""Initialize and start API server worker processes.
@@ -185,6 +187,7 @@ class APIServerProcessManager:
input_addresses: Input addresses for each API server
output_addresses: Output addresses for each API server
stats_update_address: Optional stats update address
tensor_queue: Optional tensor IPC queue for sharing MM tensors
"""
self.listen_address = listen_address
self.sock = sock
@@ -205,6 +208,8 @@ class APIServerProcessManager:
}
if stats_update_address is not None:
client_config["stats_update_address"] = stats_update_address
if tensor_queue is not None:
client_config["tensor_queue"] = tensor_queue
proc = spawn_context.Process(
target=target_server_fn,
@@ -419,7 +424,7 @@ def tensor_data(tensor: torch.Tensor) -> memoryview:
Returns:
A memoryview of the tensor data as uint8.
"""
return tensor.flatten().contiguous().view(torch.uint8).numpy().data
return tensor.flatten().cpu().contiguous().view(torch.uint8).numpy().data
@dataclass