[Feat][RL][1/2] Native Weight Syncing API: NCCL (#31943)
Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Aaron Hao <ahao@anyscale.com> Co-authored-by: SumanthRH <sumanthrh99@gmail.com>
This commit is contained in:
443
tests/distributed/test_packed_tensor.py
Normal file
443
tests/distributed/test_packed_tensor.py
Normal file
@@ -0,0 +1,443 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for packed tensor broadcasting functionality.
|
||||
|
||||
Unit tests for packed_broadcast_producer and packed_broadcast_consumer.
|
||||
These utilities enable efficient batched tensor transfer over NCCL.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.distributed.weight_transfer.nccl_engine import NCCLWeightTransferUpdateInfo
|
||||
from vllm.distributed.weight_transfer.packed_tensor import (
|
||||
packed_broadcast_consumer,
|
||||
packed_broadcast_producer,
|
||||
)
|
||||
|
||||
|
||||
class MockCommunicationGroup:
|
||||
"""Mock communication group for testing producer broadcast operations."""
|
||||
|
||||
def __init__(self):
|
||||
self.broadcasted_tensors: list[torch.Tensor] = []
|
||||
self.broadcast_count = 0
|
||||
self.device = torch.device("cuda:0")
|
||||
|
||||
def broadcast(self, tensor, src):
|
||||
"""Mock broadcast that stores the tensor for later verification."""
|
||||
self.broadcasted_tensors.append(tensor.clone())
|
||||
self.broadcast_count += 1
|
||||
|
||||
|
||||
class MockConsumerCommunicationGroup:
|
||||
"""Mock communication group for consumer that returns pre-stored tensors."""
|
||||
|
||||
def __init__(self, tensors_to_return: list[torch.Tensor]):
|
||||
self.tensors_to_return = tensors_to_return
|
||||
self.current_index = 0
|
||||
self.device = torch.device("cuda:0")
|
||||
|
||||
def broadcast(self, tensor, src):
|
||||
"""Mock broadcast that fills the tensor with pre-stored data."""
|
||||
if self.current_index < len(self.tensors_to_return):
|
||||
tensor.copy_(self.tensors_to_return[self.current_index])
|
||||
self.current_index += 1
|
||||
|
||||
|
||||
def create_mock_model_params(
|
||||
num_layers: int = 3,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
) -> list[tuple[str, torch.Tensor]]:
|
||||
"""Create mock model parameters for testing."""
|
||||
params = []
|
||||
for i in range(num_layers):
|
||||
params.append((f"layer{i}.weight", torch.randn(10, 20, dtype=dtype)))
|
||||
params.append((f"layer{i}.bias", torch.randn(10, dtype=dtype)))
|
||||
return params
|
||||
|
||||
|
||||
def create_state_dict_info(
|
||||
params: list[tuple[str, torch.Tensor]],
|
||||
) -> dict[str, tuple[tuple[int, ...], torch.dtype]]:
|
||||
"""Create state dict info (name -> (shape, dtype)) from params."""
|
||||
return {name: (tuple(tensor.shape), tensor.dtype) for name, tensor in params}
|
||||
|
||||
|
||||
# --- Unit Tests: NCCLWeightTransferUpdateInfo packed field ---
|
||||
|
||||
|
||||
class TestNCCLWeightTransferUpdateInfoPacked:
|
||||
"""Test NCCLWeightTransferUpdateInfo dataclass packed field."""
|
||||
|
||||
def test_packed_default_false(self):
|
||||
"""Test that packed defaults to False."""
|
||||
info = NCCLWeightTransferUpdateInfo(
|
||||
names=["layer.weight"],
|
||||
dtype_names=["float32"],
|
||||
shapes=[[10, 10]],
|
||||
)
|
||||
assert info.packed is False
|
||||
|
||||
def test_packed_can_be_set_true(self):
|
||||
"""Test that packed can be set to True."""
|
||||
info = NCCLWeightTransferUpdateInfo(
|
||||
names=["layer.weight"],
|
||||
dtype_names=["float32"],
|
||||
shapes=[[10, 10]],
|
||||
packed=True,
|
||||
)
|
||||
assert info.packed is True
|
||||
|
||||
|
||||
# --- Unit Tests: packed_broadcast_producer ---
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
class TestPackedBroadcastProducer:
|
||||
"""Test packed_broadcast_producer function."""
|
||||
|
||||
def test_producer_broadcasts_tensors(self):
|
||||
"""Test that producer broadcasts all tensors."""
|
||||
params = create_mock_model_params()
|
||||
params_cuda = [(name, tensor.cuda()) for name, tensor in params]
|
||||
|
||||
mock_group = MockCommunicationGroup()
|
||||
|
||||
# Use a small target size to force multiple batches
|
||||
packed_broadcast_producer(
|
||||
iterator=iter(params_cuda),
|
||||
group=mock_group,
|
||||
src=0,
|
||||
post_iter_func=lambda x: x[1],
|
||||
buffer_size_bytes=500,
|
||||
)
|
||||
|
||||
# Should have broadcasted some tensors
|
||||
assert mock_group.broadcast_count > 0
|
||||
assert len(mock_group.broadcasted_tensors) > 0
|
||||
|
||||
def test_producer_single_large_tensor(self):
|
||||
"""Test with a single tensor larger than target size."""
|
||||
# Create a large tensor
|
||||
large_tensor = torch.randn(1000, 1000, dtype=torch.float32).cuda()
|
||||
params = [("large_weight", large_tensor)]
|
||||
|
||||
mock_group = MockCommunicationGroup()
|
||||
|
||||
# Small target size to force the tensor to exceed it
|
||||
packed_broadcast_producer(
|
||||
iterator=iter(params),
|
||||
group=mock_group,
|
||||
src=0,
|
||||
post_iter_func=lambda x: x[1],
|
||||
buffer_size_bytes=100,
|
||||
)
|
||||
|
||||
# Should still broadcast the tensor (at least 1 broadcast)
|
||||
assert mock_group.broadcast_count >= 1
|
||||
assert len(mock_group.broadcasted_tensors) >= 1
|
||||
|
||||
# Verify the total broadcasted size matches the tensor
|
||||
expected_size = large_tensor.numel() * large_tensor.element_size()
|
||||
actual_size = sum(t.numel() for t in mock_group.broadcasted_tensors)
|
||||
assert actual_size == expected_size
|
||||
|
||||
def test_producer_multiple_batches(self):
|
||||
"""Test that tensors are properly batched when exceeding target size."""
|
||||
# Create many small tensors
|
||||
params = [
|
||||
(f"weight_{i}", torch.randn(10, 10, dtype=torch.float32).cuda())
|
||||
for i in range(20)
|
||||
]
|
||||
|
||||
mock_group = MockCommunicationGroup()
|
||||
|
||||
# Small target size to force multiple batches
|
||||
packed_broadcast_producer(
|
||||
iterator=iter(params),
|
||||
group=mock_group,
|
||||
src=0,
|
||||
post_iter_func=lambda x: x[1],
|
||||
buffer_size_bytes=2000,
|
||||
)
|
||||
|
||||
# Should have multiple broadcasts
|
||||
assert mock_group.broadcast_count > 1
|
||||
|
||||
# Total size should match sum of all tensors
|
||||
expected_total = sum(t.numel() * t.element_size() for _, t in params)
|
||||
actual_total = sum(t.numel() for t in mock_group.broadcasted_tensors)
|
||||
assert actual_total == expected_total
|
||||
|
||||
def test_producer_empty_iterator(self):
|
||||
"""Test producer handles empty iterator gracefully."""
|
||||
mock_group = MockCommunicationGroup()
|
||||
|
||||
packed_broadcast_producer(
|
||||
iterator=iter([]),
|
||||
group=mock_group,
|
||||
src=0,
|
||||
post_iter_func=lambda x: x[1],
|
||||
buffer_size_bytes=1000,
|
||||
)
|
||||
|
||||
# No broadcasts for empty iterator
|
||||
assert mock_group.broadcast_count == 0
|
||||
|
||||
|
||||
# --- Unit Tests: packed_broadcast_consumer ---
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
class TestPackedBroadcastConsumer:
|
||||
"""Test packed_broadcast_consumer function."""
|
||||
|
||||
def test_consumer_receives_tensors(self):
|
||||
"""Test that consumer receives and unpacks tensors."""
|
||||
params = create_mock_model_params()
|
||||
params_cuda = [(name, tensor.cuda()) for name, tensor in params]
|
||||
|
||||
buffer_size = 2000
|
||||
|
||||
# First, run producer to get the broadcasted tensors
|
||||
producer_group = MockCommunicationGroup()
|
||||
|
||||
packed_broadcast_producer(
|
||||
iterator=iter(params_cuda),
|
||||
group=producer_group,
|
||||
src=0,
|
||||
post_iter_func=lambda x: x[1],
|
||||
buffer_size_bytes=buffer_size,
|
||||
)
|
||||
|
||||
# Now run consumer with the broadcasted tensors
|
||||
consumer_group = MockConsumerCommunicationGroup(
|
||||
producer_group.broadcasted_tensors
|
||||
)
|
||||
|
||||
state_dict_info = create_state_dict_info(params_cuda)
|
||||
|
||||
unpacked_tensors = {}
|
||||
|
||||
def post_unpack_func(tensor_list):
|
||||
for name, tensor in tensor_list:
|
||||
unpacked_tensors[name] = tensor.clone()
|
||||
|
||||
packed_broadcast_consumer(
|
||||
iterator=iter(state_dict_info.items()),
|
||||
group=consumer_group,
|
||||
src=0,
|
||||
post_unpack_func=post_unpack_func,
|
||||
buffer_size_bytes=buffer_size,
|
||||
)
|
||||
|
||||
# Verify all parameters were unpacked
|
||||
assert len(unpacked_tensors) == len(params)
|
||||
|
||||
# Verify each tensor matches the original
|
||||
for name, original_tensor in params_cuda:
|
||||
assert name in unpacked_tensors
|
||||
unpacked = unpacked_tensors[name]
|
||||
assert unpacked.shape == original_tensor.shape
|
||||
assert unpacked.dtype == original_tensor.dtype
|
||||
assert torch.allclose(unpacked, original_tensor, rtol=1e-5, atol=1e-7)
|
||||
|
||||
|
||||
# --- Integration Tests: Producer-Consumer Roundtrip ---
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
class TestPackedBroadcastRoundtrip:
|
||||
"""Test producer-consumer roundtrip behavior."""
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
def test_roundtrip_different_dtypes(self, dtype):
|
||||
"""Test roundtrip with different data types."""
|
||||
params = create_mock_model_params(num_layers=2, dtype=dtype)
|
||||
params_cuda = [(name, tensor.cuda()) for name, tensor in params]
|
||||
|
||||
buffer_size = 1000
|
||||
producer_group = MockCommunicationGroup()
|
||||
|
||||
packed_broadcast_producer(
|
||||
iterator=iter(params_cuda),
|
||||
group=producer_group,
|
||||
src=0,
|
||||
post_iter_func=lambda x: x[1],
|
||||
buffer_size_bytes=buffer_size,
|
||||
)
|
||||
|
||||
consumer_group = MockConsumerCommunicationGroup(
|
||||
producer_group.broadcasted_tensors
|
||||
)
|
||||
|
||||
state_dict_info = create_state_dict_info(params_cuda)
|
||||
unpacked_tensors = {}
|
||||
|
||||
def post_unpack_func(tensor_list):
|
||||
for name, tensor in tensor_list:
|
||||
unpacked_tensors[name] = tensor.clone()
|
||||
|
||||
packed_broadcast_consumer(
|
||||
iterator=iter(state_dict_info.items()),
|
||||
group=consumer_group,
|
||||
src=0,
|
||||
post_unpack_func=post_unpack_func,
|
||||
buffer_size_bytes=buffer_size,
|
||||
)
|
||||
|
||||
# Verify roundtrip preserves data
|
||||
for name, original_tensor in params_cuda:
|
||||
assert name in unpacked_tensors
|
||||
unpacked = unpacked_tensors[name]
|
||||
assert unpacked.dtype == dtype
|
||||
assert torch.allclose(unpacked, original_tensor, rtol=1e-4, atol=1e-6)
|
||||
|
||||
def test_roundtrip_mixed_dtypes(self):
|
||||
"""Test roundtrip with mixed data types."""
|
||||
# Create params with mixed dtypes
|
||||
params = [
|
||||
("layer1.weight", torch.randn(10, 20, dtype=torch.float32).cuda()),
|
||||
("layer1.bias", torch.randn(10, dtype=torch.float16).cuda()),
|
||||
("layer2.weight", torch.randn(20, 30, dtype=torch.bfloat16).cuda()),
|
||||
]
|
||||
|
||||
buffer_size = 500
|
||||
producer_group = MockCommunicationGroup()
|
||||
|
||||
packed_broadcast_producer(
|
||||
iterator=iter(params),
|
||||
group=producer_group,
|
||||
src=0,
|
||||
post_iter_func=lambda x: x[1],
|
||||
buffer_size_bytes=buffer_size,
|
||||
)
|
||||
|
||||
consumer_group = MockConsumerCommunicationGroup(
|
||||
producer_group.broadcasted_tensors
|
||||
)
|
||||
|
||||
state_dict_info = create_state_dict_info(params)
|
||||
unpacked_tensors = {}
|
||||
|
||||
def post_unpack_func(tensor_list):
|
||||
for name, tensor in tensor_list:
|
||||
unpacked_tensors[name] = tensor.clone()
|
||||
|
||||
packed_broadcast_consumer(
|
||||
iterator=iter(state_dict_info.items()),
|
||||
group=consumer_group,
|
||||
src=0,
|
||||
post_unpack_func=post_unpack_func,
|
||||
buffer_size_bytes=buffer_size,
|
||||
)
|
||||
|
||||
# Verify all params roundtrip correctly with correct dtypes
|
||||
for name, original_tensor in params:
|
||||
assert name in unpacked_tensors
|
||||
unpacked = unpacked_tensors[name]
|
||||
assert unpacked.shape == original_tensor.shape
|
||||
assert unpacked.dtype == original_tensor.dtype
|
||||
assert torch.allclose(unpacked, original_tensor, rtol=1e-4, atol=1e-6)
|
||||
|
||||
@pytest.mark.parametrize("target_size", [100, 1000, 10000, 100000])
|
||||
def test_roundtrip_different_batch_sizes(self, target_size):
|
||||
"""Test roundtrip with different target batch sizes."""
|
||||
params = create_mock_model_params(num_layers=5)
|
||||
params_cuda = [(name, tensor.cuda()) for name, tensor in params]
|
||||
|
||||
producer_group = MockCommunicationGroup()
|
||||
|
||||
packed_broadcast_producer(
|
||||
iterator=iter(params_cuda),
|
||||
group=producer_group,
|
||||
src=0,
|
||||
post_iter_func=lambda x: x[1],
|
||||
buffer_size_bytes=target_size,
|
||||
)
|
||||
|
||||
consumer_group = MockConsumerCommunicationGroup(
|
||||
producer_group.broadcasted_tensors
|
||||
)
|
||||
|
||||
state_dict_info = create_state_dict_info(params_cuda)
|
||||
unpacked_tensors = {}
|
||||
|
||||
def post_unpack_func(tensor_list):
|
||||
for name, tensor in tensor_list:
|
||||
unpacked_tensors[name] = tensor.clone()
|
||||
|
||||
packed_broadcast_consumer(
|
||||
iterator=iter(state_dict_info.items()),
|
||||
group=consumer_group,
|
||||
src=0,
|
||||
post_unpack_func=post_unpack_func,
|
||||
buffer_size_bytes=target_size,
|
||||
)
|
||||
|
||||
# Verify all params roundtrip correctly
|
||||
assert len(unpacked_tensors) == len(params)
|
||||
for name, original_tensor in params_cuda:
|
||||
assert name in unpacked_tensors
|
||||
assert torch.allclose(
|
||||
unpacked_tensors[name], original_tensor, rtol=1e-5, atol=1e-7
|
||||
)
|
||||
|
||||
def test_roundtrip_non_contiguous_tensors(self):
|
||||
"""Test roundtrip with non-contiguous tensors from the trainer."""
|
||||
# Create non-contiguous tensors (simulating trainer outputs)
|
||||
# Transposed tensors are non-contiguous
|
||||
weight1 = torch.randn(20, 10, dtype=torch.float32).cuda().T
|
||||
# Sliced tensors with step are non-contiguous
|
||||
weight2 = torch.randn(40, 30, dtype=torch.float16).cuda()[::2, ::2]
|
||||
# Permuted tensors are non-contiguous
|
||||
weight3 = torch.randn(5, 10, 15, dtype=torch.bfloat16).cuda().permute(2, 0, 1)
|
||||
|
||||
params = [
|
||||
("layer1.weight", weight1),
|
||||
("layer2.weight", weight2),
|
||||
("layer3.weight", weight3),
|
||||
]
|
||||
|
||||
# Verify tensors are indeed non-contiguous
|
||||
for name, tensor in params:
|
||||
assert not tensor.is_contiguous(), f"{name} should be non-contiguous"
|
||||
|
||||
buffer_size = 500
|
||||
producer_group = MockCommunicationGroup()
|
||||
|
||||
packed_broadcast_producer(
|
||||
iterator=iter(params),
|
||||
group=producer_group,
|
||||
src=0,
|
||||
post_iter_func=lambda x: x[1],
|
||||
buffer_size_bytes=buffer_size,
|
||||
)
|
||||
|
||||
consumer_group = MockConsumerCommunicationGroup(
|
||||
producer_group.broadcasted_tensors
|
||||
)
|
||||
|
||||
state_dict_info = create_state_dict_info(params)
|
||||
unpacked_tensors = {}
|
||||
|
||||
def post_unpack_func(tensor_list):
|
||||
for name, tensor in tensor_list:
|
||||
unpacked_tensors[name] = tensor.clone()
|
||||
|
||||
packed_broadcast_consumer(
|
||||
iterator=iter(state_dict_info.items()),
|
||||
group=consumer_group,
|
||||
src=0,
|
||||
post_unpack_func=post_unpack_func,
|
||||
buffer_size_bytes=buffer_size,
|
||||
)
|
||||
|
||||
# Verify all non-contiguous params roundtrip correctly
|
||||
for name, original_tensor in params:
|
||||
assert name in unpacked_tensors
|
||||
unpacked = unpacked_tensors[name]
|
||||
assert unpacked.shape == original_tensor.shape
|
||||
assert unpacked.dtype == original_tensor.dtype
|
||||
assert torch.allclose(unpacked, original_tensor, rtol=1e-4, atol=1e-6)
|
||||
Reference in New Issue
Block a user