[Core] Add All-to-All communication backend for DCP (#34883)
Signed-off-by: Sungsoo Ha <sungsooh@nvidia.com> Signed-off-by: sungsoo ha <hasungsoo@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
192
tests/distributed/test_dcp_a2a.py
Normal file
192
tests/distributed/test_dcp_a2a.py
Normal file
@@ -0,0 +1,192 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Unit tests for DCP A2A communication backend (no GPU required).
|
||||
|
||||
Tests cover:
|
||||
1. DCP A2A config validation (--dcp-comm-backend)
|
||||
2. KVP group function exists
|
||||
3. LSE-weighted combination correctness
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
|
||||
|
||||
class TestDCPCommBackendConfig:
|
||||
"""Test --dcp-comm-backend config validation."""
|
||||
|
||||
def test_default_is_ag_rs(self):
|
||||
"""Default comm backend is ag_rs."""
|
||||
config = ParallelConfig()
|
||||
assert config.dcp_comm_backend == "ag_rs"
|
||||
|
||||
def test_a2a_requires_dcp_greater_than_1(self):
|
||||
"""A2A backend requires decode_context_parallel_size > 1."""
|
||||
with pytest.raises(
|
||||
ValueError, match="requires decode_context_parallel_size > 1"
|
||||
):
|
||||
ParallelConfig(
|
||||
dcp_comm_backend="a2a",
|
||||
decode_context_parallel_size=1,
|
||||
)
|
||||
|
||||
def test_a2a_with_dcp_valid(self):
|
||||
"""A2A backend is valid when DCP > 1."""
|
||||
config = ParallelConfig(
|
||||
dcp_comm_backend="a2a",
|
||||
tensor_parallel_size=8,
|
||||
decode_context_parallel_size=4,
|
||||
)
|
||||
assert config.dcp_comm_backend == "a2a"
|
||||
|
||||
def test_invalid_backend_rejected(self):
|
||||
"""Invalid backend values are rejected."""
|
||||
with pytest.raises(ValueError, match="must be one of"):
|
||||
ParallelConfig(
|
||||
dcp_comm_backend="invalid",
|
||||
)
|
||||
|
||||
def test_ag_rs_with_dcp_1_valid(self):
|
||||
"""ag_rs backend is valid with DCP=1 (no DCP)."""
|
||||
config = ParallelConfig(
|
||||
dcp_comm_backend="ag_rs",
|
||||
decode_context_parallel_size=1,
|
||||
)
|
||||
assert config.dcp_comm_backend == "ag_rs"
|
||||
|
||||
|
||||
class TestLSEWeightedCombine:
|
||||
"""Test LSE-weighted combination logic (CPU only, no GPU).
|
||||
|
||||
The _lse_weighted_combine function is the reference implementation
|
||||
that verifies the Triton kernel's correctness. It computes:
|
||||
|
||||
result[b,h,d] = sum_n(w_n * output_n[b,h,d])
|
||||
|
||||
where w_n = softmax(lse_n) = exp(lse_n) / sum_k(exp(lse_k))
|
||||
"""
|
||||
|
||||
def test_importable(self):
|
||||
"""Verify _lse_weighted_combine is importable."""
|
||||
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
|
||||
|
||||
assert callable(_lse_weighted_combine)
|
||||
|
||||
def test_single_rank(self):
|
||||
"""Single rank: output unchanged."""
|
||||
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
|
||||
|
||||
# N=1, B=2, H=4, D=8
|
||||
outputs = torch.randn(1, 2, 4, 8)
|
||||
lses = torch.randn(1, 2, 4)
|
||||
|
||||
result = _lse_weighted_combine(outputs, lses)
|
||||
|
||||
assert result.shape == (2, 4, 8)
|
||||
torch.testing.assert_close(result, outputs.squeeze(0), rtol=1e-5, atol=1e-5)
|
||||
|
||||
def test_equal_lse(self):
|
||||
"""Equal LSE values: outputs averaged equally."""
|
||||
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
|
||||
|
||||
_N, B, H, D = 2, 1, 1, 4
|
||||
outputs = torch.tensor(
|
||||
[
|
||||
[[[1.0, 2.0, 3.0, 4.0]]], # Rank 0
|
||||
[[[5.0, 6.0, 7.0, 8.0]]], # Rank 1
|
||||
]
|
||||
)
|
||||
lses = torch.tensor(
|
||||
[
|
||||
[[0.0]], # Rank 0
|
||||
[[0.0]], # Rank 1
|
||||
]
|
||||
)
|
||||
|
||||
result = _lse_weighted_combine(outputs, lses)
|
||||
|
||||
expected = (outputs[0] + outputs[1]) / 2
|
||||
assert result.shape == (B, H, D)
|
||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5)
|
||||
|
||||
def test_dominant_rank(self):
|
||||
"""Different LSE values: larger LSE gets more weight."""
|
||||
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
|
||||
|
||||
B, H, D = 1, 1, 2
|
||||
outputs = torch.tensor(
|
||||
[
|
||||
[[[0.0, 0.0]]], # Rank 0
|
||||
[[[1.0, 1.0]]], # Rank 1
|
||||
]
|
||||
)
|
||||
lses = torch.tensor(
|
||||
[
|
||||
[[-100.0]], # Rank 0: negligible contribution
|
||||
[[0.0]], # Rank 1: dominant
|
||||
]
|
||||
)
|
||||
|
||||
result = _lse_weighted_combine(outputs, lses)
|
||||
|
||||
assert result.shape == (B, H, D)
|
||||
torch.testing.assert_close(result, outputs[1].squeeze(0), atol=1e-5, rtol=1e-5)
|
||||
|
||||
def test_mathematically_correct(self):
|
||||
"""Verify mathematical correctness of LSE combination."""
|
||||
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
|
||||
|
||||
outputs = torch.tensor(
|
||||
[
|
||||
[[[2.0, 4.0]]],
|
||||
[[[6.0, 8.0]]],
|
||||
]
|
||||
)
|
||||
lses = torch.tensor(
|
||||
[
|
||||
[[1.0]], # exp(1) ≈ 2.718
|
||||
[[2.0]], # exp(2) ≈ 7.389
|
||||
]
|
||||
)
|
||||
|
||||
result = _lse_weighted_combine(outputs, lses)
|
||||
|
||||
w0 = math.exp(1) / (math.exp(1) + math.exp(2))
|
||||
w1 = math.exp(2) / (math.exp(1) + math.exp(2))
|
||||
expected = torch.tensor([[[w0 * 2.0 + w1 * 6.0, w0 * 4.0 + w1 * 8.0]]])
|
||||
|
||||
torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_return_lse(self):
|
||||
"""return_lse=True returns global LSE (logsumexp of inputs)."""
|
||||
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
|
||||
|
||||
B, H, D = 1, 1, 2
|
||||
outputs = torch.tensor(
|
||||
[
|
||||
[[[1.0, 2.0]]],
|
||||
[[[3.0, 4.0]]],
|
||||
]
|
||||
)
|
||||
lses = torch.tensor(
|
||||
[
|
||||
[[1.0]],
|
||||
[[2.0]],
|
||||
]
|
||||
)
|
||||
|
||||
result, global_lse = _lse_weighted_combine(outputs, lses, return_lse=True)
|
||||
|
||||
expected_global_lse = math.log(math.exp(1) + math.exp(2))
|
||||
|
||||
assert result.shape == (B, H, D)
|
||||
assert global_lse.shape == (B, H)
|
||||
assert abs(global_lse.item() - expected_global_lse) < 1e-5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user