Add tensor IPC transfer mechanism for multimodal data (#32104)
Signed-off-by: Brandon Pelfrey <bpelfrey@nvidia.com> Signed-off-by: Brandon Pelfrey <brandonpelfrey@gmail.com> Signed-off-by: Nick Hill <nickhill123@gmail.com> Co-authored-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -278,3 +278,148 @@ def test_custom_class_serialization_disallowed_without_pickle():
|
||||
with pytest.raises(TypeError):
|
||||
# Attempt to encode the custom class
|
||||
encoder.encode(obj)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestWithTensor:
|
||||
"""Mock request with non-multimodal tensor field like EngineCoreRequest."""
|
||||
|
||||
prompt_embeds: torch.Tensor | None
|
||||
data: str
|
||||
|
||||
|
||||
def test_non_multimodal_tensor_with_ipc():
|
||||
"""Test that non-multimodal tensor fields work correctly with IPC enabled.
|
||||
|
||||
This reproduces the bug where fields like prompt_embeds: torch.Tensor | None
|
||||
would fail to decode when IPC is enabled because _decode_tensor expected a
|
||||
raw tensor tuple but received a msgpack-decoded TensorIpcHandle list.
|
||||
"""
|
||||
import torch.multiprocessing as torch_mp
|
||||
|
||||
from vllm.v1.engine.tensor_ipc import TensorIpcReceiver, TensorIpcSender
|
||||
|
||||
# Create tensor queues for IPC
|
||||
tensor_queues = [torch_mp.Queue()]
|
||||
|
||||
# Create encoder with IPC sender
|
||||
sender = TensorIpcSender(tensor_queues[0])
|
||||
encoder = MsgpackEncoder(oob_tensor_consumer=sender)
|
||||
|
||||
# Create decoder with IPC receiver
|
||||
receiver = TensorIpcReceiver(tensor_queues[0])
|
||||
decoder = MsgpackDecoder(RequestWithTensor, oob_tensor_provider=receiver)
|
||||
|
||||
# Create a request with a non-multimodal tensor
|
||||
original_tensor = torch.randn(5, 10, dtype=torch.float32)
|
||||
request = RequestWithTensor(prompt_embeds=original_tensor, data="test_data")
|
||||
|
||||
# Encode the request - this should send the tensor via IPC
|
||||
encoded = encoder.encode(request)
|
||||
|
||||
# Verify encoding succeeded
|
||||
assert len(encoded) > 0
|
||||
|
||||
# Decode the request - this should retrieve the tensor from IPC queue
|
||||
# Previously this would fail because the decoder tried to unpack the
|
||||
# handle list as raw tensor bytes metadata.
|
||||
decoded = decoder.decode(encoded)
|
||||
|
||||
# Verify the decoded request matches the original
|
||||
assert isinstance(decoded, RequestWithTensor)
|
||||
assert decoded.data == "test_data"
|
||||
assert decoded.prompt_embeds is not None
|
||||
assert torch.allclose(decoded.prompt_embeds, original_tensor), (
|
||||
"Decoded tensor does not match the original tensor."
|
||||
)
|
||||
|
||||
|
||||
def test_non_multimodal_tensor_with_ipc_none_value():
|
||||
"""Test that None values for tensor fields work correctly with IPC enabled."""
|
||||
import torch.multiprocessing as torch_mp
|
||||
|
||||
from vllm.v1.engine.tensor_ipc import TensorIpcReceiver, TensorIpcSender
|
||||
|
||||
# Create tensor queues for IPC
|
||||
tensor_queues = [torch_mp.Queue()]
|
||||
|
||||
# Create encoder with IPC sender
|
||||
sender = TensorIpcSender(tensor_queues[0])
|
||||
encoder = MsgpackEncoder(oob_tensor_consumer=sender)
|
||||
|
||||
# Create decoder with IPC receiver
|
||||
receiver = TensorIpcReceiver(tensor_queues[0])
|
||||
decoder = MsgpackDecoder(RequestWithTensor, oob_tensor_provider=receiver)
|
||||
|
||||
# Create a request with None for the tensor field
|
||||
request = RequestWithTensor(prompt_embeds=None, data="test_data_with_none")
|
||||
|
||||
# Encode and decode the request
|
||||
encoded = encoder.encode(request)
|
||||
decoded = decoder.decode(encoded)
|
||||
|
||||
# Verify the decoded request matches the original
|
||||
assert isinstance(decoded, RequestWithTensor)
|
||||
assert decoded.data == "test_data_with_none"
|
||||
assert decoded.prompt_embeds is None
|
||||
|
||||
|
||||
def test_multiple_senders_single_receiver_ipc():
|
||||
"""Test N senders sharing a queue with a single receiver via msgpack.
|
||||
|
||||
Simulates the real vLLM topology where multiple API server frontends
|
||||
each have their own MsgpackEncoder + TensorIpcSender, all putting
|
||||
tensors onto the same torch.mp queue, and a single engine core
|
||||
decodes them with one MsgpackDecoder + TensorIpcReceiver.
|
||||
"""
|
||||
import torch.multiprocessing as torch_mp
|
||||
|
||||
from vllm.v1.engine.tensor_ipc import TensorIpcReceiver, TensorIpcSender
|
||||
|
||||
num_senders = 3
|
||||
num_messages_per_sender = 2
|
||||
tensor_queue = torch_mp.Queue()
|
||||
|
||||
# Create N independent senders (each gets its own uuid-based sender_id)
|
||||
senders = []
|
||||
encoders = []
|
||||
for _ in range(num_senders):
|
||||
s = TensorIpcSender(tensor_queue)
|
||||
senders.append(s)
|
||||
encoders.append(MsgpackEncoder(oob_tensor_consumer=s))
|
||||
|
||||
# Single receiver
|
||||
receiver = TensorIpcReceiver(tensor_queue)
|
||||
decoder = MsgpackDecoder(RequestWithTensor, oob_tensor_provider=receiver)
|
||||
|
||||
# Encode messages from all senders, interleaving the order
|
||||
# so that tensors from different senders land on the queue interleaved.
|
||||
encoded_payloads: list[tuple[int, int, torch.Tensor, list]] = []
|
||||
for msg_idx in range(num_messages_per_sender):
|
||||
for sender_idx in range(num_senders):
|
||||
tensor = torch.full(
|
||||
(sender_idx + 1, msg_idx + 2),
|
||||
float(sender_idx * 100 + msg_idx),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
req = RequestWithTensor(
|
||||
prompt_embeds=tensor,
|
||||
data=f"s{sender_idx}_m{msg_idx}",
|
||||
)
|
||||
encoded = encoders[sender_idx].encode(req)
|
||||
encoded_payloads.append((sender_idx, msg_idx, tensor, encoded))
|
||||
|
||||
# Decode all messages — the receiver must correctly match each
|
||||
# tensor handle to the right TensorIpcData from the shared queue.
|
||||
for sender_idx, msg_idx, original_tensor, encoded in encoded_payloads:
|
||||
decoded = decoder.decode(encoded)
|
||||
assert isinstance(decoded, RequestWithTensor)
|
||||
assert decoded.data == f"s{sender_idx}_m{msg_idx}"
|
||||
assert decoded.prompt_embeds is not None
|
||||
assert decoded.prompt_embeds.shape == original_tensor.shape, (
|
||||
f"Shape mismatch for sender {sender_idx} msg {msg_idx}: "
|
||||
f"{decoded.prompt_embeds.shape} != {original_tensor.shape}"
|
||||
)
|
||||
assert torch.allclose(decoded.prompt_embeds, original_tensor), (
|
||||
f"Value mismatch for sender {sender_idx} msg {msg_idx}"
|
||||
)
|
||||
|
||||
943
tests/v1/test_tensor_ipc_queue.py
Normal file
943
tests/v1/test_tensor_ipc_queue.py
Normal file
@@ -0,0 +1,943 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""Tests for tensor IPC queue functionality."""
|
||||
|
||||
import contextlib
|
||||
import multiprocessing as mp
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing.synchronize import Barrier as BarrierType
|
||||
from multiprocessing.synchronize import Event as EventType
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as torch_mp
|
||||
|
||||
from vllm.v1.engine.tensor_ipc import (
|
||||
TensorIpcData,
|
||||
TensorIpcReceiver,
|
||||
TensorIpcSender,
|
||||
)
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def setup_multiprocessing():
|
||||
"""Set multiprocessing start method to 'spawn' for compatibility."""
|
||||
with contextlib.suppress(RuntimeError):
|
||||
# Already set, which is fine
|
||||
torch_mp.set_start_method("spawn", force=True)
|
||||
yield
|
||||
|
||||
|
||||
@dataclass
|
||||
# Use a typed container so the test covers the real vLLM path where tensor IPC
|
||||
# handles are encoded and decoded as fields nested inside larger msgpack payloads.
|
||||
class TensorEnvelope:
|
||||
tensor: torch.Tensor
|
||||
label: str
|
||||
|
||||
|
||||
def encoder_process(
|
||||
tensor_queue: torch_mp.Queue,
|
||||
payload_queue: mp.Queue,
|
||||
result_queue: mp.Queue,
|
||||
tensor_data: dict[str, Any],
|
||||
ready_event: EventType,
|
||||
retrieval_done: EventType,
|
||||
):
|
||||
"""Process that msgpack-encodes and sends tensors via IPC."""
|
||||
try:
|
||||
sender = TensorIpcSender(tensor_queue)
|
||||
encoder = MsgpackEncoder(oob_tensor_consumer=sender)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda:0"
|
||||
tensor = torch.randn(
|
||||
*tensor_data["shape"], dtype=tensor_data["dtype"], device=device
|
||||
)
|
||||
else:
|
||||
# Fall back to CPU for testing
|
||||
device = "cpu"
|
||||
tensor = torch.randn(*tensor_data["shape"], dtype=tensor_data["dtype"])
|
||||
|
||||
message = TensorEnvelope(tensor=tensor, label="cuda-msgpack")
|
||||
encoded = encoder.encode(message)
|
||||
payload_queue.put(encoded, timeout=10.0)
|
||||
|
||||
ready_event.set()
|
||||
|
||||
result_queue.put(
|
||||
{
|
||||
"success": True,
|
||||
"encoded_length": len(encoded),
|
||||
"device": str(device),
|
||||
"tensor_shape": tuple(tensor.shape),
|
||||
}
|
||||
)
|
||||
retrieval_done.wait(timeout=30.0)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
ready_event.set()
|
||||
retrieval_done.set()
|
||||
result_queue.put(
|
||||
{"success": False, "error": str(e), "traceback": traceback.format_exc()}
|
||||
)
|
||||
|
||||
|
||||
def decoder_process(
|
||||
tensor_queue: torch_mp.Queue,
|
||||
payload_queue: mp.Queue,
|
||||
result_queue: mp.Queue,
|
||||
expected_shape: tuple,
|
||||
encoder_ready: EventType,
|
||||
retrieval_done: EventType,
|
||||
):
|
||||
"""Process that msgpack-decodes tensors received via IPC."""
|
||||
try:
|
||||
if not encoder_ready.wait(timeout=10.0):
|
||||
raise TimeoutError("Encoder did not signal ready")
|
||||
|
||||
encoded = payload_queue.get(timeout=5.0)
|
||||
receiver = TensorIpcReceiver(tensor_queue)
|
||||
decoder = MsgpackDecoder(TensorEnvelope, oob_tensor_provider=receiver)
|
||||
decoded = decoder.decode(encoded)
|
||||
|
||||
result_queue.put(
|
||||
{
|
||||
"success": True,
|
||||
"tensor_shape": tuple(decoded.tensor.shape),
|
||||
"device": str(decoded.tensor.device),
|
||||
"label": decoded.label,
|
||||
"matches_expected": tuple(decoded.tensor.shape) == expected_shape,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
retrieval_done.set()
|
||||
result_queue.put(
|
||||
{"success": False, "error": str(e), "traceback": traceback.format_exc()}
|
||||
)
|
||||
else:
|
||||
retrieval_done.set()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_cuda_tensor_queue_basic():
|
||||
"""Test CUDA tensor IPC through the msgpack encoder/decoder path."""
|
||||
tensor_queue = torch_mp.Queue()
|
||||
payload_queue: mp.Queue = mp.Queue()
|
||||
result_queue: mp.Queue = mp.Queue()
|
||||
encoder_ready = mp.Event()
|
||||
retrieval_done = mp.Event()
|
||||
|
||||
tensor_shape = (4, 8, 16)
|
||||
tensor_dtype = torch.float32
|
||||
|
||||
encoder_proc = mp.Process(
|
||||
target=encoder_process,
|
||||
args=(
|
||||
tensor_queue,
|
||||
payload_queue,
|
||||
result_queue,
|
||||
{"shape": tensor_shape, "dtype": tensor_dtype},
|
||||
encoder_ready,
|
||||
retrieval_done,
|
||||
),
|
||||
)
|
||||
encoder_proc.start()
|
||||
|
||||
decoder_proc = mp.Process(
|
||||
target=decoder_process,
|
||||
args=(
|
||||
tensor_queue,
|
||||
payload_queue,
|
||||
result_queue,
|
||||
tensor_shape,
|
||||
encoder_ready,
|
||||
retrieval_done,
|
||||
),
|
||||
)
|
||||
decoder_proc.start()
|
||||
|
||||
encoder_result = result_queue.get(timeout=10.0)
|
||||
decoder_result = result_queue.get(timeout=10.0)
|
||||
|
||||
encoder_proc.join(timeout=5.0)
|
||||
decoder_proc.join(timeout=5.0)
|
||||
|
||||
# Verify results
|
||||
assert encoder_result["success"], (
|
||||
f"Encoder failed: {encoder_result.get('error')}\n"
|
||||
f"{encoder_result.get('traceback', '')}"
|
||||
)
|
||||
assert decoder_result["success"], (
|
||||
f"Decoder failed: {decoder_result.get('error')}\n"
|
||||
f"{decoder_result.get('traceback', '')}"
|
||||
)
|
||||
assert decoder_result["matches_expected"], "Tensor shape mismatch"
|
||||
assert "cuda" in decoder_result["device"], "Tensor not on CUDA device"
|
||||
assert decoder_result["label"] == "cuda-msgpack"
|
||||
|
||||
|
||||
def test_cpu_tensor_fallback():
|
||||
"""Test that CPU tensors use standard serialization path."""
|
||||
encoder = MsgpackEncoder()
|
||||
|
||||
# Create a CPU tensor
|
||||
tensor = torch.randn(3, 4, dtype=torch.float32)
|
||||
|
||||
# Encode the tensor (should use standard path, not queue)
|
||||
encoded = encoder.encode({"test_tensor": tensor})
|
||||
|
||||
# Verify encoding succeeded
|
||||
assert len(encoded) > 0
|
||||
assert isinstance(encoded, (list, tuple))
|
||||
|
||||
# Basic check: no queue should be used, so tensor goes through standard path
|
||||
# This is mainly to ensure no exceptions are raised
|
||||
|
||||
|
||||
def test_msgpack_encoder_decoder_with_ipc():
|
||||
"""Test the full msgpack + tensor IPC path in one process."""
|
||||
tensor_queue = torch_mp.Queue()
|
||||
sender = TensorIpcSender(tensor_queue)
|
||||
encoder = MsgpackEncoder(oob_tensor_consumer=sender)
|
||||
receiver = TensorIpcReceiver(tensor_queue)
|
||||
decoder = MsgpackDecoder(TensorEnvelope, oob_tensor_provider=receiver)
|
||||
|
||||
# Use CPU here to exercise the msgpack + sender/receiver integration
|
||||
# without relying on same-process CUDA IPC behavior.
|
||||
tensor = torch.randn(2, 3)
|
||||
|
||||
message = TensorEnvelope(tensor=tensor, label="test")
|
||||
encoded = encoder.encode(message)
|
||||
assert len(encoded) > 0
|
||||
|
||||
decoded = decoder.decode(encoded)
|
||||
assert isinstance(decoded, TensorEnvelope)
|
||||
assert decoded.label == "test"
|
||||
assert torch.allclose(decoded.tensor, tensor)
|
||||
|
||||
|
||||
def test_decoder_buffer_management():
|
||||
"""Test receiver's tensor buffer management when draining queue."""
|
||||
tensor_queue = torch_mp.Queue()
|
||||
|
||||
sender_id = "test_sender"
|
||||
message_id = 1
|
||||
|
||||
# Put multiple tensors in queue using TensorIpcData
|
||||
tensors_data = [
|
||||
(0, torch.randn(2, 3)),
|
||||
(1, torch.randn(4, 5)),
|
||||
(2, torch.randn(6, 7)),
|
||||
]
|
||||
|
||||
for tensor_id, tensor in tensors_data:
|
||||
ipc_data = TensorIpcData(
|
||||
sender_id=sender_id,
|
||||
message_id=message_id,
|
||||
tensor_id=tensor_id,
|
||||
tensor=tensor,
|
||||
)
|
||||
tensor_queue.put(ipc_data)
|
||||
|
||||
# Create receiver directly
|
||||
receiver = TensorIpcReceiver(tensor_queue)
|
||||
|
||||
# Request tensor_id=2 (should buffer tensor_id=0 and tensor_id=1)
|
||||
handle = {"sender_id": sender_id, "message_id": message_id, "tensor_id": 2}
|
||||
|
||||
result = receiver("float32", (6, 7), handle)
|
||||
assert result.shape == (6, 7)
|
||||
|
||||
# Verify buffer has tensor_id 0 and 1
|
||||
sender = receiver._tensor_buffers[sender_id]
|
||||
tensors = sender.tensors.get(message_id, {})
|
||||
assert 0 in tensors
|
||||
assert 1 in tensors
|
||||
|
||||
# Request buffered tensor
|
||||
handle2 = {"sender_id": sender_id, "message_id": message_id, "tensor_id": 0}
|
||||
|
||||
result2 = receiver("float32", (2, 3), handle2)
|
||||
assert result2.shape == (2, 3)
|
||||
# tensor_id 0 should be removed from buffer
|
||||
sender = receiver._tensor_buffers[sender_id]
|
||||
tensors = sender.tensors.get(message_id, {})
|
||||
assert 0 not in tensors
|
||||
|
||||
|
||||
def api_server_worker(
|
||||
server_id: int,
|
||||
tensor_queue: torch_mp.Queue,
|
||||
result_queue: mp.Queue,
|
||||
barrier: BarrierType,
|
||||
retrieval_done: EventType,
|
||||
):
|
||||
"""Worker simulating an API server sending tensors."""
|
||||
try:
|
||||
# Each server sends a unique tensor
|
||||
tensor = torch.ones(server_id + 1, server_id + 2) * server_id
|
||||
sender_id = f"server_{server_id}"
|
||||
|
||||
# Wait for all servers to be ready
|
||||
barrier.wait()
|
||||
|
||||
# Send tensor using TensorIpcData
|
||||
ipc_data = TensorIpcData(
|
||||
sender_id=sender_id,
|
||||
message_id=0,
|
||||
tensor_id=0,
|
||||
tensor=tensor,
|
||||
)
|
||||
tensor_queue.put(ipc_data)
|
||||
|
||||
result_queue.put({"server_id": server_id, "success": True})
|
||||
|
||||
# Keep process alive until main process has retrieved all tensors
|
||||
# This prevents shared memory handles from being invalidated
|
||||
retrieval_done.wait(timeout=30.0)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
result_queue.put(
|
||||
{
|
||||
"server_id": server_id,
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"traceback": traceback.format_exc(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_multiple_api_servers_to_engine():
|
||||
"""Test multiple API servers sending to one engine core via multiprocessing."""
|
||||
num_api_servers = 3
|
||||
tensor_queue = torch_mp.Queue()
|
||||
result_queue: mp.Queue = mp.Queue()
|
||||
barrier = mp.Barrier(num_api_servers)
|
||||
retrieval_done = mp.Event()
|
||||
|
||||
# Start multiple API server processes
|
||||
processes = []
|
||||
for server_id in range(num_api_servers):
|
||||
proc = mp.Process(
|
||||
target=api_server_worker,
|
||||
args=(server_id, tensor_queue, result_queue, barrier, retrieval_done),
|
||||
)
|
||||
proc.start()
|
||||
processes.append(proc)
|
||||
|
||||
# Collect results from all servers
|
||||
results = []
|
||||
for _ in range(num_api_servers):
|
||||
result = result_queue.get(timeout=10.0)
|
||||
results.append(result)
|
||||
|
||||
# Verify all servers succeeded
|
||||
for result in results:
|
||||
assert result["success"], (
|
||||
f"Server {result['server_id']} failed: {result.get('error')}"
|
||||
)
|
||||
|
||||
# Verify all tensors are in queue
|
||||
received_tensors = []
|
||||
for _ in range(num_api_servers):
|
||||
ipc_data = tensor_queue.get(timeout=1.0)
|
||||
received_tensors.append((ipc_data.sender_id, ipc_data.tensor))
|
||||
|
||||
assert len(received_tensors) == num_api_servers
|
||||
|
||||
# Verify tensor content (order may vary with multiprocessing)
|
||||
tensor_by_sender = {sid: t for sid, t in received_tensors}
|
||||
for server_id in range(num_api_servers):
|
||||
expected_id = f"server_{server_id}"
|
||||
assert expected_id in tensor_by_sender, (
|
||||
f"Missing tensor from server {server_id}"
|
||||
)
|
||||
expected_tensor = torch.ones(server_id + 1, server_id + 2) * server_id
|
||||
assert torch.allclose(tensor_by_sender[expected_id], expected_tensor)
|
||||
|
||||
# Signal workers that retrieval is complete
|
||||
retrieval_done.set()
|
||||
|
||||
# Wait for all processes to complete
|
||||
for proc in processes:
|
||||
proc.join(timeout=5.0)
|
||||
|
||||
|
||||
def mixed_tensor_encoder_process(
|
||||
tensor_queue: torch_mp.Queue,
|
||||
result_queue: mp.Queue,
|
||||
ready_event: EventType,
|
||||
retrieval_done: EventType,
|
||||
):
|
||||
"""Process that encodes mixed CPU/CUDA tensors."""
|
||||
try:
|
||||
sender = TensorIpcSender(tensor_queue)
|
||||
_encoder = MsgpackEncoder(oob_tensor_consumer=sender)
|
||||
|
||||
# Create only CUDA tensor for IPC (CPU will be serialized)
|
||||
# But actually, let's just send CUDA tensor directly
|
||||
cuda_tensor = torch.randn(4, 5, device="cuda:0")
|
||||
|
||||
# Manually send via IPC to test the mechanism
|
||||
cuda_tensor_shared = cuda_tensor.share_memory_()
|
||||
|
||||
ipc_data = TensorIpcData(
|
||||
sender_id="mixed_encoder",
|
||||
message_id=0,
|
||||
tensor_id=0,
|
||||
tensor=cuda_tensor_shared,
|
||||
)
|
||||
tensor_queue.put(ipc_data, timeout=10.0)
|
||||
|
||||
ready_event.set()
|
||||
|
||||
result_queue.put({"success": True, "sent_cuda": True})
|
||||
|
||||
# Keep process alive until decoder has retrieved the tensor
|
||||
retrieval_done.wait(timeout=30.0)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
ready_event.set()
|
||||
result_queue.put(
|
||||
{"success": False, "error": str(e), "traceback": traceback.format_exc()}
|
||||
)
|
||||
|
||||
|
||||
def mixed_tensor_decoder_process(
|
||||
tensor_queue: torch_mp.Queue,
|
||||
result_queue: mp.Queue,
|
||||
encoder_ready: EventType,
|
||||
retrieval_done: EventType,
|
||||
):
|
||||
"""Process that retrieves mixed tensors from queue."""
|
||||
try:
|
||||
# Wait for encoder to finish
|
||||
if not encoder_ready.wait(timeout=10.0):
|
||||
raise TimeoutError("Encoder did not signal ready")
|
||||
|
||||
# Try to get CUDA tensor from queue
|
||||
ipc_data = tensor_queue.get(timeout=5.0)
|
||||
|
||||
result_queue.put(
|
||||
{
|
||||
"success": True,
|
||||
"is_cuda": ipc_data.tensor.is_cuda,
|
||||
"shape": tuple(ipc_data.tensor.shape),
|
||||
}
|
||||
)
|
||||
|
||||
# Signal that retrieval is complete
|
||||
retrieval_done.set()
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
retrieval_done.set() # Signal even on failure
|
||||
result_queue.put(
|
||||
{"success": False, "error": str(e), "traceback": traceback.format_exc()}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_mixed_cpu_cuda_tensors():
|
||||
"""Test encoding with mixed CPU and CUDA tensors using multiprocessing."""
|
||||
tensor_queue = torch_mp.Queue()
|
||||
result_queue: mp.Queue = mp.Queue()
|
||||
encoder_ready = mp.Event()
|
||||
retrieval_done = mp.Event()
|
||||
|
||||
# Start encoder process
|
||||
encoder_proc = mp.Process(
|
||||
target=mixed_tensor_encoder_process,
|
||||
args=(tensor_queue, result_queue, encoder_ready, retrieval_done),
|
||||
)
|
||||
encoder_proc.start()
|
||||
|
||||
# Start decoder process
|
||||
decoder_proc = mp.Process(
|
||||
target=mixed_tensor_decoder_process,
|
||||
args=(tensor_queue, result_queue, encoder_ready, retrieval_done),
|
||||
)
|
||||
decoder_proc.start()
|
||||
|
||||
# Get results
|
||||
encoder_result = result_queue.get(timeout=10.0)
|
||||
decoder_result = result_queue.get(timeout=10.0)
|
||||
|
||||
encoder_proc.join(timeout=5.0)
|
||||
decoder_proc.join(timeout=5.0)
|
||||
|
||||
# Verify encoder succeeded
|
||||
assert encoder_result["success"], (
|
||||
f"Encoder failed: {encoder_result.get('error')}\n"
|
||||
f"{encoder_result.get('traceback', '')}"
|
||||
)
|
||||
|
||||
# Verify decoder succeeded and got CUDA tensor
|
||||
assert decoder_result["success"], (
|
||||
f"Decoder failed: {decoder_result.get('error')}\n"
|
||||
f"{decoder_result.get('traceback', '')}"
|
||||
)
|
||||
assert decoder_result["is_cuda"], "Retrieved tensor is not on CUDA"
|
||||
assert decoder_result["shape"] == (4, 5), (
|
||||
f"Unexpected shape: {decoder_result['shape']}"
|
||||
)
|
||||
|
||||
|
||||
def cpu_tensor_ipc_encoder_process(
|
||||
tensor_queue: torch_mp.Queue,
|
||||
result_queue: mp.Queue,
|
||||
tensor_shape: tuple,
|
||||
ready_event: EventType,
|
||||
retrieval_done: EventType,
|
||||
):
|
||||
"""Process that encodes and sends CPU tensors via IPC queue."""
|
||||
try:
|
||||
# Create encoder with IPC enabled for all tensors
|
||||
sender = TensorIpcSender(tensor_queue)
|
||||
encoder = MsgpackEncoder(oob_tensor_consumer=sender)
|
||||
|
||||
# Create a CPU tensor
|
||||
tensor = torch.randn(*tensor_shape, dtype=torch.float32)
|
||||
|
||||
# Encode the tensor (should use IPC queue, not standard serialization)
|
||||
encoded = encoder.encode({"test_tensor": tensor})
|
||||
|
||||
# Signal that encoding is complete
|
||||
ready_event.set()
|
||||
|
||||
result_queue.put(
|
||||
{
|
||||
"success": True,
|
||||
"encoded_length": len(encoded),
|
||||
"device": str(tensor.device),
|
||||
"tensor_shape": tuple(tensor.shape),
|
||||
}
|
||||
)
|
||||
|
||||
# Keep process alive until decoder has retrieved the tensor
|
||||
# This is necessary for CPU tensor shared memory to remain valid
|
||||
retrieval_done.wait(timeout=30.0)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
ready_event.set()
|
||||
result_queue.put(
|
||||
{"success": False, "error": str(e), "traceback": traceback.format_exc()}
|
||||
)
|
||||
|
||||
|
||||
def cpu_tensor_ipc_decoder_process(
|
||||
tensor_queue: torch_mp.Queue,
|
||||
result_queue: mp.Queue,
|
||||
expected_shape: tuple,
|
||||
encoder_ready: EventType,
|
||||
retrieval_done: EventType,
|
||||
):
|
||||
"""Process that decodes and receives CPU tensors from IPC queue."""
|
||||
try:
|
||||
# Wait for encoder to finish sending
|
||||
if not encoder_ready.wait(timeout=10.0):
|
||||
raise TimeoutError("Encoder did not signal ready")
|
||||
|
||||
# Get tensor from queue
|
||||
ipc_data = tensor_queue.get(timeout=5.0)
|
||||
|
||||
result_queue.put(
|
||||
{
|
||||
"success": True,
|
||||
"tensor_id": ipc_data.tensor_id,
|
||||
"tensor_shape": tuple(ipc_data.tensor.shape),
|
||||
"device": str(ipc_data.tensor.device),
|
||||
"matches_expected": tuple(ipc_data.tensor.shape) == expected_shape,
|
||||
"is_cpu": ipc_data.tensor.device.type == "cpu",
|
||||
}
|
||||
)
|
||||
|
||||
# Signal that retrieval is complete
|
||||
retrieval_done.set()
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
retrieval_done.set() # Signal even on failure
|
||||
result_queue.put(
|
||||
{"success": False, "error": str(e), "traceback": traceback.format_exc()}
|
||||
)
|
||||
|
||||
|
||||
def test_cpu_tensor_ipc():
|
||||
"""Test CPU tensor sharing via IPC queue when mm_tensor_ipc is enabled."""
|
||||
# Set up single queue and synchronization
|
||||
tensor_queue = torch_mp.Queue()
|
||||
result_queue: mp.Queue = mp.Queue()
|
||||
encoder_ready = mp.Event()
|
||||
retrieval_done = mp.Event()
|
||||
|
||||
tensor_shape = (3, 5, 7)
|
||||
|
||||
# Start encoder process
|
||||
encoder_proc = mp.Process(
|
||||
target=cpu_tensor_ipc_encoder_process,
|
||||
args=(
|
||||
tensor_queue,
|
||||
result_queue,
|
||||
tensor_shape,
|
||||
encoder_ready,
|
||||
retrieval_done,
|
||||
),
|
||||
)
|
||||
encoder_proc.start()
|
||||
|
||||
# Start decoder process
|
||||
decoder_proc = mp.Process(
|
||||
target=cpu_tensor_ipc_decoder_process,
|
||||
args=(
|
||||
tensor_queue,
|
||||
result_queue,
|
||||
tensor_shape,
|
||||
encoder_ready,
|
||||
retrieval_done,
|
||||
),
|
||||
)
|
||||
decoder_proc.start()
|
||||
|
||||
# Wait for processes and collect results
|
||||
encoder_result = result_queue.get(timeout=10.0)
|
||||
decoder_result = result_queue.get(timeout=10.0)
|
||||
|
||||
encoder_proc.join(timeout=5.0)
|
||||
decoder_proc.join(timeout=5.0)
|
||||
|
||||
# Verify results
|
||||
assert encoder_result["success"], (
|
||||
f"Encoder failed: {encoder_result.get('error')}\n"
|
||||
f"{encoder_result.get('traceback', '')}"
|
||||
)
|
||||
assert decoder_result["success"], (
|
||||
f"Decoder failed: {decoder_result.get('error')}\n"
|
||||
f"{decoder_result.get('traceback', '')}"
|
||||
)
|
||||
assert decoder_result["matches_expected"], "Tensor shape mismatch"
|
||||
assert decoder_result["is_cpu"], "Tensor not on CPU device"
|
||||
|
||||
|
||||
def test_ipc_disabled_mode():
|
||||
"""Test that IPC is disabled when no sender is provided."""
|
||||
tensor_queues = [torch_mp.Queue()]
|
||||
|
||||
# Create encoder without IPC sender (IPC disabled)
|
||||
encoder = MsgpackEncoder()
|
||||
|
||||
# Create a CPU tensor
|
||||
cpu_tensor = torch.randn(2, 3, dtype=torch.float32)
|
||||
|
||||
# Encode the tensor (should use standard serialization, not IPC)
|
||||
encoded = encoder.encode({"test_tensor": cpu_tensor})
|
||||
|
||||
# Verify encoding succeeded
|
||||
assert len(encoded) > 0
|
||||
assert isinstance(encoded, (list, tuple))
|
||||
|
||||
# Verify queue is empty (no IPC was used)
|
||||
assert tensor_queues[0].empty(), "Tensor queue should be empty when IPC is disabled"
|
||||
|
||||
# If CUDA is available, test with CUDA tensor too
|
||||
if torch.cuda.is_available():
|
||||
cuda_tensor = torch.randn(4, 5, device="cuda:0")
|
||||
encoded_cuda = encoder.encode({"cuda_tensor": cuda_tensor})
|
||||
assert len(encoded_cuda) > 0
|
||||
assert tensor_queues[0].empty(), (
|
||||
"Tensor queue should be empty for CUDA tensor when IPC is disabled"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiTensorMessage:
|
||||
"""Message with multiple tensors to test multi-tensor IPC."""
|
||||
|
||||
t1: torch.Tensor
|
||||
t2: torch.Tensor
|
||||
sender_label: str
|
||||
|
||||
|
||||
def concurrent_sender_process(
|
||||
tensor_queue: torch_mp.Queue,
|
||||
payload_queue: mp.Queue,
|
||||
result_queue: mp.Queue,
|
||||
sender_index: int,
|
||||
num_messages: int,
|
||||
barrier: BarrierType,
|
||||
retrieval_done: EventType,
|
||||
):
|
||||
"""Process that acts as one of N concurrent senders."""
|
||||
try:
|
||||
sender = TensorIpcSender(tensor_queue)
|
||||
encoder = MsgpackEncoder(oob_tensor_consumer=sender)
|
||||
|
||||
# Wait for all senders to be ready before sending
|
||||
barrier.wait(timeout=10.0)
|
||||
|
||||
encoded_payloads = []
|
||||
for msg_idx in range(num_messages):
|
||||
# Each sender creates uniquely-shaped tensors so we can
|
||||
# verify correct routing on the receiver side.
|
||||
t1 = torch.full((sender_index + 1, 3), float(msg_idx), dtype=torch.float32)
|
||||
t2 = torch.full(
|
||||
(2, sender_index + 2), float(msg_idx + 100), dtype=torch.float64
|
||||
)
|
||||
msg = MultiTensorMessage(
|
||||
t1=t1,
|
||||
t2=t2,
|
||||
sender_label=f"sender_{sender_index}_msg_{msg_idx}",
|
||||
)
|
||||
encoded = encoder.encode(msg)
|
||||
encoded_payloads.append(encoded)
|
||||
|
||||
# Send all encoded payloads via the regular (non-tensor) queue
|
||||
for encoded in encoded_payloads:
|
||||
payload_queue.put(encoded, timeout=10.0)
|
||||
|
||||
result_queue.put(
|
||||
{
|
||||
"success": True,
|
||||
"sender_index": sender_index,
|
||||
"num_sent": num_messages,
|
||||
}
|
||||
)
|
||||
|
||||
# Keep alive so shared-memory handles remain valid
|
||||
retrieval_done.wait(timeout=30.0)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
result_queue.put(
|
||||
{
|
||||
"success": False,
|
||||
"sender_index": sender_index,
|
||||
"error": str(e),
|
||||
"traceback": traceback.format_exc(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_concurrent_senders_single_receiver():
|
||||
"""Test N concurrent senders sharing one queue with a single receiver.
|
||||
|
||||
Each sender encodes multiple messages (each containing two tensors) via
|
||||
its own MsgpackEncoder + TensorIpcSender. A single TensorIpcReceiver
|
||||
on the receiving side must correctly drain-and-buffer interleaved
|
||||
TensorIpcData items from the shared queue and match them back to the
|
||||
right message handles during decode.
|
||||
"""
|
||||
num_senders = 4
|
||||
num_messages_per_sender = 3
|
||||
tensor_queue = torch_mp.Queue()
|
||||
payload_queue: mp.Queue = mp.Queue()
|
||||
result_queue: mp.Queue = mp.Queue()
|
||||
barrier = mp.Barrier(num_senders)
|
||||
retrieval_done = mp.Event()
|
||||
|
||||
# Launch sender processes
|
||||
processes = []
|
||||
for i in range(num_senders):
|
||||
proc = mp.Process(
|
||||
target=concurrent_sender_process,
|
||||
args=(
|
||||
tensor_queue,
|
||||
payload_queue,
|
||||
result_queue,
|
||||
i,
|
||||
num_messages_per_sender,
|
||||
barrier,
|
||||
retrieval_done,
|
||||
),
|
||||
)
|
||||
proc.start()
|
||||
processes.append(proc)
|
||||
|
||||
# Collect send confirmations
|
||||
send_results = []
|
||||
for _ in range(num_senders):
|
||||
send_results.append(result_queue.get(timeout=15.0))
|
||||
for r in send_results:
|
||||
assert r["success"], (
|
||||
f"Sender {r['sender_index']} failed: {r.get('error')}\n"
|
||||
f"{r.get('traceback', '')}"
|
||||
)
|
||||
|
||||
# Now decode all messages from the main process using a single receiver
|
||||
receiver = TensorIpcReceiver(tensor_queue)
|
||||
decoder = MsgpackDecoder(MultiTensorMessage, oob_tensor_provider=receiver)
|
||||
|
||||
decoded_messages: list[MultiTensorMessage] = []
|
||||
total = num_senders * num_messages_per_sender
|
||||
for _ in range(total):
|
||||
encoded = payload_queue.get(timeout=10.0)
|
||||
decoded = decoder.decode(encoded)
|
||||
assert isinstance(decoded, MultiTensorMessage)
|
||||
decoded_messages.append(decoded)
|
||||
|
||||
# Signal senders they can exit
|
||||
retrieval_done.set()
|
||||
|
||||
# Group by sender_label prefix to verify all messages arrived
|
||||
by_sender: dict[int, list[MultiTensorMessage]] = {}
|
||||
for msg in decoded_messages:
|
||||
# label format: "sender_{i}_msg_{j}"
|
||||
parts = msg.sender_label.split("_")
|
||||
sender_idx = int(parts[1])
|
||||
by_sender.setdefault(sender_idx, []).append(msg)
|
||||
|
||||
assert len(by_sender) == num_senders, (
|
||||
f"Expected {num_senders} senders, got {len(by_sender)}"
|
||||
)
|
||||
|
||||
for sender_idx in range(num_senders):
|
||||
msgs = sorted(by_sender[sender_idx], key=lambda m: m.sender_label)
|
||||
assert len(msgs) == num_messages_per_sender, (
|
||||
f"Sender {sender_idx}: expected {num_messages_per_sender} "
|
||||
f"messages, got {len(msgs)}"
|
||||
)
|
||||
for msg_idx, msg in enumerate(msgs):
|
||||
assert msg.sender_label == f"sender_{sender_idx}_msg_{msg_idx}"
|
||||
# Verify tensor shapes match what the sender created
|
||||
assert msg.t1.shape == (sender_idx + 1, 3)
|
||||
assert msg.t2.shape == (2, sender_idx + 2)
|
||||
# Verify tensor values
|
||||
assert torch.allclose(msg.t1, torch.full_like(msg.t1, float(msg_idx)))
|
||||
assert torch.allclose(msg.t2, torch.full_like(msg.t2, float(msg_idx + 100)))
|
||||
|
||||
for proc in processes:
|
||||
proc.join(timeout=5.0)
|
||||
|
||||
|
||||
def test_concurrent_senders_interleaved_buffer():
|
||||
"""Test receiver buffering when tensors from multiple senders interleave.
|
||||
|
||||
Manually enqueue TensorIpcData from two senders in an interleaved order
|
||||
and verify the receiver correctly buffers and retrieves each tensor by
|
||||
its (sender_id, message_id, tensor_id) handle.
|
||||
"""
|
||||
tensor_queue = torch_mp.Queue()
|
||||
|
||||
# Sender A: 2 tensors for message 1
|
||||
a_t0 = torch.randn(2, 3)
|
||||
a_t1 = torch.randn(4, 5)
|
||||
# Sender B: 2 tensors for message 1
|
||||
b_t0 = torch.randn(6, 7)
|
||||
b_t1 = torch.randn(8, 9)
|
||||
|
||||
# Interleave: B_t0, A_t0, B_t1, A_t1
|
||||
for sid, mid, tid, t in [
|
||||
("B", 1, 0, b_t0),
|
||||
("A", 1, 0, a_t0),
|
||||
("B", 1, 1, b_t1),
|
||||
("A", 1, 1, a_t1),
|
||||
]:
|
||||
tensor_queue.put(
|
||||
TensorIpcData(sender_id=sid, message_id=mid, tensor_id=tid, tensor=t)
|
||||
)
|
||||
|
||||
receiver = TensorIpcReceiver(tensor_queue)
|
||||
|
||||
# Request A_t1 first — receiver must drain and buffer B_t0, A_t0, B_t1
|
||||
result = receiver(
|
||||
"float32", a_t1.shape, {"sender_id": "A", "message_id": 1, "tensor_id": 1}
|
||||
)
|
||||
assert torch.equal(result, a_t1)
|
||||
|
||||
# Now request B_t0 from buffer
|
||||
result = receiver(
|
||||
"float32", b_t0.shape, {"sender_id": "B", "message_id": 1, "tensor_id": 0}
|
||||
)
|
||||
assert torch.equal(result, b_t0)
|
||||
|
||||
# Request A_t0 from buffer
|
||||
result = receiver(
|
||||
"float32", a_t0.shape, {"sender_id": "A", "message_id": 1, "tensor_id": 0}
|
||||
)
|
||||
assert torch.equal(result, a_t0)
|
||||
|
||||
# Request B_t1 from buffer
|
||||
result = receiver(
|
||||
"float64", b_t1.shape, {"sender_id": "B", "message_id": 1, "tensor_id": 1}
|
||||
)
|
||||
assert torch.equal(result, b_t1)
|
||||
|
||||
# All buffers should be drained
|
||||
for sid in ("A", "B"):
|
||||
tensors = receiver._tensor_buffers[sid].tensors.get(1, {})
|
||||
assert len(tensors) == 0, f"Sender {sid} buffer not empty: {tensors}"
|
||||
|
||||
|
||||
def test_mixed_cpu_cuda_with_ipc_enabled():
|
||||
"""Test that encoder is configured correctly for IPC with all tensor types."""
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA not available")
|
||||
|
||||
tensor_queue = torch_mp.Queue()
|
||||
|
||||
# Create sender and encoder with IPC enabled
|
||||
sender = TensorIpcSender(tensor_queue)
|
||||
encoder = MsgpackEncoder(oob_tensor_consumer=sender)
|
||||
|
||||
# Verify sender configuration
|
||||
assert encoder.oob_tensor_consumer is not None, "Consumer should be set"
|
||||
|
||||
# Note: Actual IPC transfer only works across processes
|
||||
# (tested in test_cpu_tensor_ipc)
|
||||
# This test just verifies the configuration is correct
|
||||
|
||||
|
||||
def test_tensor_cleanup_after_decode():
|
||||
"""Test that tensors are removed from tracking after successful decode."""
|
||||
# Create a tensor queue
|
||||
tensor_queue = torch_mp.Queue()
|
||||
|
||||
# Create and encode a tensor
|
||||
tensor = torch.randn(5, 5)
|
||||
# Move to shared memory for IPC
|
||||
if not tensor.is_shared():
|
||||
tensor.share_memory_()
|
||||
|
||||
# Manually create a TensorIpcData and put it in the queue
|
||||
sender_id = "test_sender"
|
||||
message_id = 0
|
||||
tensor_id = 0
|
||||
ipc_data = TensorIpcData(
|
||||
sender_id=sender_id,
|
||||
message_id=message_id,
|
||||
tensor_id=tensor_id,
|
||||
tensor=tensor,
|
||||
)
|
||||
tensor_queue.put(ipc_data)
|
||||
|
||||
# Create receiver directly
|
||||
receiver = TensorIpcReceiver(tensor_queue)
|
||||
|
||||
handle = {
|
||||
"sender_id": sender_id,
|
||||
"message_id": message_id,
|
||||
"tensor_id": tensor_id,
|
||||
}
|
||||
|
||||
# Receive the tensor - this should retrieve it from the queue
|
||||
decoded_tensor = receiver(
|
||||
str(tensor.dtype).removeprefix("torch."), tensor.shape, handle
|
||||
)
|
||||
|
||||
# Verify the tensor was decoded
|
||||
assert decoded_tensor.shape == tensor.shape, "Decoded tensor should match shape"
|
||||
|
||||
# Verify the tensor was removed from buffer after decode
|
||||
sender = receiver._tensor_buffers[sender_id]
|
||||
tensors = sender.tensors.get(message_id, {})
|
||||
assert tensor_id not in tensors, "Tensor should be removed from buffer"
|
||||
@@ -14,7 +14,12 @@ import vllm.envs as envs
|
||||
from vllm.config.model_arch import (
|
||||
ModelArchitectureConfig,
|
||||
)
|
||||
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig
|
||||
from vllm.config.multimodal import (
|
||||
MMCacheType,
|
||||
MMEncoderTPMode,
|
||||
MMTensorIPC,
|
||||
MultiModalConfig,
|
||||
)
|
||||
from vllm.config.pooler import PoolerConfig
|
||||
from vllm.config.scheduler import RunnerType
|
||||
from vllm.config.utils import config, getattr_iter
|
||||
@@ -310,6 +315,7 @@ class ModelConfig:
|
||||
interleave_mm_strings: InitVar[bool | None] = None
|
||||
skip_mm_profiling: InitVar[bool | None] = None
|
||||
video_pruning_rate: InitVar[float | None] = None
|
||||
mm_tensor_ipc: InitVar[MMTensorIPC] = None
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
@@ -430,6 +436,7 @@ class ModelConfig:
|
||||
interleave_mm_strings: bool | None,
|
||||
skip_mm_profiling: bool | None,
|
||||
video_pruning_rate: float | None,
|
||||
mm_tensor_ipc: MMTensorIPC,
|
||||
) -> None:
|
||||
# Keep set served_model_name before maybe_model_redirect(self.model)
|
||||
self.served_model_name = get_served_model_name(
|
||||
@@ -612,6 +619,7 @@ class ModelConfig:
|
||||
interleave_mm_strings=interleave_mm_strings,
|
||||
skip_mm_profiling=skip_mm_profiling,
|
||||
video_pruning_rate=video_pruning_rate,
|
||||
mm_tensor_ipc=mm_tensor_ipc,
|
||||
)
|
||||
|
||||
mm_config_kwargs = {
|
||||
@@ -1112,6 +1120,22 @@ class ModelConfig:
|
||||
f"({parallel_config.decode_context_parallel_size})."
|
||||
)
|
||||
|
||||
# torch_shm uses a single IPC queue to rank 0; DP>1 is
|
||||
# incompatible because API servers can't know which
|
||||
# CoreEngine the scheduler will assign work to. TP>1 is
|
||||
# also not supported because this requires broadcasting
|
||||
# MM tensors between all TP ranks.
|
||||
if (
|
||||
self.multimodal_config is not None
|
||||
and self.multimodal_config.mm_tensor_ipc == "torch_shm"
|
||||
and parallel_config.world_size_across_dp > 1
|
||||
):
|
||||
raise ValueError(
|
||||
"mm_tensor_ipc='torch_shm' is not supported with "
|
||||
"data_parallel_size > 1 or tensor_parallel_size > 1 "
|
||||
"or pipeline_parallel_size > 1."
|
||||
)
|
||||
|
||||
def get_sliding_window(self) -> int | None:
|
||||
"""Get the sliding window size from the HF text config if present."""
|
||||
return getattr(self.hf_text_config, "sliding_window", None)
|
||||
|
||||
@@ -59,6 +59,7 @@ class MultiModalDummyOptionsBuiltins(TypedDict, total=False):
|
||||
|
||||
MMEncoderTPMode = Literal["weights", "data"]
|
||||
MMCacheType = Literal["shm", "lru"]
|
||||
MMTensorIPC = Literal["direct_rpc", "torch_shm"]
|
||||
MMDummyOptions: TypeAlias = dict[str, BaseDummyOptions]
|
||||
"""
|
||||
A dictionary containing an entry for each modality type of dummy data.
|
||||
@@ -172,6 +173,11 @@ class MultiModalConfig:
|
||||
Value sits in range [0;1) and determines fraction of media tokens
|
||||
from each video to be pruned.
|
||||
"""
|
||||
mm_tensor_ipc: MMTensorIPC = "direct_rpc"
|
||||
"""IPC (inter-process communication) method for multimodal tensors.
|
||||
- "direct_rpc": Use msgspec serialization via RPC
|
||||
- "torch_shm": Use torch.multiprocessing shared memory for zero-copy IPC
|
||||
Defaults to "direct_rpc". """
|
||||
|
||||
@field_validator("limit_per_prompt", mode="before")
|
||||
@classmethod
|
||||
|
||||
@@ -766,6 +766,17 @@ class VllmConfig:
|
||||
else:
|
||||
self.parallel_config.disable_nccl_for_dp_synchronization = False
|
||||
|
||||
if (
|
||||
self.model_config is not None
|
||||
and self.model_config.multimodal_config is not None
|
||||
and self.model_config.multimodal_config.mm_tensor_ipc == "torch_shm"
|
||||
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
|
||||
):
|
||||
raise ValueError(
|
||||
"torch_shm is known to fail without "
|
||||
"VLLM_WORKER_MULTIPROC_METHOD set to spawn"
|
||||
)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if (
|
||||
|
||||
@@ -79,7 +79,7 @@ from vllm.config.model import (
|
||||
RunnerOption,
|
||||
TokenizerMode,
|
||||
)
|
||||
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode
|
||||
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MMTensorIPC
|
||||
from vllm.config.observability import DetailedTraceModules
|
||||
from vllm.config.parallel import (
|
||||
All2AllBackend,
|
||||
@@ -509,6 +509,7 @@ class EngineArgs:
|
||||
io_processor_plugin: str | None = None
|
||||
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
|
||||
video_pruning_rate: float | None = MultiModalConfig.video_pruning_rate
|
||||
mm_tensor_ipc: MMTensorIPC = MultiModalConfig.mm_tensor_ipc
|
||||
# LoRA fields
|
||||
enable_lora: bool = False
|
||||
max_loras: int = LoRAConfig.max_loras
|
||||
@@ -1097,6 +1098,9 @@ class EngineArgs:
|
||||
multimodal_group.add_argument(
|
||||
"--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"]
|
||||
)
|
||||
multimodal_group.add_argument(
|
||||
"--mm-tensor-ipc", **multimodal_kwargs["mm_tensor_ipc"]
|
||||
)
|
||||
|
||||
# LoRA related configs
|
||||
lora_kwargs = get_kwargs(LoRAConfig)
|
||||
@@ -1423,6 +1427,7 @@ class EngineArgs:
|
||||
override_attention_dtype=self.override_attention_dtype,
|
||||
logits_processors=self.logits_processors,
|
||||
video_pruning_rate=self.video_pruning_rate,
|
||||
mm_tensor_ipc=self.mm_tensor_ipc,
|
||||
io_processor_plugin=self.io_processor_plugin,
|
||||
)
|
||||
|
||||
|
||||
@@ -290,7 +290,7 @@ def run_multi_api_server(args: argparse.Namespace):
|
||||
|
||||
with launch_core_engines(
|
||||
vllm_config, executor_class, log_stats, addresses, num_api_servers
|
||||
) as (local_engine_manager, coordinator, addresses):
|
||||
) as (local_engine_manager, coordinator, addresses, tensor_queue):
|
||||
# Construct common args for the APIServerProcessManager up-front.
|
||||
api_server_manager_kwargs = dict(
|
||||
target_server_fn=run_api_server_worker_proc,
|
||||
@@ -303,6 +303,7 @@ def run_multi_api_server(args: argparse.Namespace):
|
||||
stats_update_address=coordinator.get_stats_publish_address()
|
||||
if coordinator
|
||||
else None,
|
||||
tensor_queue=tensor_queue,
|
||||
)
|
||||
|
||||
# For dp ranks > 0 in external/hybrid DP LB modes, we must delay the
|
||||
|
||||
@@ -13,6 +13,7 @@ from enum import IntEnum
|
||||
from functools import partial
|
||||
from inspect import isclass, signature
|
||||
from logging import DEBUG
|
||||
from multiprocessing.queues import Queue
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
import msgspec
|
||||
@@ -59,6 +60,7 @@ from vllm.v1.engine import (
|
||||
UtilityOutput,
|
||||
UtilityResult,
|
||||
)
|
||||
from vllm.v1.engine.tensor_ipc import TensorIpcReceiver
|
||||
from vllm.v1.engine.utils import (
|
||||
EngineHandshakeMetadata,
|
||||
EngineZmqAddresses,
|
||||
@@ -788,6 +790,7 @@ class EngineCoreProc(EngineCore):
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
client_handshake_address: str | None = None,
|
||||
tensor_queue: Queue | None = None,
|
||||
*,
|
||||
engine_index: int = 0,
|
||||
):
|
||||
@@ -802,6 +805,12 @@ class EngineCoreProc(EngineCore):
|
||||
self.engines_running = False
|
||||
self.shutdown_state = EngineShutdownState.RUNNING
|
||||
|
||||
# Receiver for tensor IPC
|
||||
self.tensor_ipc_receiver: TensorIpcReceiver | None = None
|
||||
if tensor_queue is not None:
|
||||
self.tensor_ipc_receiver = TensorIpcReceiver(tensor_queue)
|
||||
logger.info("Using tensor IPC queue for multimodal tensor sharing")
|
||||
|
||||
with self._perform_handshakes(
|
||||
handshake_address,
|
||||
identity,
|
||||
@@ -1340,9 +1349,11 @@ class EngineCoreProc(EngineCore):
|
||||
):
|
||||
"""Input socket IO thread."""
|
||||
|
||||
# Msgpack serialization decoding.
|
||||
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
|
||||
generic_decoder = MsgpackDecoder()
|
||||
# Msgpack serialization decoding with optional tensor IPC receiver.
|
||||
add_request_decoder = MsgpackDecoder(
|
||||
EngineCoreRequest, oob_tensor_provider=self.tensor_ipc_receiver
|
||||
)
|
||||
generic_decoder = MsgpackDecoder(oob_tensor_provider=self.tensor_ipc_receiver)
|
||||
|
||||
with ExitStack() as stack, zmq.Context() as ctx:
|
||||
input_sockets = [
|
||||
@@ -1418,10 +1429,7 @@ class EngineCoreProc(EngineCore):
|
||||
self.input_queue.put_nowait((request_type, request))
|
||||
|
||||
def process_output_sockets(
|
||||
self,
|
||||
output_paths: list[str],
|
||||
coord_output_path: str | None,
|
||||
engine_index: int,
|
||||
self, output_paths: list[str], coord_output_path: str | None, engine_index: int
|
||||
):
|
||||
"""Output socket IO thread."""
|
||||
|
||||
@@ -1580,6 +1588,7 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
client_handshake_address: str | None = None,
|
||||
tensor_queue: Queue | None = None,
|
||||
):
|
||||
assert vllm_config.model_config.is_moe, (
|
||||
"DPEngineCoreProc should only be used for MoE models"
|
||||
@@ -1605,6 +1614,7 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
log_stats,
|
||||
client_handshake_address,
|
||||
engine_index=dp_rank,
|
||||
tensor_queue=tensor_queue,
|
||||
)
|
||||
|
||||
def _init_data_parallel(self, vllm_config: VllmConfig):
|
||||
|
||||
@@ -12,6 +12,7 @@ from collections import defaultdict, deque
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from concurrent.futures import Future
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing.queues import Queue
|
||||
from threading import Thread
|
||||
from typing import Any, TypeAlias, TypeVar
|
||||
|
||||
@@ -45,6 +46,7 @@ from vllm.v1.engine import (
|
||||
from vllm.v1.engine.coordinator import DPCoordinator
|
||||
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
||||
from vllm.v1.engine.exceptions import EngineDeadError
|
||||
from vllm.v1.engine.tensor_ipc import TensorIpcSender
|
||||
from vllm.v1.engine.utils import (
|
||||
CoreEngineActorManager,
|
||||
CoreEngineProcManager,
|
||||
@@ -477,9 +479,6 @@ class MPClient(EngineCoreClient):
|
||||
client_addresses: dict[str, str] | None = None,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
# Serialization setup.
|
||||
self.encoder = MsgpackEncoder()
|
||||
self.decoder = MsgpackDecoder(EngineCoreOutputs)
|
||||
|
||||
# ZMQ setup.
|
||||
sync_ctx = zmq.Context(io_threads=2)
|
||||
@@ -501,11 +500,14 @@ class MPClient(EngineCoreClient):
|
||||
enable_input_socket_handover = parallel_config.enable_elastic_ep
|
||||
|
||||
self.stats_update_address: str | None = None
|
||||
tensor_queue: Queue | None = None
|
||||
if client_addresses:
|
||||
# Engines are managed externally to this client.
|
||||
input_address = client_addresses["input_address"]
|
||||
output_address = client_addresses["output_address"]
|
||||
self.stats_update_address = client_addresses.get("stats_update_address")
|
||||
# Tensor queues passed via client_addresses for multi-API-server case
|
||||
tensor_queue = client_addresses.get("tensor_queue") # type: ignore[assignment]
|
||||
self.input_socket = self.resources.input_socket = make_zmq_socket(
|
||||
self.ctx,
|
||||
input_address,
|
||||
@@ -532,7 +534,7 @@ class MPClient(EngineCoreClient):
|
||||
|
||||
with launch_core_engines(
|
||||
vllm_config, executor_class, log_stats, addresses
|
||||
) as (engine_manager, coordinator, addresses):
|
||||
) as (engine_manager, coordinator, addresses, tensor_queue):
|
||||
self.resources.coordinator = coordinator
|
||||
self.resources.engine_manager = engine_manager
|
||||
|
||||
@@ -542,6 +544,17 @@ class MPClient(EngineCoreClient):
|
||||
coordinator.get_stats_publish_address()
|
||||
)
|
||||
|
||||
# Serialization setup with tensor queues for multimodal tensor IPC.
|
||||
tensor_ipc_sender: TensorIpcSender | None = None
|
||||
model_config = getattr(vllm_config, "model_config", None)
|
||||
if model_config is not None and model_config.multimodal_config is not None:
|
||||
mm_tensor_ipc = model_config.multimodal_config.mm_tensor_ipc
|
||||
if mm_tensor_ipc == "torch_shm" and tensor_queue is not None:
|
||||
tensor_ipc_sender = TensorIpcSender(tensor_queue)
|
||||
|
||||
self.encoder = MsgpackEncoder(oob_tensor_consumer=tensor_ipc_sender)
|
||||
self.decoder = MsgpackDecoder(EngineCoreOutputs)
|
||||
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
dp_rank = parallel_config.data_parallel_index
|
||||
dp_local_size = parallel_config.data_parallel_size_local
|
||||
|
||||
178
vllm/v1/engine/tensor_ipc.py
Normal file
178
vllm/v1/engine/tensor_ipc.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""Tensor IPC transport via torch.multiprocessing.Queue.
|
||||
|
||||
This module contains the queue-based transport logic for sharing tensors
|
||||
between processes (e.g., API server -> engine core). The msgpack layer
|
||||
emits/consumes lightweight :class:`TensorIpcData` values, while transport
|
||||
state such as request association, handle generation, queue routing, buffering,
|
||||
and cleanup lives here.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from dataclasses import field
|
||||
from multiprocessing.queues import Queue as MPQueue
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.serial_utils import OOBTensorConsumer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
TensorIpcQueue = MPQueue
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TensorIpcData:
|
||||
"""
|
||||
Data sent via torch.multiprocessing.Queue for zero-copy IPC.
|
||||
|
||||
Contains the tensor_id and the actual tensor. The tensor is
|
||||
shared in memory (GPU or CPU) for efficient inter-process communication.
|
||||
"""
|
||||
|
||||
sender_id: str
|
||||
message_id: int
|
||||
tensor_id: int
|
||||
tensor: torch.Tensor
|
||||
|
||||
|
||||
class TensorIpcSender(OOBTensorConsumer):
|
||||
"""Send-side logic for tensor IPC via torch.multiprocessing.Queue.
|
||||
|
||||
Uses a single queue targeting rank 0 (the only rank that consumes
|
||||
multimodal tensors during TP>1 / PP>1. Note: DP>1 not supported).
|
||||
"""
|
||||
|
||||
def __init__(self, queue: TensorIpcQueue):
|
||||
self.queue = queue
|
||||
self._tensor_id_counter = 0
|
||||
self._message_counter = 0
|
||||
self._sender_id = uuid.uuid4().hex[:8]
|
||||
|
||||
def set_target_engine(self, target_engine: int) -> None:
|
||||
if target_engine != 0:
|
||||
raise IndexError(
|
||||
"TensorIpcSender only supports a single queue; "
|
||||
f"got target engine {target_engine}"
|
||||
)
|
||||
|
||||
def new_message(self) -> None:
|
||||
self._message_counter += 1
|
||||
self._tensor_id_counter = 0
|
||||
|
||||
def __call__(self, tensor: torch.Tensor) -> dict[str, Any] | None:
|
||||
"""Send tensor via queue, return its handle. Returns None if failed."""
|
||||
try:
|
||||
# Move tensor to shared memory for IPC
|
||||
# This is required for proper inter-process communication
|
||||
if not tensor.is_shared():
|
||||
tensor = tensor.share_memory_()
|
||||
|
||||
metadata = {
|
||||
"sender_id": self._sender_id,
|
||||
"message_id": self._message_counter,
|
||||
"tensor_id": self._tensor_id_counter,
|
||||
}
|
||||
|
||||
self._tensor_id_counter += 1
|
||||
|
||||
ipc_data = TensorIpcData(**metadata, tensor=tensor) # type: ignore[arg-type]
|
||||
|
||||
# Use a timeout to avoid blocking indefinitely
|
||||
self.queue.put(ipc_data, timeout=10.0)
|
||||
|
||||
logger.debug(
|
||||
"Sent tensor %s for (shape=%s, device=%s) "
|
||||
"via IPC queue (shared memory)",
|
||||
metadata,
|
||||
tensor.shape,
|
||||
tensor.device,
|
||||
)
|
||||
|
||||
return metadata
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to send tensor via IPC queue: %s. "
|
||||
"Falling back to standard serialization.",
|
||||
e,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Sender:
|
||||
current_message_id: int = -1
|
||||
tensors: dict[int, dict[int, torch.Tensor]] = field(default_factory=dict)
|
||||
|
||||
|
||||
class TensorIpcReceiver:
|
||||
"""Receive-side logic for tensor IPC via torch.multiprocessing.Queue.
|
||||
|
||||
Wraps the queue receive logic previously embedded in MsgpackDecoder.
|
||||
"""
|
||||
|
||||
def __init__(self, queue: TensorIpcQueue):
|
||||
self.queue = queue
|
||||
self._tensor_buffers = defaultdict[str, _Sender](_Sender)
|
||||
|
||||
def __call__(
|
||||
self, dtype: str, shape: tuple[int, ...], meta: dict[str, Any]
|
||||
) -> torch.Tensor:
|
||||
"""Retrieve a tensor from torch.multiprocessing.Queue.
|
||||
|
||||
Uses a drain-and-buffer pattern: drains all available tensors from
|
||||
the queue, buffering them, until the requested tensor is found.
|
||||
Works for CUDA and CPU.
|
||||
"""
|
||||
|
||||
# Create lookup key from handle
|
||||
sender_id: str = meta["sender_id"]
|
||||
message_id: int = meta["message_id"]
|
||||
tensor_id: int = meta["tensor_id"]
|
||||
|
||||
# Drain all available tensors. We save them regardless if this is
|
||||
# the one we're waiting for as they may arrive out of order from
|
||||
# multiple producers.
|
||||
while True:
|
||||
sender = self._tensor_buffers.get(sender_id)
|
||||
if sender is not None:
|
||||
tensors = sender.tensors
|
||||
tensor = tensors.get(message_id, {}).pop(tensor_id, None)
|
||||
if tensor is not None:
|
||||
if sender.current_message_id != message_id:
|
||||
while tensors and (mid := next(iter(tensors))) < message_id:
|
||||
if sender.tensors.pop(mid):
|
||||
logger.warning(
|
||||
"Discarding %d stale tensors from sender %s",
|
||||
sender_id,
|
||||
)
|
||||
sender.current_message_id = message_id
|
||||
logger.debug(
|
||||
"Received tensor %s from sender %s for (shape=%s, device=%s) "
|
||||
"via IPC queue (shared memory)",
|
||||
(message_id, tensor_id),
|
||||
sender_id,
|
||||
tensor.shape,
|
||||
tensor.device,
|
||||
)
|
||||
return tensor
|
||||
|
||||
ipc_data: TensorIpcData = self.queue.get(timeout=10.0)
|
||||
|
||||
# Store tensor
|
||||
sender = self._tensor_buffers[ipc_data.sender_id]
|
||||
if sender.current_message_id > ipc_data.message_id:
|
||||
logger.warning(
|
||||
"Ignoring stale tensor from sender %s", ipc_data.sender_id
|
||||
)
|
||||
continue
|
||||
|
||||
sender.tensors.setdefault(ipc_data.message_id, {})[ipc_data.tensor_id] = (
|
||||
ipc_data.tensor
|
||||
)
|
||||
@@ -10,6 +10,7 @@ from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from multiprocessing import Process, connection
|
||||
from multiprocessing.process import BaseProcess
|
||||
from multiprocessing.queues import Queue
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -95,6 +96,7 @@ class CoreEngineProcManager:
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
client_handshake_address: str | None = None,
|
||||
tensor_queue: Queue | None = None,
|
||||
):
|
||||
context = get_mp_context()
|
||||
common_kwargs = {
|
||||
@@ -103,6 +105,7 @@ class CoreEngineProcManager:
|
||||
"handshake_address": handshake_address,
|
||||
"executor_class": executor_class,
|
||||
"log_stats": log_stats,
|
||||
"tensor_queue": tensor_queue,
|
||||
}
|
||||
|
||||
if client_handshake_address:
|
||||
@@ -864,6 +867,7 @@ def launch_core_engines(
|
||||
CoreEngineProcManager | CoreEngineActorManager | None,
|
||||
DPCoordinator | None,
|
||||
EngineZmqAddresses,
|
||||
Queue | None,
|
||||
]
|
||||
]:
|
||||
"""Launch engine and DP coordinator processes as needed."""
|
||||
@@ -878,6 +882,14 @@ def launch_core_engines(
|
||||
|
||||
offline_mode = local_start_index is not None
|
||||
|
||||
# Create a single tensor IPC queue for sharing multimodal tensors between
|
||||
# API servers and engine core. Returns a single queue since we only support
|
||||
# DP=1 for this data flow.
|
||||
tensor_queue: Queue | None = None
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
if multimodal_config is not None and multimodal_config.mm_tensor_ipc == "torch_shm":
|
||||
tensor_queue = get_mp_context().Queue()
|
||||
|
||||
# Run the DP Coordinator process with rank 0 when in online DP mode.
|
||||
# The coordinator is needed for:
|
||||
# 1. Internal/hybrid LB: collecting and publishing queue stats for load balancing
|
||||
@@ -913,7 +925,7 @@ def launch_core_engines(
|
||||
log_stats=log_stats,
|
||||
)
|
||||
|
||||
yield engine_actor_manager, coordinator, addresses
|
||||
yield engine_actor_manager, coordinator, addresses, tensor_queue
|
||||
return
|
||||
|
||||
if offline_mode:
|
||||
@@ -975,11 +987,12 @@ def launch_core_engines(
|
||||
local_engine_count=local_engine_count,
|
||||
start_index=dp_rank,
|
||||
local_start_index=local_start_index or 0,
|
||||
tensor_queue=tensor_queue,
|
||||
)
|
||||
else:
|
||||
local_engine_manager = None
|
||||
|
||||
yield local_engine_manager, coordinator, addresses
|
||||
yield local_engine_manager, coordinator, addresses, tensor_queue
|
||||
|
||||
# Now wait for engines to start.
|
||||
wait_for_engine_startup(
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import dataclasses
|
||||
import importlib
|
||||
import pickle
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
from inspect import isclass
|
||||
@@ -53,6 +54,27 @@ MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = {
|
||||
bytestr: TypeAlias = bytes | bytearray | memoryview | zmq.Frame
|
||||
|
||||
|
||||
class OOBTensorConsumer(ABC):
|
||||
@abstractmethod
|
||||
def __call__(self, tensor: torch.Tensor) -> dict | None:
|
||||
"""
|
||||
Called with tensors for the current message.
|
||||
Returns None to reject the tensor (falls back to regular serialization),
|
||||
otherwise a dict with arbitrary placeholder data to be included
|
||||
in the serialized message.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def new_message(self) -> None:
|
||||
"""Called at the start of each new encoded message."""
|
||||
pass
|
||||
|
||||
|
||||
# dtype, shape, metadata -> tensor
|
||||
OOBTensorProvider = Callable[[str, tuple[int, ...], dict], torch.Tensor]
|
||||
|
||||
|
||||
def _log_insecure_serialization_warning():
|
||||
logger.warning_once(
|
||||
"Allowing insecure serialization using pickle due to "
|
||||
@@ -119,9 +141,16 @@ class MsgpackEncoder:
|
||||
|
||||
By default, arrays below 256B are serialized inline Larger will get sent
|
||||
via dedicated messages. Note that this is a per-tensor limit.
|
||||
|
||||
When a ``oob_tensor_consumer`` is provided, tensors (CUDA and CPU) will be
|
||||
offered to it for out-of-band handling.
|
||||
"""
|
||||
|
||||
def __init__(self, size_threshold: int | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
size_threshold: int | None = None,
|
||||
oob_tensor_consumer: OOBTensorConsumer | None = None,
|
||||
):
|
||||
if size_threshold is None:
|
||||
size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
|
||||
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
|
||||
@@ -130,11 +159,14 @@ class MsgpackEncoder:
|
||||
# pass custom data to the hook otherwise.
|
||||
self.aux_buffers: list[bytestr] | None = None
|
||||
self.size_threshold = size_threshold
|
||||
self.oob_tensor_consumer = oob_tensor_consumer
|
||||
if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||
_log_insecure_serialization_warning()
|
||||
|
||||
def encode(self, obj: Any) -> Sequence[bytestr]:
|
||||
try:
|
||||
if self.oob_tensor_consumer is not None:
|
||||
self.oob_tensor_consumer.new_message()
|
||||
self.aux_buffers = bufs = [b""]
|
||||
bufs[0] = self.encoder.encode(obj)
|
||||
# This `bufs` list allows us to collect direct pointers to backing
|
||||
@@ -147,6 +179,8 @@ class MsgpackEncoder:
|
||||
|
||||
def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
|
||||
try:
|
||||
if self.oob_tensor_consumer is not None:
|
||||
self.oob_tensor_consumer.new_message()
|
||||
self.aux_buffers = [buf]
|
||||
bufs = self.aux_buffers
|
||||
self.encoder.encode_into(obj, buf)
|
||||
@@ -222,17 +256,19 @@ class MsgpackEncoder:
|
||||
|
||||
def _encode_tensor(
|
||||
self, obj: torch.Tensor
|
||||
) -> tuple[str, tuple[int, ...], int | memoryview]:
|
||||
assert self.aux_buffers is not None
|
||||
) -> tuple[str, tuple[int, ...], int | dict | memoryview]:
|
||||
oob_consumer = self.oob_tensor_consumer
|
||||
# view the tensor as a contiguous 1D array of bytes
|
||||
arr_data = tensor_data(obj)
|
||||
if obj.nbytes < self.size_threshold:
|
||||
if obj.nbytes < self.size_threshold and obj.is_cpu:
|
||||
# Smaller tensors are encoded inline, just like ndarrays.
|
||||
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
|
||||
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, tensor_data(obj))
|
||||
elif oob_consumer is not None and (data := oob_consumer(obj)) is not None:
|
||||
assert isinstance(data, dict)
|
||||
else:
|
||||
# Otherwise encode index of backing buffer to avoid copy.
|
||||
assert self.aux_buffers is not None
|
||||
data = len(self.aux_buffers)
|
||||
self.aux_buffers.append(arr_data)
|
||||
self.aux_buffers.append(tensor_data(obj))
|
||||
dtype = str(obj.dtype).removeprefix("torch.")
|
||||
return dtype, obj.shape, data
|
||||
|
||||
@@ -279,9 +315,17 @@ class MsgpackDecoder:
|
||||
|
||||
Note that unlike vanilla `msgspec` Decoders, this interface is generally
|
||||
not thread-safe when encoding tensors / numpy arrays.
|
||||
|
||||
``oob_tensor_provider`` must be used when an OOBTensorConsumer is used on the
|
||||
encoder side.
|
||||
"""
|
||||
|
||||
def __init__(self, t: Any | None = None, share_mem: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
t: Any | None = None,
|
||||
share_mem: bool = True,
|
||||
oob_tensor_provider: OOBTensorProvider | None = None,
|
||||
):
|
||||
self.share_mem = share_mem
|
||||
self.pin_tensors = is_pin_memory_available()
|
||||
args = () if t is None else (t,)
|
||||
@@ -289,6 +333,7 @@ class MsgpackDecoder:
|
||||
*args, ext_hook=self.ext_hook, dec_hook=self.dec_hook
|
||||
)
|
||||
self.aux_buffers: Sequence[bytestr] = ()
|
||||
self.oob_tensor_provider = oob_tensor_provider
|
||||
if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||
_log_insecure_serialization_warning()
|
||||
|
||||
@@ -353,6 +398,12 @@ class MsgpackDecoder:
|
||||
|
||||
def _decode_tensor(self, arr: Any) -> torch.Tensor:
|
||||
dtype, shape, data = arr
|
||||
if isinstance(data, dict):
|
||||
assert self.oob_tensor_provider, (
|
||||
"Received OOB tensor but tensor provider is not set"
|
||||
)
|
||||
return self.oob_tensor_provider(dtype, shape, data)
|
||||
|
||||
is_aux = isinstance(data, int)
|
||||
buffer = self.aux_buffers[data] if is_aux else data
|
||||
buffer = buffer if isinstance(buffer, memoryview) else memoryview(buffer)
|
||||
|
||||
@@ -10,6 +10,7 @@ from contextlib import AbstractContextManager
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import connection
|
||||
from multiprocessing.process import BaseProcess
|
||||
from multiprocessing.queues import Queue
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@@ -173,6 +174,7 @@ class APIServerProcessManager:
|
||||
input_addresses: list[str],
|
||||
output_addresses: list[str],
|
||||
stats_update_address: str | None = None,
|
||||
tensor_queue: Queue | None = None,
|
||||
):
|
||||
"""Initialize and start API server worker processes.
|
||||
|
||||
@@ -185,6 +187,7 @@ class APIServerProcessManager:
|
||||
input_addresses: Input addresses for each API server
|
||||
output_addresses: Output addresses for each API server
|
||||
stats_update_address: Optional stats update address
|
||||
tensor_queue: Optional tensor IPC queue for sharing MM tensors
|
||||
"""
|
||||
self.listen_address = listen_address
|
||||
self.sock = sock
|
||||
@@ -205,6 +208,8 @@ class APIServerProcessManager:
|
||||
}
|
||||
if stats_update_address is not None:
|
||||
client_config["stats_update_address"] = stats_update_address
|
||||
if tensor_queue is not None:
|
||||
client_config["tensor_queue"] = tensor_queue
|
||||
|
||||
proc = spawn_context.Process(
|
||||
target=target_server_fn,
|
||||
@@ -419,7 +424,7 @@ def tensor_data(tensor: torch.Tensor) -> memoryview:
|
||||
Returns:
|
||||
A memoryview of the tensor data as uint8.
|
||||
"""
|
||||
return tensor.flatten().contiguous().view(torch.uint8).numpy().data
|
||||
return tensor.flatten().cpu().contiguous().view(torch.uint8).numpy().data
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user