Add tensor IPC transfer mechanism for multimodal data (#32104)

Signed-off-by: Brandon Pelfrey <bpelfrey@nvidia.com>
Signed-off-by: Brandon Pelfrey <brandonpelfrey@gmail.com>
Signed-off-by: Nick Hill <nickhill123@gmail.com>
Co-authored-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Brandon Pelfrey
2026-03-21 13:10:20 -07:00
committed by GitHub
parent 61e381dcf0
commit 80b70884eb
13 changed files with 1430 additions and 25 deletions

View File

@@ -278,3 +278,148 @@ def test_custom_class_serialization_disallowed_without_pickle():
with pytest.raises(TypeError):
# Attempt to encode the custom class
encoder.encode(obj)
@dataclass
class RequestWithTensor:
"""Mock request with non-multimodal tensor field like EngineCoreRequest."""
prompt_embeds: torch.Tensor | None
data: str
def test_non_multimodal_tensor_with_ipc():
"""Test that non-multimodal tensor fields work correctly with IPC enabled.
This reproduces the bug where fields like prompt_embeds: torch.Tensor | None
would fail to decode when IPC is enabled because _decode_tensor expected a
raw tensor tuple but received a msgpack-decoded TensorIpcHandle list.
"""
import torch.multiprocessing as torch_mp
from vllm.v1.engine.tensor_ipc import TensorIpcReceiver, TensorIpcSender
# Create tensor queues for IPC
tensor_queues = [torch_mp.Queue()]
# Create encoder with IPC sender
sender = TensorIpcSender(tensor_queues[0])
encoder = MsgpackEncoder(oob_tensor_consumer=sender)
# Create decoder with IPC receiver
receiver = TensorIpcReceiver(tensor_queues[0])
decoder = MsgpackDecoder(RequestWithTensor, oob_tensor_provider=receiver)
# Create a request with a non-multimodal tensor
original_tensor = torch.randn(5, 10, dtype=torch.float32)
request = RequestWithTensor(prompt_embeds=original_tensor, data="test_data")
# Encode the request - this should send the tensor via IPC
encoded = encoder.encode(request)
# Verify encoding succeeded
assert len(encoded) > 0
# Decode the request - this should retrieve the tensor from IPC queue
# Previously this would fail because the decoder tried to unpack the
# handle list as raw tensor bytes metadata.
decoded = decoder.decode(encoded)
# Verify the decoded request matches the original
assert isinstance(decoded, RequestWithTensor)
assert decoded.data == "test_data"
assert decoded.prompt_embeds is not None
assert torch.allclose(decoded.prompt_embeds, original_tensor), (
"Decoded tensor does not match the original tensor."
)
def test_non_multimodal_tensor_with_ipc_none_value():
"""Test that None values for tensor fields work correctly with IPC enabled."""
import torch.multiprocessing as torch_mp
from vllm.v1.engine.tensor_ipc import TensorIpcReceiver, TensorIpcSender
# Create tensor queues for IPC
tensor_queues = [torch_mp.Queue()]
# Create encoder with IPC sender
sender = TensorIpcSender(tensor_queues[0])
encoder = MsgpackEncoder(oob_tensor_consumer=sender)
# Create decoder with IPC receiver
receiver = TensorIpcReceiver(tensor_queues[0])
decoder = MsgpackDecoder(RequestWithTensor, oob_tensor_provider=receiver)
# Create a request with None for the tensor field
request = RequestWithTensor(prompt_embeds=None, data="test_data_with_none")
# Encode and decode the request
encoded = encoder.encode(request)
decoded = decoder.decode(encoded)
# Verify the decoded request matches the original
assert isinstance(decoded, RequestWithTensor)
assert decoded.data == "test_data_with_none"
assert decoded.prompt_embeds is None
def test_multiple_senders_single_receiver_ipc():
"""Test N senders sharing a queue with a single receiver via msgpack.
Simulates the real vLLM topology where multiple API server frontends
each have their own MsgpackEncoder + TensorIpcSender, all putting
tensors onto the same torch.mp queue, and a single engine core
decodes them with one MsgpackDecoder + TensorIpcReceiver.
"""
import torch.multiprocessing as torch_mp
from vllm.v1.engine.tensor_ipc import TensorIpcReceiver, TensorIpcSender
num_senders = 3
num_messages_per_sender = 2
tensor_queue = torch_mp.Queue()
# Create N independent senders (each gets its own uuid-based sender_id)
senders = []
encoders = []
for _ in range(num_senders):
s = TensorIpcSender(tensor_queue)
senders.append(s)
encoders.append(MsgpackEncoder(oob_tensor_consumer=s))
# Single receiver
receiver = TensorIpcReceiver(tensor_queue)
decoder = MsgpackDecoder(RequestWithTensor, oob_tensor_provider=receiver)
# Encode messages from all senders, interleaving the order
# so that tensors from different senders land on the queue interleaved.
encoded_payloads: list[tuple[int, int, torch.Tensor, list]] = []
for msg_idx in range(num_messages_per_sender):
for sender_idx in range(num_senders):
tensor = torch.full(
(sender_idx + 1, msg_idx + 2),
float(sender_idx * 100 + msg_idx),
dtype=torch.float32,
)
req = RequestWithTensor(
prompt_embeds=tensor,
data=f"s{sender_idx}_m{msg_idx}",
)
encoded = encoders[sender_idx].encode(req)
encoded_payloads.append((sender_idx, msg_idx, tensor, encoded))
# Decode all messages — the receiver must correctly match each
# tensor handle to the right TensorIpcData from the shared queue.
for sender_idx, msg_idx, original_tensor, encoded in encoded_payloads:
decoded = decoder.decode(encoded)
assert isinstance(decoded, RequestWithTensor)
assert decoded.data == f"s{sender_idx}_m{msg_idx}"
assert decoded.prompt_embeds is not None
assert decoded.prompt_embeds.shape == original_tensor.shape, (
f"Shape mismatch for sender {sender_idx} msg {msg_idx}: "
f"{decoded.prompt_embeds.shape} != {original_tensor.shape}"
)
assert torch.allclose(decoded.prompt_embeds, original_tensor), (
f"Value mismatch for sender {sender_idx} msg {msg_idx}"
)

View File

@@ -0,0 +1,943 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for tensor IPC queue functionality."""
import contextlib
import multiprocessing as mp
from dataclasses import dataclass
from multiprocessing.synchronize import Barrier as BarrierType
from multiprocessing.synchronize import Event as EventType
from typing import Any
import pytest
import torch
import torch.multiprocessing as torch_mp
from vllm.v1.engine.tensor_ipc import (
TensorIpcData,
TensorIpcReceiver,
TensorIpcSender,
)
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
@pytest.fixture(scope="module", autouse=True)
def setup_multiprocessing():
"""Set multiprocessing start method to 'spawn' for compatibility."""
with contextlib.suppress(RuntimeError):
# Already set, which is fine
torch_mp.set_start_method("spawn", force=True)
yield
@dataclass
# Use a typed container so the test covers the real vLLM path where tensor IPC
# handles are encoded and decoded as fields nested inside larger msgpack payloads.
class TensorEnvelope:
tensor: torch.Tensor
label: str
def encoder_process(
tensor_queue: torch_mp.Queue,
payload_queue: mp.Queue,
result_queue: mp.Queue,
tensor_data: dict[str, Any],
ready_event: EventType,
retrieval_done: EventType,
):
"""Process that msgpack-encodes and sends tensors via IPC."""
try:
sender = TensorIpcSender(tensor_queue)
encoder = MsgpackEncoder(oob_tensor_consumer=sender)
if torch.cuda.is_available():
device = "cuda:0"
tensor = torch.randn(
*tensor_data["shape"], dtype=tensor_data["dtype"], device=device
)
else:
# Fall back to CPU for testing
device = "cpu"
tensor = torch.randn(*tensor_data["shape"], dtype=tensor_data["dtype"])
message = TensorEnvelope(tensor=tensor, label="cuda-msgpack")
encoded = encoder.encode(message)
payload_queue.put(encoded, timeout=10.0)
ready_event.set()
result_queue.put(
{
"success": True,
"encoded_length": len(encoded),
"device": str(device),
"tensor_shape": tuple(tensor.shape),
}
)
retrieval_done.wait(timeout=30.0)
except Exception as e:
import traceback
ready_event.set()
retrieval_done.set()
result_queue.put(
{"success": False, "error": str(e), "traceback": traceback.format_exc()}
)
def decoder_process(
tensor_queue: torch_mp.Queue,
payload_queue: mp.Queue,
result_queue: mp.Queue,
expected_shape: tuple,
encoder_ready: EventType,
retrieval_done: EventType,
):
"""Process that msgpack-decodes tensors received via IPC."""
try:
if not encoder_ready.wait(timeout=10.0):
raise TimeoutError("Encoder did not signal ready")
encoded = payload_queue.get(timeout=5.0)
receiver = TensorIpcReceiver(tensor_queue)
decoder = MsgpackDecoder(TensorEnvelope, oob_tensor_provider=receiver)
decoded = decoder.decode(encoded)
result_queue.put(
{
"success": True,
"tensor_shape": tuple(decoded.tensor.shape),
"device": str(decoded.tensor.device),
"label": decoded.label,
"matches_expected": tuple(decoded.tensor.shape) == expected_shape,
}
)
except Exception as e:
import traceback
retrieval_done.set()
result_queue.put(
{"success": False, "error": str(e), "traceback": traceback.format_exc()}
)
else:
retrieval_done.set()
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_cuda_tensor_queue_basic():
"""Test CUDA tensor IPC through the msgpack encoder/decoder path."""
tensor_queue = torch_mp.Queue()
payload_queue: mp.Queue = mp.Queue()
result_queue: mp.Queue = mp.Queue()
encoder_ready = mp.Event()
retrieval_done = mp.Event()
tensor_shape = (4, 8, 16)
tensor_dtype = torch.float32
encoder_proc = mp.Process(
target=encoder_process,
args=(
tensor_queue,
payload_queue,
result_queue,
{"shape": tensor_shape, "dtype": tensor_dtype},
encoder_ready,
retrieval_done,
),
)
encoder_proc.start()
decoder_proc = mp.Process(
target=decoder_process,
args=(
tensor_queue,
payload_queue,
result_queue,
tensor_shape,
encoder_ready,
retrieval_done,
),
)
decoder_proc.start()
encoder_result = result_queue.get(timeout=10.0)
decoder_result = result_queue.get(timeout=10.0)
encoder_proc.join(timeout=5.0)
decoder_proc.join(timeout=5.0)
# Verify results
assert encoder_result["success"], (
f"Encoder failed: {encoder_result.get('error')}\n"
f"{encoder_result.get('traceback', '')}"
)
assert decoder_result["success"], (
f"Decoder failed: {decoder_result.get('error')}\n"
f"{decoder_result.get('traceback', '')}"
)
assert decoder_result["matches_expected"], "Tensor shape mismatch"
assert "cuda" in decoder_result["device"], "Tensor not on CUDA device"
assert decoder_result["label"] == "cuda-msgpack"
def test_cpu_tensor_fallback():
"""Test that CPU tensors use standard serialization path."""
encoder = MsgpackEncoder()
# Create a CPU tensor
tensor = torch.randn(3, 4, dtype=torch.float32)
# Encode the tensor (should use standard path, not queue)
encoded = encoder.encode({"test_tensor": tensor})
# Verify encoding succeeded
assert len(encoded) > 0
assert isinstance(encoded, (list, tuple))
# Basic check: no queue should be used, so tensor goes through standard path
# This is mainly to ensure no exceptions are raised
def test_msgpack_encoder_decoder_with_ipc():
"""Test the full msgpack + tensor IPC path in one process."""
tensor_queue = torch_mp.Queue()
sender = TensorIpcSender(tensor_queue)
encoder = MsgpackEncoder(oob_tensor_consumer=sender)
receiver = TensorIpcReceiver(tensor_queue)
decoder = MsgpackDecoder(TensorEnvelope, oob_tensor_provider=receiver)
# Use CPU here to exercise the msgpack + sender/receiver integration
# without relying on same-process CUDA IPC behavior.
tensor = torch.randn(2, 3)
message = TensorEnvelope(tensor=tensor, label="test")
encoded = encoder.encode(message)
assert len(encoded) > 0
decoded = decoder.decode(encoded)
assert isinstance(decoded, TensorEnvelope)
assert decoded.label == "test"
assert torch.allclose(decoded.tensor, tensor)
def test_decoder_buffer_management():
"""Test receiver's tensor buffer management when draining queue."""
tensor_queue = torch_mp.Queue()
sender_id = "test_sender"
message_id = 1
# Put multiple tensors in queue using TensorIpcData
tensors_data = [
(0, torch.randn(2, 3)),
(1, torch.randn(4, 5)),
(2, torch.randn(6, 7)),
]
for tensor_id, tensor in tensors_data:
ipc_data = TensorIpcData(
sender_id=sender_id,
message_id=message_id,
tensor_id=tensor_id,
tensor=tensor,
)
tensor_queue.put(ipc_data)
# Create receiver directly
receiver = TensorIpcReceiver(tensor_queue)
# Request tensor_id=2 (should buffer tensor_id=0 and tensor_id=1)
handle = {"sender_id": sender_id, "message_id": message_id, "tensor_id": 2}
result = receiver("float32", (6, 7), handle)
assert result.shape == (6, 7)
# Verify buffer has tensor_id 0 and 1
sender = receiver._tensor_buffers[sender_id]
tensors = sender.tensors.get(message_id, {})
assert 0 in tensors
assert 1 in tensors
# Request buffered tensor
handle2 = {"sender_id": sender_id, "message_id": message_id, "tensor_id": 0}
result2 = receiver("float32", (2, 3), handle2)
assert result2.shape == (2, 3)
# tensor_id 0 should be removed from buffer
sender = receiver._tensor_buffers[sender_id]
tensors = sender.tensors.get(message_id, {})
assert 0 not in tensors
def api_server_worker(
server_id: int,
tensor_queue: torch_mp.Queue,
result_queue: mp.Queue,
barrier: BarrierType,
retrieval_done: EventType,
):
"""Worker simulating an API server sending tensors."""
try:
# Each server sends a unique tensor
tensor = torch.ones(server_id + 1, server_id + 2) * server_id
sender_id = f"server_{server_id}"
# Wait for all servers to be ready
barrier.wait()
# Send tensor using TensorIpcData
ipc_data = TensorIpcData(
sender_id=sender_id,
message_id=0,
tensor_id=0,
tensor=tensor,
)
tensor_queue.put(ipc_data)
result_queue.put({"server_id": server_id, "success": True})
# Keep process alive until main process has retrieved all tensors
# This prevents shared memory handles from being invalidated
retrieval_done.wait(timeout=30.0)
except Exception as e:
import traceback
result_queue.put(
{
"server_id": server_id,
"success": False,
"error": str(e),
"traceback": traceback.format_exc(),
}
)
def test_multiple_api_servers_to_engine():
"""Test multiple API servers sending to one engine core via multiprocessing."""
num_api_servers = 3
tensor_queue = torch_mp.Queue()
result_queue: mp.Queue = mp.Queue()
barrier = mp.Barrier(num_api_servers)
retrieval_done = mp.Event()
# Start multiple API server processes
processes = []
for server_id in range(num_api_servers):
proc = mp.Process(
target=api_server_worker,
args=(server_id, tensor_queue, result_queue, barrier, retrieval_done),
)
proc.start()
processes.append(proc)
# Collect results from all servers
results = []
for _ in range(num_api_servers):
result = result_queue.get(timeout=10.0)
results.append(result)
# Verify all servers succeeded
for result in results:
assert result["success"], (
f"Server {result['server_id']} failed: {result.get('error')}"
)
# Verify all tensors are in queue
received_tensors = []
for _ in range(num_api_servers):
ipc_data = tensor_queue.get(timeout=1.0)
received_tensors.append((ipc_data.sender_id, ipc_data.tensor))
assert len(received_tensors) == num_api_servers
# Verify tensor content (order may vary with multiprocessing)
tensor_by_sender = {sid: t for sid, t in received_tensors}
for server_id in range(num_api_servers):
expected_id = f"server_{server_id}"
assert expected_id in tensor_by_sender, (
f"Missing tensor from server {server_id}"
)
expected_tensor = torch.ones(server_id + 1, server_id + 2) * server_id
assert torch.allclose(tensor_by_sender[expected_id], expected_tensor)
# Signal workers that retrieval is complete
retrieval_done.set()
# Wait for all processes to complete
for proc in processes:
proc.join(timeout=5.0)
def mixed_tensor_encoder_process(
tensor_queue: torch_mp.Queue,
result_queue: mp.Queue,
ready_event: EventType,
retrieval_done: EventType,
):
"""Process that encodes mixed CPU/CUDA tensors."""
try:
sender = TensorIpcSender(tensor_queue)
_encoder = MsgpackEncoder(oob_tensor_consumer=sender)
# Create only CUDA tensor for IPC (CPU will be serialized)
# But actually, let's just send CUDA tensor directly
cuda_tensor = torch.randn(4, 5, device="cuda:0")
# Manually send via IPC to test the mechanism
cuda_tensor_shared = cuda_tensor.share_memory_()
ipc_data = TensorIpcData(
sender_id="mixed_encoder",
message_id=0,
tensor_id=0,
tensor=cuda_tensor_shared,
)
tensor_queue.put(ipc_data, timeout=10.0)
ready_event.set()
result_queue.put({"success": True, "sent_cuda": True})
# Keep process alive until decoder has retrieved the tensor
retrieval_done.wait(timeout=30.0)
except Exception as e:
import traceback
ready_event.set()
result_queue.put(
{"success": False, "error": str(e), "traceback": traceback.format_exc()}
)
def mixed_tensor_decoder_process(
tensor_queue: torch_mp.Queue,
result_queue: mp.Queue,
encoder_ready: EventType,
retrieval_done: EventType,
):
"""Process that retrieves mixed tensors from queue."""
try:
# Wait for encoder to finish
if not encoder_ready.wait(timeout=10.0):
raise TimeoutError("Encoder did not signal ready")
# Try to get CUDA tensor from queue
ipc_data = tensor_queue.get(timeout=5.0)
result_queue.put(
{
"success": True,
"is_cuda": ipc_data.tensor.is_cuda,
"shape": tuple(ipc_data.tensor.shape),
}
)
# Signal that retrieval is complete
retrieval_done.set()
except Exception as e:
import traceback
retrieval_done.set() # Signal even on failure
result_queue.put(
{"success": False, "error": str(e), "traceback": traceback.format_exc()}
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_mixed_cpu_cuda_tensors():
"""Test encoding with mixed CPU and CUDA tensors using multiprocessing."""
tensor_queue = torch_mp.Queue()
result_queue: mp.Queue = mp.Queue()
encoder_ready = mp.Event()
retrieval_done = mp.Event()
# Start encoder process
encoder_proc = mp.Process(
target=mixed_tensor_encoder_process,
args=(tensor_queue, result_queue, encoder_ready, retrieval_done),
)
encoder_proc.start()
# Start decoder process
decoder_proc = mp.Process(
target=mixed_tensor_decoder_process,
args=(tensor_queue, result_queue, encoder_ready, retrieval_done),
)
decoder_proc.start()
# Get results
encoder_result = result_queue.get(timeout=10.0)
decoder_result = result_queue.get(timeout=10.0)
encoder_proc.join(timeout=5.0)
decoder_proc.join(timeout=5.0)
# Verify encoder succeeded
assert encoder_result["success"], (
f"Encoder failed: {encoder_result.get('error')}\n"
f"{encoder_result.get('traceback', '')}"
)
# Verify decoder succeeded and got CUDA tensor
assert decoder_result["success"], (
f"Decoder failed: {decoder_result.get('error')}\n"
f"{decoder_result.get('traceback', '')}"
)
assert decoder_result["is_cuda"], "Retrieved tensor is not on CUDA"
assert decoder_result["shape"] == (4, 5), (
f"Unexpected shape: {decoder_result['shape']}"
)
def cpu_tensor_ipc_encoder_process(
tensor_queue: torch_mp.Queue,
result_queue: mp.Queue,
tensor_shape: tuple,
ready_event: EventType,
retrieval_done: EventType,
):
"""Process that encodes and sends CPU tensors via IPC queue."""
try:
# Create encoder with IPC enabled for all tensors
sender = TensorIpcSender(tensor_queue)
encoder = MsgpackEncoder(oob_tensor_consumer=sender)
# Create a CPU tensor
tensor = torch.randn(*tensor_shape, dtype=torch.float32)
# Encode the tensor (should use IPC queue, not standard serialization)
encoded = encoder.encode({"test_tensor": tensor})
# Signal that encoding is complete
ready_event.set()
result_queue.put(
{
"success": True,
"encoded_length": len(encoded),
"device": str(tensor.device),
"tensor_shape": tuple(tensor.shape),
}
)
# Keep process alive until decoder has retrieved the tensor
# This is necessary for CPU tensor shared memory to remain valid
retrieval_done.wait(timeout=30.0)
except Exception as e:
import traceback
ready_event.set()
result_queue.put(
{"success": False, "error": str(e), "traceback": traceback.format_exc()}
)
def cpu_tensor_ipc_decoder_process(
tensor_queue: torch_mp.Queue,
result_queue: mp.Queue,
expected_shape: tuple,
encoder_ready: EventType,
retrieval_done: EventType,
):
"""Process that decodes and receives CPU tensors from IPC queue."""
try:
# Wait for encoder to finish sending
if not encoder_ready.wait(timeout=10.0):
raise TimeoutError("Encoder did not signal ready")
# Get tensor from queue
ipc_data = tensor_queue.get(timeout=5.0)
result_queue.put(
{
"success": True,
"tensor_id": ipc_data.tensor_id,
"tensor_shape": tuple(ipc_data.tensor.shape),
"device": str(ipc_data.tensor.device),
"matches_expected": tuple(ipc_data.tensor.shape) == expected_shape,
"is_cpu": ipc_data.tensor.device.type == "cpu",
}
)
# Signal that retrieval is complete
retrieval_done.set()
except Exception as e:
import traceback
retrieval_done.set() # Signal even on failure
result_queue.put(
{"success": False, "error": str(e), "traceback": traceback.format_exc()}
)
def test_cpu_tensor_ipc():
"""Test CPU tensor sharing via IPC queue when mm_tensor_ipc is enabled."""
# Set up single queue and synchronization
tensor_queue = torch_mp.Queue()
result_queue: mp.Queue = mp.Queue()
encoder_ready = mp.Event()
retrieval_done = mp.Event()
tensor_shape = (3, 5, 7)
# Start encoder process
encoder_proc = mp.Process(
target=cpu_tensor_ipc_encoder_process,
args=(
tensor_queue,
result_queue,
tensor_shape,
encoder_ready,
retrieval_done,
),
)
encoder_proc.start()
# Start decoder process
decoder_proc = mp.Process(
target=cpu_tensor_ipc_decoder_process,
args=(
tensor_queue,
result_queue,
tensor_shape,
encoder_ready,
retrieval_done,
),
)
decoder_proc.start()
# Wait for processes and collect results
encoder_result = result_queue.get(timeout=10.0)
decoder_result = result_queue.get(timeout=10.0)
encoder_proc.join(timeout=5.0)
decoder_proc.join(timeout=5.0)
# Verify results
assert encoder_result["success"], (
f"Encoder failed: {encoder_result.get('error')}\n"
f"{encoder_result.get('traceback', '')}"
)
assert decoder_result["success"], (
f"Decoder failed: {decoder_result.get('error')}\n"
f"{decoder_result.get('traceback', '')}"
)
assert decoder_result["matches_expected"], "Tensor shape mismatch"
assert decoder_result["is_cpu"], "Tensor not on CPU device"
def test_ipc_disabled_mode():
"""Test that IPC is disabled when no sender is provided."""
tensor_queues = [torch_mp.Queue()]
# Create encoder without IPC sender (IPC disabled)
encoder = MsgpackEncoder()
# Create a CPU tensor
cpu_tensor = torch.randn(2, 3, dtype=torch.float32)
# Encode the tensor (should use standard serialization, not IPC)
encoded = encoder.encode({"test_tensor": cpu_tensor})
# Verify encoding succeeded
assert len(encoded) > 0
assert isinstance(encoded, (list, tuple))
# Verify queue is empty (no IPC was used)
assert tensor_queues[0].empty(), "Tensor queue should be empty when IPC is disabled"
# If CUDA is available, test with CUDA tensor too
if torch.cuda.is_available():
cuda_tensor = torch.randn(4, 5, device="cuda:0")
encoded_cuda = encoder.encode({"cuda_tensor": cuda_tensor})
assert len(encoded_cuda) > 0
assert tensor_queues[0].empty(), (
"Tensor queue should be empty for CUDA tensor when IPC is disabled"
)
@dataclass
class MultiTensorMessage:
"""Message with multiple tensors to test multi-tensor IPC."""
t1: torch.Tensor
t2: torch.Tensor
sender_label: str
def concurrent_sender_process(
tensor_queue: torch_mp.Queue,
payload_queue: mp.Queue,
result_queue: mp.Queue,
sender_index: int,
num_messages: int,
barrier: BarrierType,
retrieval_done: EventType,
):
"""Process that acts as one of N concurrent senders."""
try:
sender = TensorIpcSender(tensor_queue)
encoder = MsgpackEncoder(oob_tensor_consumer=sender)
# Wait for all senders to be ready before sending
barrier.wait(timeout=10.0)
encoded_payloads = []
for msg_idx in range(num_messages):
# Each sender creates uniquely-shaped tensors so we can
# verify correct routing on the receiver side.
t1 = torch.full((sender_index + 1, 3), float(msg_idx), dtype=torch.float32)
t2 = torch.full(
(2, sender_index + 2), float(msg_idx + 100), dtype=torch.float64
)
msg = MultiTensorMessage(
t1=t1,
t2=t2,
sender_label=f"sender_{sender_index}_msg_{msg_idx}",
)
encoded = encoder.encode(msg)
encoded_payloads.append(encoded)
# Send all encoded payloads via the regular (non-tensor) queue
for encoded in encoded_payloads:
payload_queue.put(encoded, timeout=10.0)
result_queue.put(
{
"success": True,
"sender_index": sender_index,
"num_sent": num_messages,
}
)
# Keep alive so shared-memory handles remain valid
retrieval_done.wait(timeout=30.0)
except Exception as e:
import traceback
result_queue.put(
{
"success": False,
"sender_index": sender_index,
"error": str(e),
"traceback": traceback.format_exc(),
}
)
def test_concurrent_senders_single_receiver():
"""Test N concurrent senders sharing one queue with a single receiver.
Each sender encodes multiple messages (each containing two tensors) via
its own MsgpackEncoder + TensorIpcSender. A single TensorIpcReceiver
on the receiving side must correctly drain-and-buffer interleaved
TensorIpcData items from the shared queue and match them back to the
right message handles during decode.
"""
num_senders = 4
num_messages_per_sender = 3
tensor_queue = torch_mp.Queue()
payload_queue: mp.Queue = mp.Queue()
result_queue: mp.Queue = mp.Queue()
barrier = mp.Barrier(num_senders)
retrieval_done = mp.Event()
# Launch sender processes
processes = []
for i in range(num_senders):
proc = mp.Process(
target=concurrent_sender_process,
args=(
tensor_queue,
payload_queue,
result_queue,
i,
num_messages_per_sender,
barrier,
retrieval_done,
),
)
proc.start()
processes.append(proc)
# Collect send confirmations
send_results = []
for _ in range(num_senders):
send_results.append(result_queue.get(timeout=15.0))
for r in send_results:
assert r["success"], (
f"Sender {r['sender_index']} failed: {r.get('error')}\n"
f"{r.get('traceback', '')}"
)
# Now decode all messages from the main process using a single receiver
receiver = TensorIpcReceiver(tensor_queue)
decoder = MsgpackDecoder(MultiTensorMessage, oob_tensor_provider=receiver)
decoded_messages: list[MultiTensorMessage] = []
total = num_senders * num_messages_per_sender
for _ in range(total):
encoded = payload_queue.get(timeout=10.0)
decoded = decoder.decode(encoded)
assert isinstance(decoded, MultiTensorMessage)
decoded_messages.append(decoded)
# Signal senders they can exit
retrieval_done.set()
# Group by sender_label prefix to verify all messages arrived
by_sender: dict[int, list[MultiTensorMessage]] = {}
for msg in decoded_messages:
# label format: "sender_{i}_msg_{j}"
parts = msg.sender_label.split("_")
sender_idx = int(parts[1])
by_sender.setdefault(sender_idx, []).append(msg)
assert len(by_sender) == num_senders, (
f"Expected {num_senders} senders, got {len(by_sender)}"
)
for sender_idx in range(num_senders):
msgs = sorted(by_sender[sender_idx], key=lambda m: m.sender_label)
assert len(msgs) == num_messages_per_sender, (
f"Sender {sender_idx}: expected {num_messages_per_sender} "
f"messages, got {len(msgs)}"
)
for msg_idx, msg in enumerate(msgs):
assert msg.sender_label == f"sender_{sender_idx}_msg_{msg_idx}"
# Verify tensor shapes match what the sender created
assert msg.t1.shape == (sender_idx + 1, 3)
assert msg.t2.shape == (2, sender_idx + 2)
# Verify tensor values
assert torch.allclose(msg.t1, torch.full_like(msg.t1, float(msg_idx)))
assert torch.allclose(msg.t2, torch.full_like(msg.t2, float(msg_idx + 100)))
for proc in processes:
proc.join(timeout=5.0)
def test_concurrent_senders_interleaved_buffer():
"""Test receiver buffering when tensors from multiple senders interleave.
Manually enqueue TensorIpcData from two senders in an interleaved order
and verify the receiver correctly buffers and retrieves each tensor by
its (sender_id, message_id, tensor_id) handle.
"""
tensor_queue = torch_mp.Queue()
# Sender A: 2 tensors for message 1
a_t0 = torch.randn(2, 3)
a_t1 = torch.randn(4, 5)
# Sender B: 2 tensors for message 1
b_t0 = torch.randn(6, 7)
b_t1 = torch.randn(8, 9)
# Interleave: B_t0, A_t0, B_t1, A_t1
for sid, mid, tid, t in [
("B", 1, 0, b_t0),
("A", 1, 0, a_t0),
("B", 1, 1, b_t1),
("A", 1, 1, a_t1),
]:
tensor_queue.put(
TensorIpcData(sender_id=sid, message_id=mid, tensor_id=tid, tensor=t)
)
receiver = TensorIpcReceiver(tensor_queue)
# Request A_t1 first — receiver must drain and buffer B_t0, A_t0, B_t1
result = receiver(
"float32", a_t1.shape, {"sender_id": "A", "message_id": 1, "tensor_id": 1}
)
assert torch.equal(result, a_t1)
# Now request B_t0 from buffer
result = receiver(
"float32", b_t0.shape, {"sender_id": "B", "message_id": 1, "tensor_id": 0}
)
assert torch.equal(result, b_t0)
# Request A_t0 from buffer
result = receiver(
"float32", a_t0.shape, {"sender_id": "A", "message_id": 1, "tensor_id": 0}
)
assert torch.equal(result, a_t0)
# Request B_t1 from buffer
result = receiver(
"float64", b_t1.shape, {"sender_id": "B", "message_id": 1, "tensor_id": 1}
)
assert torch.equal(result, b_t1)
# All buffers should be drained
for sid in ("A", "B"):
tensors = receiver._tensor_buffers[sid].tensors.get(1, {})
assert len(tensors) == 0, f"Sender {sid} buffer not empty: {tensors}"
def test_mixed_cpu_cuda_with_ipc_enabled():
"""Test that encoder is configured correctly for IPC with all tensor types."""
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
tensor_queue = torch_mp.Queue()
# Create sender and encoder with IPC enabled
sender = TensorIpcSender(tensor_queue)
encoder = MsgpackEncoder(oob_tensor_consumer=sender)
# Verify sender configuration
assert encoder.oob_tensor_consumer is not None, "Consumer should be set"
# Note: Actual IPC transfer only works across processes
# (tested in test_cpu_tensor_ipc)
# This test just verifies the configuration is correct
def test_tensor_cleanup_after_decode():
"""Test that tensors are removed from tracking after successful decode."""
# Create a tensor queue
tensor_queue = torch_mp.Queue()
# Create and encode a tensor
tensor = torch.randn(5, 5)
# Move to shared memory for IPC
if not tensor.is_shared():
tensor.share_memory_()
# Manually create a TensorIpcData and put it in the queue
sender_id = "test_sender"
message_id = 0
tensor_id = 0
ipc_data = TensorIpcData(
sender_id=sender_id,
message_id=message_id,
tensor_id=tensor_id,
tensor=tensor,
)
tensor_queue.put(ipc_data)
# Create receiver directly
receiver = TensorIpcReceiver(tensor_queue)
handle = {
"sender_id": sender_id,
"message_id": message_id,
"tensor_id": tensor_id,
}
# Receive the tensor - this should retrieve it from the queue
decoded_tensor = receiver(
str(tensor.dtype).removeprefix("torch."), tensor.shape, handle
)
# Verify the tensor was decoded
assert decoded_tensor.shape == tensor.shape, "Decoded tensor should match shape"
# Verify the tensor was removed from buffer after decode
sender = receiver._tensor_buffers[sender_id]
tensors = sender.tensors.get(message_id, {})
assert tensor_id not in tensors, "Tensor should be removed from buffer"

View File

@@ -14,7 +14,12 @@ import vllm.envs as envs
from vllm.config.model_arch import (
ModelArchitectureConfig,
)
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig
from vllm.config.multimodal import (
MMCacheType,
MMEncoderTPMode,
MMTensorIPC,
MultiModalConfig,
)
from vllm.config.pooler import PoolerConfig
from vllm.config.scheduler import RunnerType
from vllm.config.utils import config, getattr_iter
@@ -310,6 +315,7 @@ class ModelConfig:
interleave_mm_strings: InitVar[bool | None] = None
skip_mm_profiling: InitVar[bool | None] = None
video_pruning_rate: InitVar[float | None] = None
mm_tensor_ipc: InitVar[MMTensorIPC] = None
def compute_hash(self) -> str:
"""
@@ -430,6 +436,7 @@ class ModelConfig:
interleave_mm_strings: bool | None,
skip_mm_profiling: bool | None,
video_pruning_rate: float | None,
mm_tensor_ipc: MMTensorIPC,
) -> None:
# Keep set served_model_name before maybe_model_redirect(self.model)
self.served_model_name = get_served_model_name(
@@ -612,6 +619,7 @@ class ModelConfig:
interleave_mm_strings=interleave_mm_strings,
skip_mm_profiling=skip_mm_profiling,
video_pruning_rate=video_pruning_rate,
mm_tensor_ipc=mm_tensor_ipc,
)
mm_config_kwargs = {
@@ -1112,6 +1120,22 @@ class ModelConfig:
f"({parallel_config.decode_context_parallel_size})."
)
# torch_shm uses a single IPC queue to rank 0; DP>1 is
# incompatible because API servers can't know which
# CoreEngine the scheduler will assign work to. TP>1 is
# also not supported because this requires broadcasting
# MM tensors between all TP ranks.
if (
self.multimodal_config is not None
and self.multimodal_config.mm_tensor_ipc == "torch_shm"
and parallel_config.world_size_across_dp > 1
):
raise ValueError(
"mm_tensor_ipc='torch_shm' is not supported with "
"data_parallel_size > 1 or tensor_parallel_size > 1 "
"or pipeline_parallel_size > 1."
)
def get_sliding_window(self) -> int | None:
"""Get the sliding window size from the HF text config if present."""
return getattr(self.hf_text_config, "sliding_window", None)

View File

@@ -59,6 +59,7 @@ class MultiModalDummyOptionsBuiltins(TypedDict, total=False):
MMEncoderTPMode = Literal["weights", "data"]
MMCacheType = Literal["shm", "lru"]
MMTensorIPC = Literal["direct_rpc", "torch_shm"]
MMDummyOptions: TypeAlias = dict[str, BaseDummyOptions]
"""
A dictionary containing an entry for each modality type of dummy data.
@@ -172,6 +173,11 @@ class MultiModalConfig:
Value sits in range [0;1) and determines fraction of media tokens
from each video to be pruned.
"""
mm_tensor_ipc: MMTensorIPC = "direct_rpc"
"""IPC (inter-process communication) method for multimodal tensors.
- "direct_rpc": Use msgspec serialization via RPC
- "torch_shm": Use torch.multiprocessing shared memory for zero-copy IPC
Defaults to "direct_rpc". """
@field_validator("limit_per_prompt", mode="before")
@classmethod

View File

@@ -766,6 +766,17 @@ class VllmConfig:
else:
self.parallel_config.disable_nccl_for_dp_synchronization = False
if (
self.model_config is not None
and self.model_config.multimodal_config is not None
and self.model_config.multimodal_config.mm_tensor_ipc == "torch_shm"
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
):
raise ValueError(
"torch_shm is known to fail without "
"VLLM_WORKER_MULTIPROC_METHOD set to spawn"
)
from vllm.platforms import current_platform
if (

View File

@@ -79,7 +79,7 @@ from vllm.config.model import (
RunnerOption,
TokenizerMode,
)
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MMTensorIPC
from vllm.config.observability import DetailedTraceModules
from vllm.config.parallel import (
All2AllBackend,
@@ -509,6 +509,7 @@ class EngineArgs:
io_processor_plugin: str | None = None
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
video_pruning_rate: float | None = MultiModalConfig.video_pruning_rate
mm_tensor_ipc: MMTensorIPC = MultiModalConfig.mm_tensor_ipc
# LoRA fields
enable_lora: bool = False
max_loras: int = LoRAConfig.max_loras
@@ -1097,6 +1098,9 @@ class EngineArgs:
multimodal_group.add_argument(
"--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"]
)
multimodal_group.add_argument(
"--mm-tensor-ipc", **multimodal_kwargs["mm_tensor_ipc"]
)
# LoRA related configs
lora_kwargs = get_kwargs(LoRAConfig)
@@ -1423,6 +1427,7 @@ class EngineArgs:
override_attention_dtype=self.override_attention_dtype,
logits_processors=self.logits_processors,
video_pruning_rate=self.video_pruning_rate,
mm_tensor_ipc=self.mm_tensor_ipc,
io_processor_plugin=self.io_processor_plugin,
)

View File

@@ -290,7 +290,7 @@ def run_multi_api_server(args: argparse.Namespace):
with launch_core_engines(
vllm_config, executor_class, log_stats, addresses, num_api_servers
) as (local_engine_manager, coordinator, addresses):
) as (local_engine_manager, coordinator, addresses, tensor_queue):
# Construct common args for the APIServerProcessManager up-front.
api_server_manager_kwargs = dict(
target_server_fn=run_api_server_worker_proc,
@@ -303,6 +303,7 @@ def run_multi_api_server(args: argparse.Namespace):
stats_update_address=coordinator.get_stats_publish_address()
if coordinator
else None,
tensor_queue=tensor_queue,
)
# For dp ranks > 0 in external/hybrid DP LB modes, we must delay the

View File

@@ -13,6 +13,7 @@ from enum import IntEnum
from functools import partial
from inspect import isclass, signature
from logging import DEBUG
from multiprocessing.queues import Queue
from typing import Any, TypeVar, cast
import msgspec
@@ -59,6 +60,7 @@ from vllm.v1.engine import (
UtilityOutput,
UtilityResult,
)
from vllm.v1.engine.tensor_ipc import TensorIpcReceiver
from vllm.v1.engine.utils import (
EngineHandshakeMetadata,
EngineZmqAddresses,
@@ -788,6 +790,7 @@ class EngineCoreProc(EngineCore):
executor_class: type[Executor],
log_stats: bool,
client_handshake_address: str | None = None,
tensor_queue: Queue | None = None,
*,
engine_index: int = 0,
):
@@ -802,6 +805,12 @@ class EngineCoreProc(EngineCore):
self.engines_running = False
self.shutdown_state = EngineShutdownState.RUNNING
# Receiver for tensor IPC
self.tensor_ipc_receiver: TensorIpcReceiver | None = None
if tensor_queue is not None:
self.tensor_ipc_receiver = TensorIpcReceiver(tensor_queue)
logger.info("Using tensor IPC queue for multimodal tensor sharing")
with self._perform_handshakes(
handshake_address,
identity,
@@ -1340,9 +1349,11 @@ class EngineCoreProc(EngineCore):
):
"""Input socket IO thread."""
# Msgpack serialization decoding.
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
generic_decoder = MsgpackDecoder()
# Msgpack serialization decoding with optional tensor IPC receiver.
add_request_decoder = MsgpackDecoder(
EngineCoreRequest, oob_tensor_provider=self.tensor_ipc_receiver
)
generic_decoder = MsgpackDecoder(oob_tensor_provider=self.tensor_ipc_receiver)
with ExitStack() as stack, zmq.Context() as ctx:
input_sockets = [
@@ -1418,10 +1429,7 @@ class EngineCoreProc(EngineCore):
self.input_queue.put_nowait((request_type, request))
def process_output_sockets(
self,
output_paths: list[str],
coord_output_path: str | None,
engine_index: int,
self, output_paths: list[str], coord_output_path: str | None, engine_index: int
):
"""Output socket IO thread."""
@@ -1580,6 +1588,7 @@ class DPEngineCoreProc(EngineCoreProc):
executor_class: type[Executor],
log_stats: bool,
client_handshake_address: str | None = None,
tensor_queue: Queue | None = None,
):
assert vllm_config.model_config.is_moe, (
"DPEngineCoreProc should only be used for MoE models"
@@ -1605,6 +1614,7 @@ class DPEngineCoreProc(EngineCoreProc):
log_stats,
client_handshake_address,
engine_index=dp_rank,
tensor_queue=tensor_queue,
)
def _init_data_parallel(self, vllm_config: VllmConfig):

View File

@@ -12,6 +12,7 @@ from collections import defaultdict, deque
from collections.abc import Awaitable, Callable, Sequence
from concurrent.futures import Future
from dataclasses import dataclass
from multiprocessing.queues import Queue
from threading import Thread
from typing import Any, TypeAlias, TypeVar
@@ -45,6 +46,7 @@ from vllm.v1.engine import (
from vllm.v1.engine.coordinator import DPCoordinator
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.engine.exceptions import EngineDeadError
from vllm.v1.engine.tensor_ipc import TensorIpcSender
from vllm.v1.engine.utils import (
CoreEngineActorManager,
CoreEngineProcManager,
@@ -477,9 +479,6 @@ class MPClient(EngineCoreClient):
client_addresses: dict[str, str] | None = None,
):
self.vllm_config = vllm_config
# Serialization setup.
self.encoder = MsgpackEncoder()
self.decoder = MsgpackDecoder(EngineCoreOutputs)
# ZMQ setup.
sync_ctx = zmq.Context(io_threads=2)
@@ -501,11 +500,14 @@ class MPClient(EngineCoreClient):
enable_input_socket_handover = parallel_config.enable_elastic_ep
self.stats_update_address: str | None = None
tensor_queue: Queue | None = None
if client_addresses:
# Engines are managed externally to this client.
input_address = client_addresses["input_address"]
output_address = client_addresses["output_address"]
self.stats_update_address = client_addresses.get("stats_update_address")
# Tensor queues passed via client_addresses for multi-API-server case
tensor_queue = client_addresses.get("tensor_queue") # type: ignore[assignment]
self.input_socket = self.resources.input_socket = make_zmq_socket(
self.ctx,
input_address,
@@ -532,7 +534,7 @@ class MPClient(EngineCoreClient):
with launch_core_engines(
vllm_config, executor_class, log_stats, addresses
) as (engine_manager, coordinator, addresses):
) as (engine_manager, coordinator, addresses, tensor_queue):
self.resources.coordinator = coordinator
self.resources.engine_manager = engine_manager
@@ -542,6 +544,17 @@ class MPClient(EngineCoreClient):
coordinator.get_stats_publish_address()
)
# Serialization setup with tensor queues for multimodal tensor IPC.
tensor_ipc_sender: TensorIpcSender | None = None
model_config = getattr(vllm_config, "model_config", None)
if model_config is not None and model_config.multimodal_config is not None:
mm_tensor_ipc = model_config.multimodal_config.mm_tensor_ipc
if mm_tensor_ipc == "torch_shm" and tensor_queue is not None:
tensor_ipc_sender = TensorIpcSender(tensor_queue)
self.encoder = MsgpackEncoder(oob_tensor_consumer=tensor_ipc_sender)
self.decoder = MsgpackDecoder(EngineCoreOutputs)
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_index
dp_local_size = parallel_config.data_parallel_size_local

View File

@@ -0,0 +1,178 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tensor IPC transport via torch.multiprocessing.Queue.
This module contains the queue-based transport logic for sharing tensors
between processes (e.g., API server -> engine core). The msgpack layer
emits/consumes lightweight :class:`TensorIpcData` values, while transport
state such as request association, handle generation, queue routing, buffering,
and cleanup lives here.
"""
import dataclasses
import uuid
from collections import defaultdict
from dataclasses import field
from multiprocessing.queues import Queue as MPQueue
from typing import Any
import torch
from vllm.logger import init_logger
from vllm.v1.serial_utils import OOBTensorConsumer
logger = init_logger(__name__)
TensorIpcQueue = MPQueue
@dataclasses.dataclass
class TensorIpcData:
"""
Data sent via torch.multiprocessing.Queue for zero-copy IPC.
Contains the tensor_id and the actual tensor. The tensor is
shared in memory (GPU or CPU) for efficient inter-process communication.
"""
sender_id: str
message_id: int
tensor_id: int
tensor: torch.Tensor
class TensorIpcSender(OOBTensorConsumer):
"""Send-side logic for tensor IPC via torch.multiprocessing.Queue.
Uses a single queue targeting rank 0 (the only rank that consumes
multimodal tensors during TP>1 / PP>1. Note: DP>1 not supported).
"""
def __init__(self, queue: TensorIpcQueue):
self.queue = queue
self._tensor_id_counter = 0
self._message_counter = 0
self._sender_id = uuid.uuid4().hex[:8]
def set_target_engine(self, target_engine: int) -> None:
if target_engine != 0:
raise IndexError(
"TensorIpcSender only supports a single queue; "
f"got target engine {target_engine}"
)
def new_message(self) -> None:
self._message_counter += 1
self._tensor_id_counter = 0
def __call__(self, tensor: torch.Tensor) -> dict[str, Any] | None:
"""Send tensor via queue, return its handle. Returns None if failed."""
try:
# Move tensor to shared memory for IPC
# This is required for proper inter-process communication
if not tensor.is_shared():
tensor = tensor.share_memory_()
metadata = {
"sender_id": self._sender_id,
"message_id": self._message_counter,
"tensor_id": self._tensor_id_counter,
}
self._tensor_id_counter += 1
ipc_data = TensorIpcData(**metadata, tensor=tensor) # type: ignore[arg-type]
# Use a timeout to avoid blocking indefinitely
self.queue.put(ipc_data, timeout=10.0)
logger.debug(
"Sent tensor %s for (shape=%s, device=%s) "
"via IPC queue (shared memory)",
metadata,
tensor.shape,
tensor.device,
)
return metadata
except Exception as e:
logger.warning(
"Failed to send tensor via IPC queue: %s. "
"Falling back to standard serialization.",
e,
)
return None
@dataclasses.dataclass
class _Sender:
current_message_id: int = -1
tensors: dict[int, dict[int, torch.Tensor]] = field(default_factory=dict)
class TensorIpcReceiver:
"""Receive-side logic for tensor IPC via torch.multiprocessing.Queue.
Wraps the queue receive logic previously embedded in MsgpackDecoder.
"""
def __init__(self, queue: TensorIpcQueue):
self.queue = queue
self._tensor_buffers = defaultdict[str, _Sender](_Sender)
def __call__(
self, dtype: str, shape: tuple[int, ...], meta: dict[str, Any]
) -> torch.Tensor:
"""Retrieve a tensor from torch.multiprocessing.Queue.
Uses a drain-and-buffer pattern: drains all available tensors from
the queue, buffering them, until the requested tensor is found.
Works for CUDA and CPU.
"""
# Create lookup key from handle
sender_id: str = meta["sender_id"]
message_id: int = meta["message_id"]
tensor_id: int = meta["tensor_id"]
# Drain all available tensors. We save them regardless if this is
# the one we're waiting for as they may arrive out of order from
# multiple producers.
while True:
sender = self._tensor_buffers.get(sender_id)
if sender is not None:
tensors = sender.tensors
tensor = tensors.get(message_id, {}).pop(tensor_id, None)
if tensor is not None:
if sender.current_message_id != message_id:
while tensors and (mid := next(iter(tensors))) < message_id:
if sender.tensors.pop(mid):
logger.warning(
"Discarding %d stale tensors from sender %s",
sender_id,
)
sender.current_message_id = message_id
logger.debug(
"Received tensor %s from sender %s for (shape=%s, device=%s) "
"via IPC queue (shared memory)",
(message_id, tensor_id),
sender_id,
tensor.shape,
tensor.device,
)
return tensor
ipc_data: TensorIpcData = self.queue.get(timeout=10.0)
# Store tensor
sender = self._tensor_buffers[ipc_data.sender_id]
if sender.current_message_id > ipc_data.message_id:
logger.warning(
"Ignoring stale tensor from sender %s", ipc_data.sender_id
)
continue
sender.tensors.setdefault(ipc_data.message_id, {})[ipc_data.tensor_id] = (
ipc_data.tensor
)

View File

@@ -10,6 +10,7 @@ from dataclasses import dataclass
from enum import Enum, auto
from multiprocessing import Process, connection
from multiprocessing.process import BaseProcess
from multiprocessing.queues import Queue
from typing import TYPE_CHECKING
from unittest.mock import patch
@@ -95,6 +96,7 @@ class CoreEngineProcManager:
executor_class: type[Executor],
log_stats: bool,
client_handshake_address: str | None = None,
tensor_queue: Queue | None = None,
):
context = get_mp_context()
common_kwargs = {
@@ -103,6 +105,7 @@ class CoreEngineProcManager:
"handshake_address": handshake_address,
"executor_class": executor_class,
"log_stats": log_stats,
"tensor_queue": tensor_queue,
}
if client_handshake_address:
@@ -864,6 +867,7 @@ def launch_core_engines(
CoreEngineProcManager | CoreEngineActorManager | None,
DPCoordinator | None,
EngineZmqAddresses,
Queue | None,
]
]:
"""Launch engine and DP coordinator processes as needed."""
@@ -878,6 +882,14 @@ def launch_core_engines(
offline_mode = local_start_index is not None
# Create a single tensor IPC queue for sharing multimodal tensors between
# API servers and engine core. Returns a single queue since we only support
# DP=1 for this data flow.
tensor_queue: Queue | None = None
multimodal_config = vllm_config.model_config.multimodal_config
if multimodal_config is not None and multimodal_config.mm_tensor_ipc == "torch_shm":
tensor_queue = get_mp_context().Queue()
# Run the DP Coordinator process with rank 0 when in online DP mode.
# The coordinator is needed for:
# 1. Internal/hybrid LB: collecting and publishing queue stats for load balancing
@@ -913,7 +925,7 @@ def launch_core_engines(
log_stats=log_stats,
)
yield engine_actor_manager, coordinator, addresses
yield engine_actor_manager, coordinator, addresses, tensor_queue
return
if offline_mode:
@@ -975,11 +987,12 @@ def launch_core_engines(
local_engine_count=local_engine_count,
start_index=dp_rank,
local_start_index=local_start_index or 0,
tensor_queue=tensor_queue,
)
else:
local_engine_manager = None
yield local_engine_manager, coordinator, addresses
yield local_engine_manager, coordinator, addresses, tensor_queue
# Now wait for engines to start.
wait_for_engine_startup(

View File

@@ -4,6 +4,7 @@
import dataclasses
import importlib
import pickle
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from functools import partial
from inspect import isclass
@@ -53,6 +54,27 @@ MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = {
bytestr: TypeAlias = bytes | bytearray | memoryview | zmq.Frame
class OOBTensorConsumer(ABC):
@abstractmethod
def __call__(self, tensor: torch.Tensor) -> dict | None:
"""
Called with tensors for the current message.
Returns None to reject the tensor (falls back to regular serialization),
otherwise a dict with arbitrary placeholder data to be included
in the serialized message.
"""
return None
@abstractmethod
def new_message(self) -> None:
"""Called at the start of each new encoded message."""
pass
# dtype, shape, metadata -> tensor
OOBTensorProvider = Callable[[str, tuple[int, ...], dict], torch.Tensor]
def _log_insecure_serialization_warning():
logger.warning_once(
"Allowing insecure serialization using pickle due to "
@@ -119,9 +141,16 @@ class MsgpackEncoder:
By default, arrays below 256B are serialized inline Larger will get sent
via dedicated messages. Note that this is a per-tensor limit.
When a ``oob_tensor_consumer`` is provided, tensors (CUDA and CPU) will be
offered to it for out-of-band handling.
"""
def __init__(self, size_threshold: int | None = None):
def __init__(
self,
size_threshold: int | None = None,
oob_tensor_consumer: OOBTensorConsumer | None = None,
):
if size_threshold is None:
size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
@@ -130,11 +159,14 @@ class MsgpackEncoder:
# pass custom data to the hook otherwise.
self.aux_buffers: list[bytestr] | None = None
self.size_threshold = size_threshold
self.oob_tensor_consumer = oob_tensor_consumer
if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
_log_insecure_serialization_warning()
def encode(self, obj: Any) -> Sequence[bytestr]:
try:
if self.oob_tensor_consumer is not None:
self.oob_tensor_consumer.new_message()
self.aux_buffers = bufs = [b""]
bufs[0] = self.encoder.encode(obj)
# This `bufs` list allows us to collect direct pointers to backing
@@ -147,6 +179,8 @@ class MsgpackEncoder:
def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
try:
if self.oob_tensor_consumer is not None:
self.oob_tensor_consumer.new_message()
self.aux_buffers = [buf]
bufs = self.aux_buffers
self.encoder.encode_into(obj, buf)
@@ -222,17 +256,19 @@ class MsgpackEncoder:
def _encode_tensor(
self, obj: torch.Tensor
) -> tuple[str, tuple[int, ...], int | memoryview]:
assert self.aux_buffers is not None
) -> tuple[str, tuple[int, ...], int | dict | memoryview]:
oob_consumer = self.oob_tensor_consumer
# view the tensor as a contiguous 1D array of bytes
arr_data = tensor_data(obj)
if obj.nbytes < self.size_threshold:
if obj.nbytes < self.size_threshold and obj.is_cpu:
# Smaller tensors are encoded inline, just like ndarrays.
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, tensor_data(obj))
elif oob_consumer is not None and (data := oob_consumer(obj)) is not None:
assert isinstance(data, dict)
else:
# Otherwise encode index of backing buffer to avoid copy.
assert self.aux_buffers is not None
data = len(self.aux_buffers)
self.aux_buffers.append(arr_data)
self.aux_buffers.append(tensor_data(obj))
dtype = str(obj.dtype).removeprefix("torch.")
return dtype, obj.shape, data
@@ -279,9 +315,17 @@ class MsgpackDecoder:
Note that unlike vanilla `msgspec` Decoders, this interface is generally
not thread-safe when encoding tensors / numpy arrays.
``oob_tensor_provider`` must be used when an OOBTensorConsumer is used on the
encoder side.
"""
def __init__(self, t: Any | None = None, share_mem: bool = True):
def __init__(
self,
t: Any | None = None,
share_mem: bool = True,
oob_tensor_provider: OOBTensorProvider | None = None,
):
self.share_mem = share_mem
self.pin_tensors = is_pin_memory_available()
args = () if t is None else (t,)
@@ -289,6 +333,7 @@ class MsgpackDecoder:
*args, ext_hook=self.ext_hook, dec_hook=self.dec_hook
)
self.aux_buffers: Sequence[bytestr] = ()
self.oob_tensor_provider = oob_tensor_provider
if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
_log_insecure_serialization_warning()
@@ -353,6 +398,12 @@ class MsgpackDecoder:
def _decode_tensor(self, arr: Any) -> torch.Tensor:
dtype, shape, data = arr
if isinstance(data, dict):
assert self.oob_tensor_provider, (
"Received OOB tensor but tensor provider is not set"
)
return self.oob_tensor_provider(dtype, shape, data)
is_aux = isinstance(data, int)
buffer = self.aux_buffers[data] if is_aux else data
buffer = buffer if isinstance(buffer, memoryview) else memoryview(buffer)

View File

@@ -10,6 +10,7 @@ from contextlib import AbstractContextManager
from dataclasses import dataclass
from multiprocessing import connection
from multiprocessing.process import BaseProcess
from multiprocessing.queues import Queue
from typing import (
TYPE_CHECKING,
Any,
@@ -173,6 +174,7 @@ class APIServerProcessManager:
input_addresses: list[str],
output_addresses: list[str],
stats_update_address: str | None = None,
tensor_queue: Queue | None = None,
):
"""Initialize and start API server worker processes.
@@ -185,6 +187,7 @@ class APIServerProcessManager:
input_addresses: Input addresses for each API server
output_addresses: Output addresses for each API server
stats_update_address: Optional stats update address
tensor_queue: Optional tensor IPC queue for sharing MM tensors
"""
self.listen_address = listen_address
self.sock = sock
@@ -205,6 +208,8 @@ class APIServerProcessManager:
}
if stats_update_address is not None:
client_config["stats_update_address"] = stats_update_address
if tensor_queue is not None:
client_config["tensor_queue"] = tensor_queue
proc = spawn_context.Process(
target=target_server_fn,
@@ -419,7 +424,7 @@ def tensor_data(tensor: torch.Tensor) -> memoryview:
Returns:
A memoryview of the tensor data as uint8.
"""
return tensor.flatten().contiguous().view(torch.uint8).numpy().data
return tensor.flatten().cpu().contiguous().view(torch.uint8).numpy().data
@dataclass