Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Aaron Hao <ahao@anyscale.com> Co-authored-by: SumanthRH <sumanthrh99@gmail.com>
444 lines
15 KiB
Python
444 lines
15 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Tests for packed tensor broadcasting functionality.
|
|
|
|
Unit tests for packed_broadcast_producer and packed_broadcast_consumer.
|
|
These utilities enable efficient batched tensor transfer over NCCL.
|
|
"""
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm.distributed.weight_transfer.nccl_engine import NCCLWeightTransferUpdateInfo
|
|
from vllm.distributed.weight_transfer.packed_tensor import (
|
|
packed_broadcast_consumer,
|
|
packed_broadcast_producer,
|
|
)
|
|
|
|
|
|
class MockCommunicationGroup:
|
|
"""Mock communication group for testing producer broadcast operations."""
|
|
|
|
def __init__(self):
|
|
self.broadcasted_tensors: list[torch.Tensor] = []
|
|
self.broadcast_count = 0
|
|
self.device = torch.device("cuda:0")
|
|
|
|
def broadcast(self, tensor, src):
|
|
"""Mock broadcast that stores the tensor for later verification."""
|
|
self.broadcasted_tensors.append(tensor.clone())
|
|
self.broadcast_count += 1
|
|
|
|
|
|
class MockConsumerCommunicationGroup:
|
|
"""Mock communication group for consumer that returns pre-stored tensors."""
|
|
|
|
def __init__(self, tensors_to_return: list[torch.Tensor]):
|
|
self.tensors_to_return = tensors_to_return
|
|
self.current_index = 0
|
|
self.device = torch.device("cuda:0")
|
|
|
|
def broadcast(self, tensor, src):
|
|
"""Mock broadcast that fills the tensor with pre-stored data."""
|
|
if self.current_index < len(self.tensors_to_return):
|
|
tensor.copy_(self.tensors_to_return[self.current_index])
|
|
self.current_index += 1
|
|
|
|
|
|
def create_mock_model_params(
|
|
num_layers: int = 3,
|
|
dtype: torch.dtype = torch.float32,
|
|
) -> list[tuple[str, torch.Tensor]]:
|
|
"""Create mock model parameters for testing."""
|
|
params = []
|
|
for i in range(num_layers):
|
|
params.append((f"layer{i}.weight", torch.randn(10, 20, dtype=dtype)))
|
|
params.append((f"layer{i}.bias", torch.randn(10, dtype=dtype)))
|
|
return params
|
|
|
|
|
|
def create_state_dict_info(
|
|
params: list[tuple[str, torch.Tensor]],
|
|
) -> dict[str, tuple[tuple[int, ...], torch.dtype]]:
|
|
"""Create state dict info (name -> (shape, dtype)) from params."""
|
|
return {name: (tuple(tensor.shape), tensor.dtype) for name, tensor in params}
|
|
|
|
|
|
# --- Unit Tests: NCCLWeightTransferUpdateInfo packed field ---
|
|
|
|
|
|
class TestNCCLWeightTransferUpdateInfoPacked:
|
|
"""Test NCCLWeightTransferUpdateInfo dataclass packed field."""
|
|
|
|
def test_packed_default_false(self):
|
|
"""Test that packed defaults to False."""
|
|
info = NCCLWeightTransferUpdateInfo(
|
|
names=["layer.weight"],
|
|
dtype_names=["float32"],
|
|
shapes=[[10, 10]],
|
|
)
|
|
assert info.packed is False
|
|
|
|
def test_packed_can_be_set_true(self):
|
|
"""Test that packed can be set to True."""
|
|
info = NCCLWeightTransferUpdateInfo(
|
|
names=["layer.weight"],
|
|
dtype_names=["float32"],
|
|
shapes=[[10, 10]],
|
|
packed=True,
|
|
)
|
|
assert info.packed is True
|
|
|
|
|
|
# --- Unit Tests: packed_broadcast_producer ---
|
|
|
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
|
class TestPackedBroadcastProducer:
|
|
"""Test packed_broadcast_producer function."""
|
|
|
|
def test_producer_broadcasts_tensors(self):
|
|
"""Test that producer broadcasts all tensors."""
|
|
params = create_mock_model_params()
|
|
params_cuda = [(name, tensor.cuda()) for name, tensor in params]
|
|
|
|
mock_group = MockCommunicationGroup()
|
|
|
|
# Use a small target size to force multiple batches
|
|
packed_broadcast_producer(
|
|
iterator=iter(params_cuda),
|
|
group=mock_group,
|
|
src=0,
|
|
post_iter_func=lambda x: x[1],
|
|
buffer_size_bytes=500,
|
|
)
|
|
|
|
# Should have broadcasted some tensors
|
|
assert mock_group.broadcast_count > 0
|
|
assert len(mock_group.broadcasted_tensors) > 0
|
|
|
|
def test_producer_single_large_tensor(self):
|
|
"""Test with a single tensor larger than target size."""
|
|
# Create a large tensor
|
|
large_tensor = torch.randn(1000, 1000, dtype=torch.float32).cuda()
|
|
params = [("large_weight", large_tensor)]
|
|
|
|
mock_group = MockCommunicationGroup()
|
|
|
|
# Small target size to force the tensor to exceed it
|
|
packed_broadcast_producer(
|
|
iterator=iter(params),
|
|
group=mock_group,
|
|
src=0,
|
|
post_iter_func=lambda x: x[1],
|
|
buffer_size_bytes=100,
|
|
)
|
|
|
|
# Should still broadcast the tensor (at least 1 broadcast)
|
|
assert mock_group.broadcast_count >= 1
|
|
assert len(mock_group.broadcasted_tensors) >= 1
|
|
|
|
# Verify the total broadcasted size matches the tensor
|
|
expected_size = large_tensor.numel() * large_tensor.element_size()
|
|
actual_size = sum(t.numel() for t in mock_group.broadcasted_tensors)
|
|
assert actual_size == expected_size
|
|
|
|
def test_producer_multiple_batches(self):
|
|
"""Test that tensors are properly batched when exceeding target size."""
|
|
# Create many small tensors
|
|
params = [
|
|
(f"weight_{i}", torch.randn(10, 10, dtype=torch.float32).cuda())
|
|
for i in range(20)
|
|
]
|
|
|
|
mock_group = MockCommunicationGroup()
|
|
|
|
# Small target size to force multiple batches
|
|
packed_broadcast_producer(
|
|
iterator=iter(params),
|
|
group=mock_group,
|
|
src=0,
|
|
post_iter_func=lambda x: x[1],
|
|
buffer_size_bytes=2000,
|
|
)
|
|
|
|
# Should have multiple broadcasts
|
|
assert mock_group.broadcast_count > 1
|
|
|
|
# Total size should match sum of all tensors
|
|
expected_total = sum(t.numel() * t.element_size() for _, t in params)
|
|
actual_total = sum(t.numel() for t in mock_group.broadcasted_tensors)
|
|
assert actual_total == expected_total
|
|
|
|
def test_producer_empty_iterator(self):
|
|
"""Test producer handles empty iterator gracefully."""
|
|
mock_group = MockCommunicationGroup()
|
|
|
|
packed_broadcast_producer(
|
|
iterator=iter([]),
|
|
group=mock_group,
|
|
src=0,
|
|
post_iter_func=lambda x: x[1],
|
|
buffer_size_bytes=1000,
|
|
)
|
|
|
|
# No broadcasts for empty iterator
|
|
assert mock_group.broadcast_count == 0
|
|
|
|
|
|
# --- Unit Tests: packed_broadcast_consumer ---
|
|
|
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
|
class TestPackedBroadcastConsumer:
|
|
"""Test packed_broadcast_consumer function."""
|
|
|
|
def test_consumer_receives_tensors(self):
|
|
"""Test that consumer receives and unpacks tensors."""
|
|
params = create_mock_model_params()
|
|
params_cuda = [(name, tensor.cuda()) for name, tensor in params]
|
|
|
|
buffer_size = 2000
|
|
|
|
# First, run producer to get the broadcasted tensors
|
|
producer_group = MockCommunicationGroup()
|
|
|
|
packed_broadcast_producer(
|
|
iterator=iter(params_cuda),
|
|
group=producer_group,
|
|
src=0,
|
|
post_iter_func=lambda x: x[1],
|
|
buffer_size_bytes=buffer_size,
|
|
)
|
|
|
|
# Now run consumer with the broadcasted tensors
|
|
consumer_group = MockConsumerCommunicationGroup(
|
|
producer_group.broadcasted_tensors
|
|
)
|
|
|
|
state_dict_info = create_state_dict_info(params_cuda)
|
|
|
|
unpacked_tensors = {}
|
|
|
|
def post_unpack_func(tensor_list):
|
|
for name, tensor in tensor_list:
|
|
unpacked_tensors[name] = tensor.clone()
|
|
|
|
packed_broadcast_consumer(
|
|
iterator=iter(state_dict_info.items()),
|
|
group=consumer_group,
|
|
src=0,
|
|
post_unpack_func=post_unpack_func,
|
|
buffer_size_bytes=buffer_size,
|
|
)
|
|
|
|
# Verify all parameters were unpacked
|
|
assert len(unpacked_tensors) == len(params)
|
|
|
|
# Verify each tensor matches the original
|
|
for name, original_tensor in params_cuda:
|
|
assert name in unpacked_tensors
|
|
unpacked = unpacked_tensors[name]
|
|
assert unpacked.shape == original_tensor.shape
|
|
assert unpacked.dtype == original_tensor.dtype
|
|
assert torch.allclose(unpacked, original_tensor, rtol=1e-5, atol=1e-7)
|
|
|
|
|
|
# --- Integration Tests: Producer-Consumer Roundtrip ---
|
|
|
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
|
class TestPackedBroadcastRoundtrip:
|
|
"""Test producer-consumer roundtrip behavior."""
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
|
def test_roundtrip_different_dtypes(self, dtype):
|
|
"""Test roundtrip with different data types."""
|
|
params = create_mock_model_params(num_layers=2, dtype=dtype)
|
|
params_cuda = [(name, tensor.cuda()) for name, tensor in params]
|
|
|
|
buffer_size = 1000
|
|
producer_group = MockCommunicationGroup()
|
|
|
|
packed_broadcast_producer(
|
|
iterator=iter(params_cuda),
|
|
group=producer_group,
|
|
src=0,
|
|
post_iter_func=lambda x: x[1],
|
|
buffer_size_bytes=buffer_size,
|
|
)
|
|
|
|
consumer_group = MockConsumerCommunicationGroup(
|
|
producer_group.broadcasted_tensors
|
|
)
|
|
|
|
state_dict_info = create_state_dict_info(params_cuda)
|
|
unpacked_tensors = {}
|
|
|
|
def post_unpack_func(tensor_list):
|
|
for name, tensor in tensor_list:
|
|
unpacked_tensors[name] = tensor.clone()
|
|
|
|
packed_broadcast_consumer(
|
|
iterator=iter(state_dict_info.items()),
|
|
group=consumer_group,
|
|
src=0,
|
|
post_unpack_func=post_unpack_func,
|
|
buffer_size_bytes=buffer_size,
|
|
)
|
|
|
|
# Verify roundtrip preserves data
|
|
for name, original_tensor in params_cuda:
|
|
assert name in unpacked_tensors
|
|
unpacked = unpacked_tensors[name]
|
|
assert unpacked.dtype == dtype
|
|
assert torch.allclose(unpacked, original_tensor, rtol=1e-4, atol=1e-6)
|
|
|
|
def test_roundtrip_mixed_dtypes(self):
|
|
"""Test roundtrip with mixed data types."""
|
|
# Create params with mixed dtypes
|
|
params = [
|
|
("layer1.weight", torch.randn(10, 20, dtype=torch.float32).cuda()),
|
|
("layer1.bias", torch.randn(10, dtype=torch.float16).cuda()),
|
|
("layer2.weight", torch.randn(20, 30, dtype=torch.bfloat16).cuda()),
|
|
]
|
|
|
|
buffer_size = 500
|
|
producer_group = MockCommunicationGroup()
|
|
|
|
packed_broadcast_producer(
|
|
iterator=iter(params),
|
|
group=producer_group,
|
|
src=0,
|
|
post_iter_func=lambda x: x[1],
|
|
buffer_size_bytes=buffer_size,
|
|
)
|
|
|
|
consumer_group = MockConsumerCommunicationGroup(
|
|
producer_group.broadcasted_tensors
|
|
)
|
|
|
|
state_dict_info = create_state_dict_info(params)
|
|
unpacked_tensors = {}
|
|
|
|
def post_unpack_func(tensor_list):
|
|
for name, tensor in tensor_list:
|
|
unpacked_tensors[name] = tensor.clone()
|
|
|
|
packed_broadcast_consumer(
|
|
iterator=iter(state_dict_info.items()),
|
|
group=consumer_group,
|
|
src=0,
|
|
post_unpack_func=post_unpack_func,
|
|
buffer_size_bytes=buffer_size,
|
|
)
|
|
|
|
# Verify all params roundtrip correctly with correct dtypes
|
|
for name, original_tensor in params:
|
|
assert name in unpacked_tensors
|
|
unpacked = unpacked_tensors[name]
|
|
assert unpacked.shape == original_tensor.shape
|
|
assert unpacked.dtype == original_tensor.dtype
|
|
assert torch.allclose(unpacked, original_tensor, rtol=1e-4, atol=1e-6)
|
|
|
|
@pytest.mark.parametrize("target_size", [100, 1000, 10000, 100000])
|
|
def test_roundtrip_different_batch_sizes(self, target_size):
|
|
"""Test roundtrip with different target batch sizes."""
|
|
params = create_mock_model_params(num_layers=5)
|
|
params_cuda = [(name, tensor.cuda()) for name, tensor in params]
|
|
|
|
producer_group = MockCommunicationGroup()
|
|
|
|
packed_broadcast_producer(
|
|
iterator=iter(params_cuda),
|
|
group=producer_group,
|
|
src=0,
|
|
post_iter_func=lambda x: x[1],
|
|
buffer_size_bytes=target_size,
|
|
)
|
|
|
|
consumer_group = MockConsumerCommunicationGroup(
|
|
producer_group.broadcasted_tensors
|
|
)
|
|
|
|
state_dict_info = create_state_dict_info(params_cuda)
|
|
unpacked_tensors = {}
|
|
|
|
def post_unpack_func(tensor_list):
|
|
for name, tensor in tensor_list:
|
|
unpacked_tensors[name] = tensor.clone()
|
|
|
|
packed_broadcast_consumer(
|
|
iterator=iter(state_dict_info.items()),
|
|
group=consumer_group,
|
|
src=0,
|
|
post_unpack_func=post_unpack_func,
|
|
buffer_size_bytes=target_size,
|
|
)
|
|
|
|
# Verify all params roundtrip correctly
|
|
assert len(unpacked_tensors) == len(params)
|
|
for name, original_tensor in params_cuda:
|
|
assert name in unpacked_tensors
|
|
assert torch.allclose(
|
|
unpacked_tensors[name], original_tensor, rtol=1e-5, atol=1e-7
|
|
)
|
|
|
|
def test_roundtrip_non_contiguous_tensors(self):
|
|
"""Test roundtrip with non-contiguous tensors from the trainer."""
|
|
# Create non-contiguous tensors (simulating trainer outputs)
|
|
# Transposed tensors are non-contiguous
|
|
weight1 = torch.randn(20, 10, dtype=torch.float32).cuda().T
|
|
# Sliced tensors with step are non-contiguous
|
|
weight2 = torch.randn(40, 30, dtype=torch.float16).cuda()[::2, ::2]
|
|
# Permuted tensors are non-contiguous
|
|
weight3 = torch.randn(5, 10, 15, dtype=torch.bfloat16).cuda().permute(2, 0, 1)
|
|
|
|
params = [
|
|
("layer1.weight", weight1),
|
|
("layer2.weight", weight2),
|
|
("layer3.weight", weight3),
|
|
]
|
|
|
|
# Verify tensors are indeed non-contiguous
|
|
for name, tensor in params:
|
|
assert not tensor.is_contiguous(), f"{name} should be non-contiguous"
|
|
|
|
buffer_size = 500
|
|
producer_group = MockCommunicationGroup()
|
|
|
|
packed_broadcast_producer(
|
|
iterator=iter(params),
|
|
group=producer_group,
|
|
src=0,
|
|
post_iter_func=lambda x: x[1],
|
|
buffer_size_bytes=buffer_size,
|
|
)
|
|
|
|
consumer_group = MockConsumerCommunicationGroup(
|
|
producer_group.broadcasted_tensors
|
|
)
|
|
|
|
state_dict_info = create_state_dict_info(params)
|
|
unpacked_tensors = {}
|
|
|
|
def post_unpack_func(tensor_list):
|
|
for name, tensor in tensor_list:
|
|
unpacked_tensors[name] = tensor.clone()
|
|
|
|
packed_broadcast_consumer(
|
|
iterator=iter(state_dict_info.items()),
|
|
group=consumer_group,
|
|
src=0,
|
|
post_unpack_func=post_unpack_func,
|
|
buffer_size_bytes=buffer_size,
|
|
)
|
|
|
|
# Verify all non-contiguous params roundtrip correctly
|
|
for name, original_tensor in params:
|
|
assert name in unpacked_tensors
|
|
unpacked = unpacked_tensors[name]
|
|
assert unpacked.shape == original_tensor.shape
|
|
assert unpacked.dtype == original_tensor.dtype
|
|
assert torch.allclose(unpacked, original_tensor, rtol=1e-4, atol=1e-6)
|