# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Unit tests for DCP A2A communication backend (no GPU required). Tests cover: 1. DCP A2A config validation (--dcp-comm-backend) 2. KVP group function exists 3. LSE-weighted combination correctness """ import math import pytest import torch from vllm.config.parallel import ParallelConfig class TestDCPCommBackendConfig: """Test --dcp-comm-backend config validation.""" def test_default_is_ag_rs(self): """Default comm backend is ag_rs.""" config = ParallelConfig() assert config.dcp_comm_backend == "ag_rs" def test_a2a_requires_dcp_greater_than_1(self): """A2A backend requires decode_context_parallel_size > 1.""" with pytest.raises( ValueError, match="requires decode_context_parallel_size > 1" ): ParallelConfig( dcp_comm_backend="a2a", decode_context_parallel_size=1, ) def test_a2a_with_dcp_valid(self): """A2A backend is valid when DCP > 1.""" config = ParallelConfig( dcp_comm_backend="a2a", tensor_parallel_size=8, decode_context_parallel_size=4, ) assert config.dcp_comm_backend == "a2a" def test_invalid_backend_rejected(self): """Invalid backend values are rejected.""" with pytest.raises(ValueError, match="must be one of"): ParallelConfig( dcp_comm_backend="invalid", ) def test_ag_rs_with_dcp_1_valid(self): """ag_rs backend is valid with DCP=1 (no DCP).""" config = ParallelConfig( dcp_comm_backend="ag_rs", decode_context_parallel_size=1, ) assert config.dcp_comm_backend == "ag_rs" class TestLSEWeightedCombine: """Test LSE-weighted combination logic (CPU only, no GPU). The _lse_weighted_combine function is the reference implementation that verifies the Triton kernel's correctness. It computes: result[b,h,d] = sum_n(w_n * output_n[b,h,d]) where w_n = softmax(lse_n) = exp(lse_n) / sum_k(exp(lse_k)) """ def test_importable(self): """Verify _lse_weighted_combine is importable.""" from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine assert callable(_lse_weighted_combine) def test_single_rank(self): """Single rank: output unchanged.""" from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine # N=1, B=2, H=4, D=8 outputs = torch.randn(1, 2, 4, 8) lses = torch.randn(1, 2, 4) result = _lse_weighted_combine(outputs, lses) assert result.shape == (2, 4, 8) torch.testing.assert_close(result, outputs.squeeze(0), rtol=1e-5, atol=1e-5) def test_equal_lse(self): """Equal LSE values: outputs averaged equally.""" from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine _N, B, H, D = 2, 1, 1, 4 outputs = torch.tensor( [ [[[1.0, 2.0, 3.0, 4.0]]], # Rank 0 [[[5.0, 6.0, 7.0, 8.0]]], # Rank 1 ] ) lses = torch.tensor( [ [[0.0]], # Rank 0 [[0.0]], # Rank 1 ] ) result = _lse_weighted_combine(outputs, lses) expected = (outputs[0] + outputs[1]) / 2 assert result.shape == (B, H, D) torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5) def test_dominant_rank(self): """Different LSE values: larger LSE gets more weight.""" from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine B, H, D = 1, 1, 2 outputs = torch.tensor( [ [[[0.0, 0.0]]], # Rank 0 [[[1.0, 1.0]]], # Rank 1 ] ) lses = torch.tensor( [ [[-100.0]], # Rank 0: negligible contribution [[0.0]], # Rank 1: dominant ] ) result = _lse_weighted_combine(outputs, lses) assert result.shape == (B, H, D) torch.testing.assert_close(result, outputs[1].squeeze(0), atol=1e-5, rtol=1e-5) def test_mathematically_correct(self): """Verify mathematical correctness of LSE combination.""" from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine outputs = torch.tensor( [ [[[2.0, 4.0]]], [[[6.0, 8.0]]], ] ) lses = torch.tensor( [ [[1.0]], # exp(1) ≈ 2.718 [[2.0]], # exp(2) ≈ 7.389 ] ) result = _lse_weighted_combine(outputs, lses) w0 = math.exp(1) / (math.exp(1) + math.exp(2)) w1 = math.exp(2) / (math.exp(1) + math.exp(2)) expected = torch.tensor([[[w0 * 2.0 + w1 * 6.0, w0 * 4.0 + w1 * 8.0]]]) torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) def test_return_lse(self): """return_lse=True returns global LSE (logsumexp of inputs).""" from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine B, H, D = 1, 1, 2 outputs = torch.tensor( [ [[[1.0, 2.0]]], [[[3.0, 4.0]]], ] ) lses = torch.tensor( [ [[1.0]], [[2.0]], ] ) result, global_lse = _lse_weighted_combine(outputs, lses, return_lse=True) expected_global_lse = math.log(math.exp(1) + math.exp(2)) assert result.shape == (B, H, D) assert global_lse.shape == (B, H) assert abs(global_lse.item() - expected_global_lse) < 1e-5 if __name__ == "__main__": pytest.main([__file__, "-v"])