[Core] Add All-to-All communication backend for DCP (#34883)

Signed-off-by: Sungsoo Ha <sungsooh@nvidia.com>
Signed-off-by: sungsoo ha <hasungsoo@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
sungsoo ha
2026-03-04 07:01:57 -08:00
committed by GitHub
parent ead7bde1ab
commit 6cb901093f
8 changed files with 658 additions and 17 deletions

View File

@@ -0,0 +1,192 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for DCP A2A communication backend (no GPU required).
Tests cover:
1. DCP A2A config validation (--dcp-comm-backend)
2. KVP group function exists
3. LSE-weighted combination correctness
"""
import math
import pytest
import torch
from vllm.config.parallel import ParallelConfig
class TestDCPCommBackendConfig:
"""Test --dcp-comm-backend config validation."""
def test_default_is_ag_rs(self):
"""Default comm backend is ag_rs."""
config = ParallelConfig()
assert config.dcp_comm_backend == "ag_rs"
def test_a2a_requires_dcp_greater_than_1(self):
"""A2A backend requires decode_context_parallel_size > 1."""
with pytest.raises(
ValueError, match="requires decode_context_parallel_size > 1"
):
ParallelConfig(
dcp_comm_backend="a2a",
decode_context_parallel_size=1,
)
def test_a2a_with_dcp_valid(self):
"""A2A backend is valid when DCP > 1."""
config = ParallelConfig(
dcp_comm_backend="a2a",
tensor_parallel_size=8,
decode_context_parallel_size=4,
)
assert config.dcp_comm_backend == "a2a"
def test_invalid_backend_rejected(self):
"""Invalid backend values are rejected."""
with pytest.raises(ValueError, match="must be one of"):
ParallelConfig(
dcp_comm_backend="invalid",
)
def test_ag_rs_with_dcp_1_valid(self):
"""ag_rs backend is valid with DCP=1 (no DCP)."""
config = ParallelConfig(
dcp_comm_backend="ag_rs",
decode_context_parallel_size=1,
)
assert config.dcp_comm_backend == "ag_rs"
class TestLSEWeightedCombine:
"""Test LSE-weighted combination logic (CPU only, no GPU).
The _lse_weighted_combine function is the reference implementation
that verifies the Triton kernel's correctness. It computes:
result[b,h,d] = sum_n(w_n * output_n[b,h,d])
where w_n = softmax(lse_n) = exp(lse_n) / sum_k(exp(lse_k))
"""
def test_importable(self):
"""Verify _lse_weighted_combine is importable."""
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
assert callable(_lse_weighted_combine)
def test_single_rank(self):
"""Single rank: output unchanged."""
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
# N=1, B=2, H=4, D=8
outputs = torch.randn(1, 2, 4, 8)
lses = torch.randn(1, 2, 4)
result = _lse_weighted_combine(outputs, lses)
assert result.shape == (2, 4, 8)
torch.testing.assert_close(result, outputs.squeeze(0), rtol=1e-5, atol=1e-5)
def test_equal_lse(self):
"""Equal LSE values: outputs averaged equally."""
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
_N, B, H, D = 2, 1, 1, 4
outputs = torch.tensor(
[
[[[1.0, 2.0, 3.0, 4.0]]], # Rank 0
[[[5.0, 6.0, 7.0, 8.0]]], # Rank 1
]
)
lses = torch.tensor(
[
[[0.0]], # Rank 0
[[0.0]], # Rank 1
]
)
result = _lse_weighted_combine(outputs, lses)
expected = (outputs[0] + outputs[1]) / 2
assert result.shape == (B, H, D)
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5)
def test_dominant_rank(self):
"""Different LSE values: larger LSE gets more weight."""
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
B, H, D = 1, 1, 2
outputs = torch.tensor(
[
[[[0.0, 0.0]]], # Rank 0
[[[1.0, 1.0]]], # Rank 1
]
)
lses = torch.tensor(
[
[[-100.0]], # Rank 0: negligible contribution
[[0.0]], # Rank 1: dominant
]
)
result = _lse_weighted_combine(outputs, lses)
assert result.shape == (B, H, D)
torch.testing.assert_close(result, outputs[1].squeeze(0), atol=1e-5, rtol=1e-5)
def test_mathematically_correct(self):
"""Verify mathematical correctness of LSE combination."""
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
outputs = torch.tensor(
[
[[[2.0, 4.0]]],
[[[6.0, 8.0]]],
]
)
lses = torch.tensor(
[
[[1.0]], # exp(1) ≈ 2.718
[[2.0]], # exp(2) ≈ 7.389
]
)
result = _lse_weighted_combine(outputs, lses)
w0 = math.exp(1) / (math.exp(1) + math.exp(2))
w1 = math.exp(2) / (math.exp(1) + math.exp(2))
expected = torch.tensor([[[w0 * 2.0 + w1 * 6.0, w0 * 4.0 + w1 * 8.0]]])
torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4)
def test_return_lse(self):
"""return_lse=True returns global LSE (logsumexp of inputs)."""
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
B, H, D = 1, 1, 2
outputs = torch.tensor(
[
[[[1.0, 2.0]]],
[[[3.0, 4.0]]],
]
)
lses = torch.tensor(
[
[[1.0]],
[[2.0]],
]
)
result, global_lse = _lse_weighted_combine(outputs, lses, return_lse=True)
expected_global_lse = math.log(math.exp(1) + math.exp(2))
assert result.shape == (B, H, D)
assert global_lse.shape == (B, H)
assert abs(global_lse.item() - expected_global_lse) < 1e-5
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -36,6 +36,7 @@ ExpertPlacementStrategy = Literal["linear", "round_robin"]
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"] DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
DataParallelBackend = Literal["ray", "mp"] DataParallelBackend = Literal["ray", "mp"]
EPLBPolicyOption = Literal["default"] EPLBPolicyOption = Literal["default"]
DCPCommBackend = Literal["ag_rs", "a2a"]
All2AllBackend = Literal[ All2AllBackend = Literal[
"naive", "naive",
"pplx", "pplx",
@@ -287,6 +288,14 @@ class ParallelConfig:
and will be deprecated when PCP is fully supported. and will be deprecated when PCP is fully supported.
""" """
dcp_comm_backend: DCPCommBackend = "ag_rs"
"""Communication backend for Decode Context Parallel (DCP).
- "ag_rs": AllGather + ReduceScatter (default, existing behavior)
- "a2a": All-to-All exchange of partial outputs + LSE, then
combine with Triton kernel. Reduces NCCL calls from 3 to 2
per layer for MLA models.
"""
cp_kv_cache_interleave_size: int = 1 cp_kv_cache_interleave_size: int = 1
"""Interleave size of kv_cache storage while using DCP or PCP. """Interleave size of kv_cache storage while using DCP or PCP.
For `total_cp_rank = pcp_rank * dcp_world_size + dcp_rank`, For `total_cp_rank = pcp_rank * dcp_world_size + dcp_rank`,
@@ -392,6 +401,11 @@ class ParallelConfig:
f"dcp_size={self.decode_context_parallel_size}." f"dcp_size={self.decode_context_parallel_size}."
) )
if self.dcp_comm_backend == "a2a" and self.decode_context_parallel_size <= 1:
raise ValueError(
"dcp_comm_backend='a2a' requires decode_context_parallel_size > 1."
)
return self return self
@property @property

View File

@@ -1645,6 +1645,8 @@ class VllmConfig:
f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}, " # noqa f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}, " # noqa
f"pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, " # noqa f"pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, " # noqa
f"data_parallel_size={self.parallel_config.data_parallel_size}, " # noqa f"data_parallel_size={self.parallel_config.data_parallel_size}, " # noqa
f"decode_context_parallel_size={self.parallel_config.decode_context_parallel_size}, " # noqa
f"dcp_comm_backend={self.parallel_config.dcp_comm_backend}, " # noqa
f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa
f"quantization={self.model_config.quantization}, " f"quantization={self.model_config.quantization}, "
f"enforce_eager={self.model_config.enforce_eager}, " f"enforce_eager={self.model_config.enforce_eager}, "

View File

@@ -85,6 +85,7 @@ from vllm.config.observability import DetailedTraceModules
from vllm.config.parallel import ( from vllm.config.parallel import (
All2AllBackend, All2AllBackend,
DataParallelBackend, DataParallelBackend,
DCPCommBackend,
DistributedExecutorBackend, DistributedExecutorBackend,
ExpertPlacementStrategy, ExpertPlacementStrategy,
) )
@@ -405,6 +406,7 @@ class EngineArgs:
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
prefill_context_parallel_size: int = ParallelConfig.prefill_context_parallel_size prefill_context_parallel_size: int = ParallelConfig.prefill_context_parallel_size
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
dcp_comm_backend: DCPCommBackend = ParallelConfig.dcp_comm_backend
dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
cp_kv_cache_interleave_size: int = ParallelConfig.cp_kv_cache_interleave_size cp_kv_cache_interleave_size: int = ParallelConfig.cp_kv_cache_interleave_size
data_parallel_size: int = ParallelConfig.data_parallel_size data_parallel_size: int = ParallelConfig.data_parallel_size
@@ -820,6 +822,10 @@ class EngineArgs:
"-dcp", "-dcp",
**parallel_kwargs["decode_context_parallel_size"], **parallel_kwargs["decode_context_parallel_size"],
) )
parallel_group.add_argument(
"--dcp-comm-backend",
**parallel_kwargs["dcp_comm_backend"],
)
parallel_group.add_argument( parallel_group.add_argument(
"--dcp-kv-cache-interleave-size", "--dcp-kv-cache-interleave-size",
**parallel_kwargs["dcp_kv_cache_interleave_size"], **parallel_kwargs["dcp_kv_cache_interleave_size"],
@@ -1720,6 +1726,7 @@ class EngineArgs:
worker_cls=self.worker_cls, worker_cls=self.worker_cls,
worker_extension_cls=self.worker_extension_cls, worker_extension_cls=self.worker_extension_cls,
decode_context_parallel_size=self.decode_context_parallel_size, decode_context_parallel_size=self.decode_context_parallel_size,
dcp_comm_backend=self.dcp_comm_backend,
dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size, dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size,
cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size, cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size,
_api_process_count=self._api_process_count, _api_process_count=self._api_process_count,

View File

@@ -203,8 +203,17 @@ from tqdm import tqdm
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config from vllm.config import (
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank CacheConfig,
ModelConfig,
VllmConfig,
get_current_vllm_config,
get_current_vllm_config_or_none,
)
from vllm.distributed.parallel_state import (
get_dcp_group,
is_global_first_rank,
)
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
@@ -253,6 +262,7 @@ from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills, split_decodes_and_prefills,
) )
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
from vllm.v1.attention.ops.dcp_alltoall import dcp_a2a_lse_reduce
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
from vllm.v1.attention.selector import get_attn_backend from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import ( from vllm.v1.kv_cache_interface import (
@@ -393,6 +403,13 @@ class MLAAttention(nn.Module, AttentionLayerBase):
self.use_sparse = use_sparse self.use_sparse = use_sparse
vllm_config = get_current_vllm_config_or_none()
self.dcp_a2a = (
vllm_config is not None
and vllm_config.parallel_config.decode_context_parallel_size > 1
and vllm_config.parallel_config.dcp_comm_backend == "a2a"
)
# Initialize q/k/v range constants. # Initialize q/k/v range constants.
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
@@ -647,6 +664,14 @@ class MLAAttention(nn.Module, AttentionLayerBase):
# correct dcp attn_out with lse. # correct dcp attn_out with lse.
if self.impl.dcp_world_size > 1: if self.impl.dcp_world_size > 1:
if self.dcp_a2a:
attn_out = dcp_a2a_lse_reduce(
attn_out,
lse,
get_dcp_group(),
is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False),
)
else:
attn_out = cp_lse_ag_out_rs( attn_out = cp_lse_ag_out_rs(
attn_out, attn_out,
lse, lse,

View File

@@ -23,6 +23,7 @@ from vllm.v1.attention.backends.fa_utils import (
is_flash_attn_varlen_func_available, is_flash_attn_varlen_func_available,
) )
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
from vllm.v1.attention.ops.dcp_alltoall import dcp_a2a_lse_reduce
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
if is_flash_attn_varlen_func_available(): if is_flash_attn_varlen_func_available():
@@ -32,7 +33,12 @@ if is_flash_attn_varlen_func_available():
get_scheduler_metadata, get_scheduler_metadata,
reshape_and_cache_flash, reshape_and_cache_flash,
) )
from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config from vllm.config import (
VllmConfig,
get_current_vllm_config,
get_current_vllm_config_or_none,
get_layers_from_vllm_config,
)
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger from vllm.logger import init_logger
@@ -609,6 +615,14 @@ class FlashAttentionImpl(AttentionImpl):
self.supports_quant_query_input = True self.supports_quant_query_input = True
vllm_config = get_current_vllm_config_or_none()
dcp_a2a = (
vllm_config is not None
and vllm_config.parallel_config.decode_context_parallel_size > 1
and vllm_config.parallel_config.dcp_comm_backend == "a2a"
)
self.dcp_combine = dcp_a2a_lse_reduce if dcp_a2a else cp_lse_ag_out_rs
def forward( def forward(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
@@ -857,8 +871,8 @@ class FlashAttentionImpl(AttentionImpl):
v_descale=v_descale, v_descale=v_descale,
num_splits=attn_metadata.max_num_splits, num_splits=attn_metadata.max_num_splits,
) )
# FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ] # FA returns LSE in shape [ H, B ] but DCP combine wants [ B, H ]
context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs( context_attn_out_cor, context_lse_cor = self.dcp_combine(
context_attn_out, context_attn_out,
context_lse.transpose(0, 1), context_lse.transpose(0, 1),
get_dcp_group(), get_dcp_group(),

View File

@@ -3,6 +3,7 @@
"""Attention layer with FlashInfer.""" """Attention layer with FlashInfer."""
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial
from typing import ClassVar from typing import ClassVar
import numpy as np import numpy as np
@@ -19,7 +20,11 @@ from flashinfer.utils import FP4Tensor
from typing_extensions import override from typing_extensions import override
from vllm import envs from vllm import envs
from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config from vllm.config import (
CUDAGraphMode,
VllmConfig,
get_current_vllm_config_or_none,
)
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger from vllm.logger import init_logger
@@ -59,6 +64,7 @@ from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills, split_decodes_and_prefills,
) )
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
from vllm.v1.attention.ops.dcp_alltoall import dcp_a2a_lse_reduce
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
from vllm.v1.kv_cache_interface import AttentionSpec, UniformTypeKVCacheSpecs from vllm.v1.kv_cache_interface import AttentionSpec, UniformTypeKVCacheSpecs
from vllm.v1.utils import CpuGpuBuffer from vllm.v1.utils import CpuGpuBuffer
@@ -170,7 +176,12 @@ class BatchDCPPrefillWrapper:
def __init__( def __init__(
self, self,
workspace_buffer: torch.Tensor | None = None, workspace_buffer: torch.Tensor | None = None,
dcp_a2a: bool = False,
): ):
if dcp_a2a:
self._dcp_combine = partial(dcp_a2a_lse_reduce, is_lse_base_on_e=False)
else:
self._dcp_combine = partial(cp_lse_ag_out_rs, is_lse_base_on_e=False)
self._context = BatchPrefillWithPagedKVCacheWrapper( self._context = BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, get_kv_cache_layout() workspace_buffer, get_kv_cache_layout()
) )
@@ -249,12 +260,11 @@ class BatchDCPPrefillWrapper:
v_scale=layer._v_scale_float, v_scale=layer._v_scale_float,
return_lse=True, return_lse=True,
) )
output_context, lse_context = cp_lse_ag_out_rs( output_context, lse_context = self._dcp_combine(
output_context_tmp, output_context_tmp,
lse_context_tmp, lse_context_tmp,
get_dcp_group(), get_dcp_group(),
return_lse=True, return_lse=True,
is_lse_base_on_e=False,
) )
lse_context = lse_context.transpose(0, 1).contiguous() lse_context = lse_context.transpose(0, 1).contiguous()
@@ -550,6 +560,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.dcp_rank = 0 self.dcp_rank = 0
self.dcp_kv_cache_interleave_size = 1 self.dcp_kv_cache_interleave_size = 1
self.use_dcp = self.dcp_world_size > 1 self.use_dcp = self.dcp_world_size > 1
self.dcp_a2a = (
self.use_dcp and vllm_config.parallel_config.dcp_comm_backend == "a2a"
)
self.num_qo_heads = self.model_config.get_num_attention_heads( self.num_qo_heads = self.model_config.get_num_attention_heads(
self.vllm_config.parallel_config self.vllm_config.parallel_config
@@ -699,6 +712,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
if self.use_dcp: if self.use_dcp:
self._prefill_wrapper = BatchDCPPrefillWrapper( self._prefill_wrapper = BatchDCPPrefillWrapper(
workspace_buffer=self._get_workspace_buffer(), workspace_buffer=self._get_workspace_buffer(),
dcp_a2a=self.dcp_a2a,
) )
else: else:
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
@@ -1208,15 +1222,26 @@ class FlashInferImpl(AttentionImpl):
self.sinks = sinks self.sinks = sinks
self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads) self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads)
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config_or_none()
self.supports_quant_query_input = ( self.supports_quant_query_input = (
self.support_trtllm_attn self.support_trtllm_attn
and vllm_config is not None
and not vllm_config.attention_config.disable_flashinfer_q_quantization and not vllm_config.attention_config.disable_flashinfer_q_quantization
) )
self.bmm1_scale: float | None = None self.bmm1_scale: float | None = None
self.bmm2_scale: float | None = None self.bmm2_scale: float | None = None
self.o_sf_scale: float | None = None self.o_sf_scale: float | None = None
dcp_a2a = (
vllm_config is not None
and vllm_config.parallel_config.decode_context_parallel_size > 1
and vllm_config.parallel_config.dcp_comm_backend == "a2a"
)
if dcp_a2a:
self.dcp_combine = partial(dcp_a2a_lse_reduce, is_lse_base_on_e=False)
else:
self.dcp_combine = partial(cp_lse_ag_out_rs, is_lse_base_on_e=False)
def fused_output_quant_supported(self, quant_key: QuantKey): def fused_output_quant_supported(self, quant_key: QuantKey):
return ( return (
self.support_trtllm_attn self.support_trtllm_attn
@@ -1503,11 +1528,10 @@ class FlashInferImpl(AttentionImpl):
lse=lse, lse=lse,
return_lse=True, return_lse=True,
) )
output[:num_decode_tokens] = cp_lse_ag_out_rs( output[:num_decode_tokens] = self.dcp_combine(
output_tmp, output_tmp,
lse, lse,
get_dcp_group(), get_dcp_group(),
is_lse_base_on_e=False,
) )
else: else:
decode_wrapper.run( decode_wrapper.run(

View File

@@ -0,0 +1,363 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
DCP All-to-All communication backend for attention.
Provides All-to-All (A2A) communication as an alternative to
AllGather + ReduceScatter (AG+RS) for Decode Context Parallel (DCP).
Instead of gathering the full Q tensor and scattering partial outputs,
A2A exchanges partial attention outputs and their LSE values across
ranks, then combines them with exact LSE-weighted reduction.
This reduces the number of NCCL calls per attention layer from 3
(AG for Q, AG for K metadata, RS for output) to 2 (A2A for output,
A2A for LSE), lowering per-step communication overhead for long-context
decode where NCCL latency is a significant fraction of step time.
Usage:
vllm serve model --tp 16 --dcp 16 --dcp-comm-backend a2a
Reference: https://arxiv.org/abs/2507.07120
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
import torch.distributed as dist
from vllm.triton_utils import tl, triton
if TYPE_CHECKING:
from vllm.distributed.parallel_state import GroupCoordinator
from vllm.v1.attention.ops.common import CPTritonContext
def _lse_weighted_combine(
outputs: torch.Tensor,
lses: torch.Tensor,
return_lse: bool = False,
is_lse_base_on_e: bool = True,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
CPU reference implementation for LSE-weighted combination.
This is a pure PyTorch implementation used for testing and validation.
For GPU execution, use dcp_lse_combine_triton instead.
Args:
outputs: Partial attention outputs [N, B, H, D]
N = number of KV shards (ranks)
B = batch size (num_tokens)
H = number of heads per rank
D = head dimension
lses: Log-sum-exp values [N, B, H]
return_lse: If True, also return the global LSE
is_lse_base_on_e: If True, LSE is base e; if False, base 2
Returns:
Combined output [B, H, D], and optionally global LSE [B, H]
"""
N, B, H, D = outputs.shape
# Handle NaN and inf in LSEs
lses = torch.where(
torch.isnan(lses) | torch.isinf(lses),
torch.tensor(float("-inf"), device=lses.device, dtype=lses.dtype),
lses,
)
# Compute max LSE for numerical stability
lse_max, _ = lses.max(dim=0) # [B, H]
lse_max = torch.where(
lse_max == float("-inf"),
torch.zeros_like(lse_max),
lse_max,
)
# Compute weights: softmax over the N dimension
if is_lse_base_on_e:
weights = torch.exp(lses - lse_max.unsqueeze(0)) # [N, B, H]
else:
weights = torch.pow(2.0, lses - lse_max.unsqueeze(0)) # [N, B, H]
# Handle NaN weights
weights = torch.where(torch.isnan(weights), torch.zeros_like(weights), weights)
# Normalize weights
weight_sum = weights.sum(dim=0, keepdim=True) # [1, B, H]
weights = weights / weight_sum.clamp(min=1e-10) # [N, B, H]
# Weighted combination: sum over N dimension
result = (outputs * weights.unsqueeze(-1)).sum(dim=0) # [B, H, D]
if return_lse:
if is_lse_base_on_e:
global_lse = torch.log(weight_sum.squeeze(0)) + lse_max # [B, H]
else:
global_lse = torch.log2(weight_sum.squeeze(0)) + lse_max # [B, H]
return result, global_lse
return result
@triton.jit
def _dcp_lse_combine_kernel(
# Input pointers
recv_output_ptr,
recv_lse_ptr,
# Output pointers
out_ptr,
out_lse_ptr,
# Strides for recv_output [N, B, H_local, D]
ro_stride_N,
ro_stride_B,
ro_stride_H,
ro_stride_D,
# Strides for recv_lse [N, B, H_local]
rl_stride_N,
rl_stride_B,
rl_stride_H,
# Strides for output [B, H_local, D]
o_stride_B,
o_stride_H,
o_stride_D,
# Constants
N: tl.constexpr,
HEAD_DIM: tl.constexpr,
IS_BASE_E: tl.constexpr,
RETURN_LSE: tl.constexpr,
):
"""
Triton kernel for LSE-weighted combination of partial attention outputs.
After All-to-All, each rank has:
- recv_output [N, B, H_local, D]: partial outputs from all KV shards
- recv_lse [N, B, H_local]: partial LSEs from all KV shards
This kernel computes the weighted combination locally (no communication).
Grid: (B, H_local)
Each program handles one (batch, head) and processes all D elements.
"""
batch_idx = tl.program_id(0).to(tl.int64)
head_idx = tl.program_id(1).to(tl.int64)
# Base offset for this (batch, head)
base_lse_offset = batch_idx * rl_stride_B + head_idx * rl_stride_H
base_out_offset = batch_idx * ro_stride_B + head_idx * ro_stride_H
# First pass: find max LSE for numerical stability
lse_max = -float("inf")
for n in tl.static_range(N):
lse_offset = n * rl_stride_N + base_lse_offset
lse_val = tl.load(recv_lse_ptr + lse_offset)
lse_val = tl.where(
(lse_val != lse_val) | (lse_val == float("inf")),
-float("inf"),
lse_val,
)
lse_max = tl.maximum(lse_max, lse_val)
lse_max = tl.where(lse_max == -float("inf"), 0.0, lse_max)
# Second pass: compute sum of exp(lse - max)
lse_sum = 0.0
for n in tl.static_range(N):
lse_offset = n * rl_stride_N + base_lse_offset
lse_val = tl.load(recv_lse_ptr + lse_offset)
lse_val = tl.where(
(lse_val != lse_val) | (lse_val == float("inf")),
-float("inf"),
lse_val,
)
if IS_BASE_E:
lse_sum += tl.exp(lse_val - lse_max)
else:
lse_sum += tl.exp2(lse_val - lse_max)
# Compute global LSE
if IS_BASE_E: # noqa: SIM108
global_lse = tl.log(lse_sum) + lse_max
else:
global_lse = tl.log2(lse_sum) + lse_max
# Third pass: weighted combination across D dimension
d_offsets = tl.arange(0, HEAD_DIM)
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
for n in tl.static_range(N):
lse_offset = n * rl_stride_N + base_lse_offset
lse_val = tl.load(recv_lse_ptr + lse_offset)
lse_val = tl.where(
(lse_val != lse_val) | (lse_val == float("inf")),
-float("inf"),
lse_val,
)
if IS_BASE_E:
weight = tl.exp(lse_val - global_lse)
else:
weight = tl.exp2(lse_val - global_lse)
weight = tl.where(weight != weight, 0.0, weight)
out_offsets = n * ro_stride_N + base_out_offset + d_offsets * ro_stride_D
out_vals = tl.load(recv_output_ptr + out_offsets)
acc += out_vals.to(tl.float32) * weight
# Store result
final_offsets = (
batch_idx * o_stride_B + head_idx * o_stride_H + d_offsets * o_stride_D
)
tl.store(out_ptr + final_offsets, acc)
if RETURN_LSE:
tl.store(out_lse_ptr + base_lse_offset, global_lse)
def dcp_lse_combine_triton(
recv_output: torch.Tensor,
recv_lse: torch.Tensor,
return_lse: bool = False,
is_lse_base_on_e: bool = True,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
Triton-accelerated LSE-weighted combination for DCP A2A.
Args:
recv_output: [N, B, H_local, D] - partial outputs from all KV shards
recv_lse: [N, B, H_local] - partial LSEs from all KV shards
return_lse: If True, also return the global LSE
is_lse_base_on_e: If True, LSE is base e; if False, base 2
Returns:
Combined output [B, H_local, D]
If return_lse=True, also returns global_lse [B, H_local]
"""
N, B, H_local, D = recv_output.shape
out = torch.empty(
(B, H_local, D), device=recv_output.device, dtype=recv_output.dtype
)
if return_lse:
out_lse = torch.empty(
(B, H_local), device=recv_lse.device, dtype=recv_lse.dtype
)
else:
out_lse = torch.empty(1, device=recv_lse.device, dtype=recv_lse.dtype)
ro_stride_N, ro_stride_B, ro_stride_H, ro_stride_D = recv_output.stride()
rl_stride_N, rl_stride_B, rl_stride_H = recv_lse.stride()
o_stride_B, o_stride_H, o_stride_D = out.stride()
grid = (B, H_local, 1)
_dcp_lse_combine_kernel[grid](
recv_output,
recv_lse,
out,
out_lse,
ro_stride_N,
ro_stride_B,
ro_stride_H,
ro_stride_D,
rl_stride_N,
rl_stride_B,
rl_stride_H,
o_stride_B,
o_stride_H,
o_stride_D,
N=N,
HEAD_DIM=D,
IS_BASE_E=is_lse_base_on_e,
RETURN_LSE=return_lse,
)
if return_lse:
return out, out_lse
return out
def dcp_a2a_lse_reduce(
cp_attn_out: torch.Tensor,
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext | None = None,
return_lse: bool = False,
is_lse_base_on_e: bool = True,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
Combine partial attention outputs across DCP ranks using All-to-All.
Each rank holds attention output for all heads but only a local shard
of the KV cache. This function:
1. Exchanges partial outputs across ranks via All-to-All
2. Exchanges LSE values via All-to-All
3. Combines them with exact LSE-weighted reduction (Triton kernel)
Tensor flow:
Input: cp_attn_out [B, H, D] - all heads, local KV shard
Reshape: [N, B, H/N, D] - split heads across ranks
A2A: Two all_to_all_single calls (output and LSE)
Combine: recv [N, B, H/N, D] + lse [N, B, H/N] -> [B, H/N, D]
Args:
cp_attn_out: [B, H, D] where B=num_tokens, H=total_heads, D=head_dim
cp_attn_lse: [B, H] log-sum-exp values (fp32)
cp_group: GroupCoordinator for DCP communication
ctx: CPTritonContext (unused, for signature compatibility)
return_lse: If True, also return the combined global LSE
is_lse_base_on_e: If True, LSE is base e; if False, base 2
Returns:
Combined output [B, H/N, D] (head-scattered)
If return_lse=True, also returns global_lse [B, H/N]
"""
world_size = cp_group.world_size
if world_size == 1:
if return_lse:
return cp_attn_out, cp_attn_lse
return cp_attn_out
local_output = cp_attn_out.contiguous()
local_lse = cp_attn_lse.contiguous()
B, H, D = local_output.shape
H_per_rank = H // world_size
# Reshape for All-to-All: [B, H, D] -> [N, B, H/N, D]
# Split heads into N chunks, each destined for a different rank
send_output = (
local_output.view(B, world_size, H_per_rank, D).permute(1, 0, 2, 3).contiguous()
)
recv_output = torch.empty_like(send_output)
# Same for LSE: [B, H] -> [N, B, H/N]
send_lse = local_lse.view(B, world_size, H_per_rank).permute(1, 0, 2).contiguous()
recv_lse = torch.empty_like(send_lse)
# All-to-All for partial attention outputs and LSE values (async overlap)
work_output = dist.all_to_all_single(
recv_output.view(-1),
send_output.view(-1),
group=cp_group.device_group,
async_op=True,
)
work_lse = dist.all_to_all_single(
recv_lse.view(-1),
send_lse.view(-1),
group=cp_group.device_group,
async_op=True,
)
work_output.wait()
work_lse.wait()
# LSE-weighted combination via Triton kernel (local, no communication)
return dcp_lse_combine_triton(
recv_output,
recv_lse,
return_lse=return_lse,
is_lse_base_on_e=is_lse_base_on_e,
)