Add tensor IPC transfer mechanism for multimodal data (#32104)
Signed-off-by: Brandon Pelfrey <bpelfrey@nvidia.com> Signed-off-by: Brandon Pelfrey <brandonpelfrey@gmail.com> Signed-off-by: Nick Hill <nickhill123@gmail.com> Co-authored-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -278,3 +278,148 @@ def test_custom_class_serialization_disallowed_without_pickle():
|
||||
with pytest.raises(TypeError):
|
||||
# Attempt to encode the custom class
|
||||
encoder.encode(obj)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestWithTensor:
|
||||
"""Mock request with non-multimodal tensor field like EngineCoreRequest."""
|
||||
|
||||
prompt_embeds: torch.Tensor | None
|
||||
data: str
|
||||
|
||||
|
||||
def test_non_multimodal_tensor_with_ipc():
|
||||
"""Test that non-multimodal tensor fields work correctly with IPC enabled.
|
||||
|
||||
This reproduces the bug where fields like prompt_embeds: torch.Tensor | None
|
||||
would fail to decode when IPC is enabled because _decode_tensor expected a
|
||||
raw tensor tuple but received a msgpack-decoded TensorIpcHandle list.
|
||||
"""
|
||||
import torch.multiprocessing as torch_mp
|
||||
|
||||
from vllm.v1.engine.tensor_ipc import TensorIpcReceiver, TensorIpcSender
|
||||
|
||||
# Create tensor queues for IPC
|
||||
tensor_queues = [torch_mp.Queue()]
|
||||
|
||||
# Create encoder with IPC sender
|
||||
sender = TensorIpcSender(tensor_queues[0])
|
||||
encoder = MsgpackEncoder(oob_tensor_consumer=sender)
|
||||
|
||||
# Create decoder with IPC receiver
|
||||
receiver = TensorIpcReceiver(tensor_queues[0])
|
||||
decoder = MsgpackDecoder(RequestWithTensor, oob_tensor_provider=receiver)
|
||||
|
||||
# Create a request with a non-multimodal tensor
|
||||
original_tensor = torch.randn(5, 10, dtype=torch.float32)
|
||||
request = RequestWithTensor(prompt_embeds=original_tensor, data="test_data")
|
||||
|
||||
# Encode the request - this should send the tensor via IPC
|
||||
encoded = encoder.encode(request)
|
||||
|
||||
# Verify encoding succeeded
|
||||
assert len(encoded) > 0
|
||||
|
||||
# Decode the request - this should retrieve the tensor from IPC queue
|
||||
# Previously this would fail because the decoder tried to unpack the
|
||||
# handle list as raw tensor bytes metadata.
|
||||
decoded = decoder.decode(encoded)
|
||||
|
||||
# Verify the decoded request matches the original
|
||||
assert isinstance(decoded, RequestWithTensor)
|
||||
assert decoded.data == "test_data"
|
||||
assert decoded.prompt_embeds is not None
|
||||
assert torch.allclose(decoded.prompt_embeds, original_tensor), (
|
||||
"Decoded tensor does not match the original tensor."
|
||||
)
|
||||
|
||||
|
||||
def test_non_multimodal_tensor_with_ipc_none_value():
|
||||
"""Test that None values for tensor fields work correctly with IPC enabled."""
|
||||
import torch.multiprocessing as torch_mp
|
||||
|
||||
from vllm.v1.engine.tensor_ipc import TensorIpcReceiver, TensorIpcSender
|
||||
|
||||
# Create tensor queues for IPC
|
||||
tensor_queues = [torch_mp.Queue()]
|
||||
|
||||
# Create encoder with IPC sender
|
||||
sender = TensorIpcSender(tensor_queues[0])
|
||||
encoder = MsgpackEncoder(oob_tensor_consumer=sender)
|
||||
|
||||
# Create decoder with IPC receiver
|
||||
receiver = TensorIpcReceiver(tensor_queues[0])
|
||||
decoder = MsgpackDecoder(RequestWithTensor, oob_tensor_provider=receiver)
|
||||
|
||||
# Create a request with None for the tensor field
|
||||
request = RequestWithTensor(prompt_embeds=None, data="test_data_with_none")
|
||||
|
||||
# Encode and decode the request
|
||||
encoded = encoder.encode(request)
|
||||
decoded = decoder.decode(encoded)
|
||||
|
||||
# Verify the decoded request matches the original
|
||||
assert isinstance(decoded, RequestWithTensor)
|
||||
assert decoded.data == "test_data_with_none"
|
||||
assert decoded.prompt_embeds is None
|
||||
|
||||
|
||||
def test_multiple_senders_single_receiver_ipc():
|
||||
"""Test N senders sharing a queue with a single receiver via msgpack.
|
||||
|
||||
Simulates the real vLLM topology where multiple API server frontends
|
||||
each have their own MsgpackEncoder + TensorIpcSender, all putting
|
||||
tensors onto the same torch.mp queue, and a single engine core
|
||||
decodes them with one MsgpackDecoder + TensorIpcReceiver.
|
||||
"""
|
||||
import torch.multiprocessing as torch_mp
|
||||
|
||||
from vllm.v1.engine.tensor_ipc import TensorIpcReceiver, TensorIpcSender
|
||||
|
||||
num_senders = 3
|
||||
num_messages_per_sender = 2
|
||||
tensor_queue = torch_mp.Queue()
|
||||
|
||||
# Create N independent senders (each gets its own uuid-based sender_id)
|
||||
senders = []
|
||||
encoders = []
|
||||
for _ in range(num_senders):
|
||||
s = TensorIpcSender(tensor_queue)
|
||||
senders.append(s)
|
||||
encoders.append(MsgpackEncoder(oob_tensor_consumer=s))
|
||||
|
||||
# Single receiver
|
||||
receiver = TensorIpcReceiver(tensor_queue)
|
||||
decoder = MsgpackDecoder(RequestWithTensor, oob_tensor_provider=receiver)
|
||||
|
||||
# Encode messages from all senders, interleaving the order
|
||||
# so that tensors from different senders land on the queue interleaved.
|
||||
encoded_payloads: list[tuple[int, int, torch.Tensor, list]] = []
|
||||
for msg_idx in range(num_messages_per_sender):
|
||||
for sender_idx in range(num_senders):
|
||||
tensor = torch.full(
|
||||
(sender_idx + 1, msg_idx + 2),
|
||||
float(sender_idx * 100 + msg_idx),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
req = RequestWithTensor(
|
||||
prompt_embeds=tensor,
|
||||
data=f"s{sender_idx}_m{msg_idx}",
|
||||
)
|
||||
encoded = encoders[sender_idx].encode(req)
|
||||
encoded_payloads.append((sender_idx, msg_idx, tensor, encoded))
|
||||
|
||||
# Decode all messages — the receiver must correctly match each
|
||||
# tensor handle to the right TensorIpcData from the shared queue.
|
||||
for sender_idx, msg_idx, original_tensor, encoded in encoded_payloads:
|
||||
decoded = decoder.decode(encoded)
|
||||
assert isinstance(decoded, RequestWithTensor)
|
||||
assert decoded.data == f"s{sender_idx}_m{msg_idx}"
|
||||
assert decoded.prompt_embeds is not None
|
||||
assert decoded.prompt_embeds.shape == original_tensor.shape, (
|
||||
f"Shape mismatch for sender {sender_idx} msg {msg_idx}: "
|
||||
f"{decoded.prompt_embeds.shape} != {original_tensor.shape}"
|
||||
)
|
||||
assert torch.allclose(decoded.prompt_embeds, original_tensor), (
|
||||
f"Value mismatch for sender {sender_idx} msg {msg_idx}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user