[Core] Add All-to-All communication backend for DCP (#34883)
Signed-off-by: Sungsoo Ha <sungsooh@nvidia.com> Signed-off-by: sungsoo ha <hasungsoo@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
192
tests/distributed/test_dcp_a2a.py
Normal file
192
tests/distributed/test_dcp_a2a.py
Normal file
@@ -0,0 +1,192 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Unit tests for DCP A2A communication backend (no GPU required).
|
||||
|
||||
Tests cover:
|
||||
1. DCP A2A config validation (--dcp-comm-backend)
|
||||
2. KVP group function exists
|
||||
3. LSE-weighted combination correctness
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
|
||||
|
||||
class TestDCPCommBackendConfig:
|
||||
"""Test --dcp-comm-backend config validation."""
|
||||
|
||||
def test_default_is_ag_rs(self):
|
||||
"""Default comm backend is ag_rs."""
|
||||
config = ParallelConfig()
|
||||
assert config.dcp_comm_backend == "ag_rs"
|
||||
|
||||
def test_a2a_requires_dcp_greater_than_1(self):
|
||||
"""A2A backend requires decode_context_parallel_size > 1."""
|
||||
with pytest.raises(
|
||||
ValueError, match="requires decode_context_parallel_size > 1"
|
||||
):
|
||||
ParallelConfig(
|
||||
dcp_comm_backend="a2a",
|
||||
decode_context_parallel_size=1,
|
||||
)
|
||||
|
||||
def test_a2a_with_dcp_valid(self):
|
||||
"""A2A backend is valid when DCP > 1."""
|
||||
config = ParallelConfig(
|
||||
dcp_comm_backend="a2a",
|
||||
tensor_parallel_size=8,
|
||||
decode_context_parallel_size=4,
|
||||
)
|
||||
assert config.dcp_comm_backend == "a2a"
|
||||
|
||||
def test_invalid_backend_rejected(self):
|
||||
"""Invalid backend values are rejected."""
|
||||
with pytest.raises(ValueError, match="must be one of"):
|
||||
ParallelConfig(
|
||||
dcp_comm_backend="invalid",
|
||||
)
|
||||
|
||||
def test_ag_rs_with_dcp_1_valid(self):
|
||||
"""ag_rs backend is valid with DCP=1 (no DCP)."""
|
||||
config = ParallelConfig(
|
||||
dcp_comm_backend="ag_rs",
|
||||
decode_context_parallel_size=1,
|
||||
)
|
||||
assert config.dcp_comm_backend == "ag_rs"
|
||||
|
||||
|
||||
class TestLSEWeightedCombine:
|
||||
"""Test LSE-weighted combination logic (CPU only, no GPU).
|
||||
|
||||
The _lse_weighted_combine function is the reference implementation
|
||||
that verifies the Triton kernel's correctness. It computes:
|
||||
|
||||
result[b,h,d] = sum_n(w_n * output_n[b,h,d])
|
||||
|
||||
where w_n = softmax(lse_n) = exp(lse_n) / sum_k(exp(lse_k))
|
||||
"""
|
||||
|
||||
def test_importable(self):
|
||||
"""Verify _lse_weighted_combine is importable."""
|
||||
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
|
||||
|
||||
assert callable(_lse_weighted_combine)
|
||||
|
||||
def test_single_rank(self):
|
||||
"""Single rank: output unchanged."""
|
||||
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
|
||||
|
||||
# N=1, B=2, H=4, D=8
|
||||
outputs = torch.randn(1, 2, 4, 8)
|
||||
lses = torch.randn(1, 2, 4)
|
||||
|
||||
result = _lse_weighted_combine(outputs, lses)
|
||||
|
||||
assert result.shape == (2, 4, 8)
|
||||
torch.testing.assert_close(result, outputs.squeeze(0), rtol=1e-5, atol=1e-5)
|
||||
|
||||
def test_equal_lse(self):
|
||||
"""Equal LSE values: outputs averaged equally."""
|
||||
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
|
||||
|
||||
_N, B, H, D = 2, 1, 1, 4
|
||||
outputs = torch.tensor(
|
||||
[
|
||||
[[[1.0, 2.0, 3.0, 4.0]]], # Rank 0
|
||||
[[[5.0, 6.0, 7.0, 8.0]]], # Rank 1
|
||||
]
|
||||
)
|
||||
lses = torch.tensor(
|
||||
[
|
||||
[[0.0]], # Rank 0
|
||||
[[0.0]], # Rank 1
|
||||
]
|
||||
)
|
||||
|
||||
result = _lse_weighted_combine(outputs, lses)
|
||||
|
||||
expected = (outputs[0] + outputs[1]) / 2
|
||||
assert result.shape == (B, H, D)
|
||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5)
|
||||
|
||||
def test_dominant_rank(self):
|
||||
"""Different LSE values: larger LSE gets more weight."""
|
||||
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
|
||||
|
||||
B, H, D = 1, 1, 2
|
||||
outputs = torch.tensor(
|
||||
[
|
||||
[[[0.0, 0.0]]], # Rank 0
|
||||
[[[1.0, 1.0]]], # Rank 1
|
||||
]
|
||||
)
|
||||
lses = torch.tensor(
|
||||
[
|
||||
[[-100.0]], # Rank 0: negligible contribution
|
||||
[[0.0]], # Rank 1: dominant
|
||||
]
|
||||
)
|
||||
|
||||
result = _lse_weighted_combine(outputs, lses)
|
||||
|
||||
assert result.shape == (B, H, D)
|
||||
torch.testing.assert_close(result, outputs[1].squeeze(0), atol=1e-5, rtol=1e-5)
|
||||
|
||||
def test_mathematically_correct(self):
|
||||
"""Verify mathematical correctness of LSE combination."""
|
||||
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
|
||||
|
||||
outputs = torch.tensor(
|
||||
[
|
||||
[[[2.0, 4.0]]],
|
||||
[[[6.0, 8.0]]],
|
||||
]
|
||||
)
|
||||
lses = torch.tensor(
|
||||
[
|
||||
[[1.0]], # exp(1) ≈ 2.718
|
||||
[[2.0]], # exp(2) ≈ 7.389
|
||||
]
|
||||
)
|
||||
|
||||
result = _lse_weighted_combine(outputs, lses)
|
||||
|
||||
w0 = math.exp(1) / (math.exp(1) + math.exp(2))
|
||||
w1 = math.exp(2) / (math.exp(1) + math.exp(2))
|
||||
expected = torch.tensor([[[w0 * 2.0 + w1 * 6.0, w0 * 4.0 + w1 * 8.0]]])
|
||||
|
||||
torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_return_lse(self):
|
||||
"""return_lse=True returns global LSE (logsumexp of inputs)."""
|
||||
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
|
||||
|
||||
B, H, D = 1, 1, 2
|
||||
outputs = torch.tensor(
|
||||
[
|
||||
[[[1.0, 2.0]]],
|
||||
[[[3.0, 4.0]]],
|
||||
]
|
||||
)
|
||||
lses = torch.tensor(
|
||||
[
|
||||
[[1.0]],
|
||||
[[2.0]],
|
||||
]
|
||||
)
|
||||
|
||||
result, global_lse = _lse_weighted_combine(outputs, lses, return_lse=True)
|
||||
|
||||
expected_global_lse = math.log(math.exp(1) + math.exp(2))
|
||||
|
||||
assert result.shape == (B, H, D)
|
||||
assert global_lse.shape == (B, H)
|
||||
assert abs(global_lse.item() - expected_global_lse) < 1e-5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -36,6 +36,7 @@ ExpertPlacementStrategy = Literal["linear", "round_robin"]
|
||||
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
|
||||
DataParallelBackend = Literal["ray", "mp"]
|
||||
EPLBPolicyOption = Literal["default"]
|
||||
DCPCommBackend = Literal["ag_rs", "a2a"]
|
||||
All2AllBackend = Literal[
|
||||
"naive",
|
||||
"pplx",
|
||||
@@ -287,6 +288,14 @@ class ParallelConfig:
|
||||
and will be deprecated when PCP is fully supported.
|
||||
|
||||
"""
|
||||
dcp_comm_backend: DCPCommBackend = "ag_rs"
|
||||
"""Communication backend for Decode Context Parallel (DCP).
|
||||
- "ag_rs": AllGather + ReduceScatter (default, existing behavior)
|
||||
- "a2a": All-to-All exchange of partial outputs + LSE, then
|
||||
combine with Triton kernel. Reduces NCCL calls from 3 to 2
|
||||
per layer for MLA models.
|
||||
"""
|
||||
|
||||
cp_kv_cache_interleave_size: int = 1
|
||||
"""Interleave size of kv_cache storage while using DCP or PCP.
|
||||
For `total_cp_rank = pcp_rank * dcp_world_size + dcp_rank`,
|
||||
@@ -392,6 +401,11 @@ class ParallelConfig:
|
||||
f"dcp_size={self.decode_context_parallel_size}."
|
||||
)
|
||||
|
||||
if self.dcp_comm_backend == "a2a" and self.decode_context_parallel_size <= 1:
|
||||
raise ValueError(
|
||||
"dcp_comm_backend='a2a' requires decode_context_parallel_size > 1."
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
|
||||
@@ -1645,6 +1645,8 @@ class VllmConfig:
|
||||
f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}, " # noqa
|
||||
f"pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, " # noqa
|
||||
f"data_parallel_size={self.parallel_config.data_parallel_size}, " # noqa
|
||||
f"decode_context_parallel_size={self.parallel_config.decode_context_parallel_size}, " # noqa
|
||||
f"dcp_comm_backend={self.parallel_config.dcp_comm_backend}, " # noqa
|
||||
f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa
|
||||
f"quantization={self.model_config.quantization}, "
|
||||
f"enforce_eager={self.model_config.enforce_eager}, "
|
||||
|
||||
@@ -85,6 +85,7 @@ from vllm.config.observability import DetailedTraceModules
|
||||
from vllm.config.parallel import (
|
||||
All2AllBackend,
|
||||
DataParallelBackend,
|
||||
DCPCommBackend,
|
||||
DistributedExecutorBackend,
|
||||
ExpertPlacementStrategy,
|
||||
)
|
||||
@@ -405,6 +406,7 @@ class EngineArgs:
|
||||
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
|
||||
prefill_context_parallel_size: int = ParallelConfig.prefill_context_parallel_size
|
||||
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
|
||||
dcp_comm_backend: DCPCommBackend = ParallelConfig.dcp_comm_backend
|
||||
dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
|
||||
cp_kv_cache_interleave_size: int = ParallelConfig.cp_kv_cache_interleave_size
|
||||
data_parallel_size: int = ParallelConfig.data_parallel_size
|
||||
@@ -820,6 +822,10 @@ class EngineArgs:
|
||||
"-dcp",
|
||||
**parallel_kwargs["decode_context_parallel_size"],
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--dcp-comm-backend",
|
||||
**parallel_kwargs["dcp_comm_backend"],
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--dcp-kv-cache-interleave-size",
|
||||
**parallel_kwargs["dcp_kv_cache_interleave_size"],
|
||||
@@ -1720,6 +1726,7 @@ class EngineArgs:
|
||||
worker_cls=self.worker_cls,
|
||||
worker_extension_cls=self.worker_extension_cls,
|
||||
decode_context_parallel_size=self.decode_context_parallel_size,
|
||||
dcp_comm_backend=self.dcp_comm_backend,
|
||||
dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size,
|
||||
cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size,
|
||||
_api_process_count=self._api_process_count,
|
||||
|
||||
@@ -203,8 +203,17 @@ from tqdm import tqdm
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
ModelConfig,
|
||||
VllmConfig,
|
||||
get_current_vllm_config,
|
||||
get_current_vllm_config_or_none,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_dcp_group,
|
||||
is_global_first_rank,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
@@ -253,6 +262,7 @@ from vllm.v1.attention.backends.utils import (
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
|
||||
from vllm.v1.attention.ops.dcp_alltoall import dcp_a2a_lse_reduce
|
||||
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
@@ -393,6 +403,13 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
|
||||
self.use_sparse = use_sparse
|
||||
|
||||
vllm_config = get_current_vllm_config_or_none()
|
||||
self.dcp_a2a = (
|
||||
vllm_config is not None
|
||||
and vllm_config.parallel_config.decode_context_parallel_size > 1
|
||||
and vllm_config.parallel_config.dcp_comm_backend == "a2a"
|
||||
)
|
||||
|
||||
# Initialize q/k/v range constants.
|
||||
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
|
||||
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
|
||||
@@ -647,6 +664,14 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
|
||||
# correct dcp attn_out with lse.
|
||||
if self.impl.dcp_world_size > 1:
|
||||
if self.dcp_a2a:
|
||||
attn_out = dcp_a2a_lse_reduce(
|
||||
attn_out,
|
||||
lse,
|
||||
get_dcp_group(),
|
||||
is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False),
|
||||
)
|
||||
else:
|
||||
attn_out = cp_lse_ag_out_rs(
|
||||
attn_out,
|
||||
lse,
|
||||
|
||||
@@ -23,6 +23,7 @@ from vllm.v1.attention.backends.fa_utils import (
|
||||
is_flash_attn_varlen_func_available,
|
||||
)
|
||||
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
|
||||
from vllm.v1.attention.ops.dcp_alltoall import dcp_a2a_lse_reduce
|
||||
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
|
||||
|
||||
if is_flash_attn_varlen_func_available():
|
||||
@@ -32,7 +33,12 @@ if is_flash_attn_varlen_func_available():
|
||||
get_scheduler_metadata,
|
||||
reshape_and_cache_flash,
|
||||
)
|
||||
from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
|
||||
from vllm.config import (
|
||||
VllmConfig,
|
||||
get_current_vllm_config,
|
||||
get_current_vllm_config_or_none,
|
||||
get_layers_from_vllm_config,
|
||||
)
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
from vllm.logger import init_logger
|
||||
@@ -609,6 +615,14 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
self.supports_quant_query_input = True
|
||||
|
||||
vllm_config = get_current_vllm_config_or_none()
|
||||
dcp_a2a = (
|
||||
vllm_config is not None
|
||||
and vllm_config.parallel_config.decode_context_parallel_size > 1
|
||||
and vllm_config.parallel_config.dcp_comm_backend == "a2a"
|
||||
)
|
||||
self.dcp_combine = dcp_a2a_lse_reduce if dcp_a2a else cp_lse_ag_out_rs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -857,8 +871,8 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
v_descale=v_descale,
|
||||
num_splits=attn_metadata.max_num_splits,
|
||||
)
|
||||
# FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ]
|
||||
context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs(
|
||||
# FA returns LSE in shape [ H, B ] but DCP combine wants [ B, H ]
|
||||
context_attn_out_cor, context_lse_cor = self.dcp_combine(
|
||||
context_attn_out,
|
||||
context_lse.transpose(0, 1),
|
||||
get_dcp_group(),
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
"""Attention layer with FlashInfer."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import ClassVar
|
||||
|
||||
import numpy as np
|
||||
@@ -19,7 +20,11 @@ from flashinfer.utils import FP4Tensor
|
||||
from typing_extensions import override
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
|
||||
from vllm.config import (
|
||||
CUDAGraphMode,
|
||||
VllmConfig,
|
||||
get_current_vllm_config_or_none,
|
||||
)
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
from vllm.logger import init_logger
|
||||
@@ -59,6 +64,7 @@ from vllm.v1.attention.backends.utils import (
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
|
||||
from vllm.v1.attention.ops.dcp_alltoall import dcp_a2a_lse_reduce
|
||||
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, UniformTypeKVCacheSpecs
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
@@ -170,7 +176,12 @@ class BatchDCPPrefillWrapper:
|
||||
def __init__(
|
||||
self,
|
||||
workspace_buffer: torch.Tensor | None = None,
|
||||
dcp_a2a: bool = False,
|
||||
):
|
||||
if dcp_a2a:
|
||||
self._dcp_combine = partial(dcp_a2a_lse_reduce, is_lse_base_on_e=False)
|
||||
else:
|
||||
self._dcp_combine = partial(cp_lse_ag_out_rs, is_lse_base_on_e=False)
|
||||
self._context = BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer, get_kv_cache_layout()
|
||||
)
|
||||
@@ -249,12 +260,11 @@ class BatchDCPPrefillWrapper:
|
||||
v_scale=layer._v_scale_float,
|
||||
return_lse=True,
|
||||
)
|
||||
output_context, lse_context = cp_lse_ag_out_rs(
|
||||
output_context, lse_context = self._dcp_combine(
|
||||
output_context_tmp,
|
||||
lse_context_tmp,
|
||||
get_dcp_group(),
|
||||
return_lse=True,
|
||||
is_lse_base_on_e=False,
|
||||
)
|
||||
lse_context = lse_context.transpose(0, 1).contiguous()
|
||||
|
||||
@@ -550,6 +560,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.dcp_rank = 0
|
||||
self.dcp_kv_cache_interleave_size = 1
|
||||
self.use_dcp = self.dcp_world_size > 1
|
||||
self.dcp_a2a = (
|
||||
self.use_dcp and vllm_config.parallel_config.dcp_comm_backend == "a2a"
|
||||
)
|
||||
|
||||
self.num_qo_heads = self.model_config.get_num_attention_heads(
|
||||
self.vllm_config.parallel_config
|
||||
@@ -699,6 +712,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
if self.use_dcp:
|
||||
self._prefill_wrapper = BatchDCPPrefillWrapper(
|
||||
workspace_buffer=self._get_workspace_buffer(),
|
||||
dcp_a2a=self.dcp_a2a,
|
||||
)
|
||||
else:
|
||||
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
||||
@@ -1208,15 +1222,26 @@ class FlashInferImpl(AttentionImpl):
|
||||
self.sinks = sinks
|
||||
|
||||
self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads)
|
||||
vllm_config = get_current_vllm_config()
|
||||
vllm_config = get_current_vllm_config_or_none()
|
||||
self.supports_quant_query_input = (
|
||||
self.support_trtllm_attn
|
||||
and vllm_config is not None
|
||||
and not vllm_config.attention_config.disable_flashinfer_q_quantization
|
||||
)
|
||||
self.bmm1_scale: float | None = None
|
||||
self.bmm2_scale: float | None = None
|
||||
self.o_sf_scale: float | None = None
|
||||
|
||||
dcp_a2a = (
|
||||
vllm_config is not None
|
||||
and vllm_config.parallel_config.decode_context_parallel_size > 1
|
||||
and vllm_config.parallel_config.dcp_comm_backend == "a2a"
|
||||
)
|
||||
if dcp_a2a:
|
||||
self.dcp_combine = partial(dcp_a2a_lse_reduce, is_lse_base_on_e=False)
|
||||
else:
|
||||
self.dcp_combine = partial(cp_lse_ag_out_rs, is_lse_base_on_e=False)
|
||||
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
return (
|
||||
self.support_trtllm_attn
|
||||
@@ -1503,11 +1528,10 @@ class FlashInferImpl(AttentionImpl):
|
||||
lse=lse,
|
||||
return_lse=True,
|
||||
)
|
||||
output[:num_decode_tokens] = cp_lse_ag_out_rs(
|
||||
output[:num_decode_tokens] = self.dcp_combine(
|
||||
output_tmp,
|
||||
lse,
|
||||
get_dcp_group(),
|
||||
is_lse_base_on_e=False,
|
||||
)
|
||||
else:
|
||||
decode_wrapper.run(
|
||||
|
||||
363
vllm/v1/attention/ops/dcp_alltoall.py
Normal file
363
vllm/v1/attention/ops/dcp_alltoall.py
Normal file
@@ -0,0 +1,363 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
DCP All-to-All communication backend for attention.
|
||||
|
||||
Provides All-to-All (A2A) communication as an alternative to
|
||||
AllGather + ReduceScatter (AG+RS) for Decode Context Parallel (DCP).
|
||||
Instead of gathering the full Q tensor and scattering partial outputs,
|
||||
A2A exchanges partial attention outputs and their LSE values across
|
||||
ranks, then combines them with exact LSE-weighted reduction.
|
||||
|
||||
This reduces the number of NCCL calls per attention layer from 3
|
||||
(AG for Q, AG for K metadata, RS for output) to 2 (A2A for output,
|
||||
A2A for LSE), lowering per-step communication overhead for long-context
|
||||
decode where NCCL latency is a significant fraction of step time.
|
||||
|
||||
Usage:
|
||||
vllm serve model --tp 16 --dcp 16 --dcp-comm-backend a2a
|
||||
|
||||
Reference: https://arxiv.org/abs/2507.07120
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
from vllm.v1.attention.ops.common import CPTritonContext
|
||||
|
||||
|
||||
def _lse_weighted_combine(
|
||||
outputs: torch.Tensor,
|
||||
lses: torch.Tensor,
|
||||
return_lse: bool = False,
|
||||
is_lse_base_on_e: bool = True,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
CPU reference implementation for LSE-weighted combination.
|
||||
|
||||
This is a pure PyTorch implementation used for testing and validation.
|
||||
For GPU execution, use dcp_lse_combine_triton instead.
|
||||
|
||||
Args:
|
||||
outputs: Partial attention outputs [N, B, H, D]
|
||||
N = number of KV shards (ranks)
|
||||
B = batch size (num_tokens)
|
||||
H = number of heads per rank
|
||||
D = head dimension
|
||||
lses: Log-sum-exp values [N, B, H]
|
||||
return_lse: If True, also return the global LSE
|
||||
is_lse_base_on_e: If True, LSE is base e; if False, base 2
|
||||
|
||||
Returns:
|
||||
Combined output [B, H, D], and optionally global LSE [B, H]
|
||||
"""
|
||||
N, B, H, D = outputs.shape
|
||||
|
||||
# Handle NaN and inf in LSEs
|
||||
lses = torch.where(
|
||||
torch.isnan(lses) | torch.isinf(lses),
|
||||
torch.tensor(float("-inf"), device=lses.device, dtype=lses.dtype),
|
||||
lses,
|
||||
)
|
||||
|
||||
# Compute max LSE for numerical stability
|
||||
lse_max, _ = lses.max(dim=0) # [B, H]
|
||||
lse_max = torch.where(
|
||||
lse_max == float("-inf"),
|
||||
torch.zeros_like(lse_max),
|
||||
lse_max,
|
||||
)
|
||||
|
||||
# Compute weights: softmax over the N dimension
|
||||
if is_lse_base_on_e:
|
||||
weights = torch.exp(lses - lse_max.unsqueeze(0)) # [N, B, H]
|
||||
else:
|
||||
weights = torch.pow(2.0, lses - lse_max.unsqueeze(0)) # [N, B, H]
|
||||
|
||||
# Handle NaN weights
|
||||
weights = torch.where(torch.isnan(weights), torch.zeros_like(weights), weights)
|
||||
|
||||
# Normalize weights
|
||||
weight_sum = weights.sum(dim=0, keepdim=True) # [1, B, H]
|
||||
weights = weights / weight_sum.clamp(min=1e-10) # [N, B, H]
|
||||
|
||||
# Weighted combination: sum over N dimension
|
||||
result = (outputs * weights.unsqueeze(-1)).sum(dim=0) # [B, H, D]
|
||||
|
||||
if return_lse:
|
||||
if is_lse_base_on_e:
|
||||
global_lse = torch.log(weight_sum.squeeze(0)) + lse_max # [B, H]
|
||||
else:
|
||||
global_lse = torch.log2(weight_sum.squeeze(0)) + lse_max # [B, H]
|
||||
return result, global_lse
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _dcp_lse_combine_kernel(
|
||||
# Input pointers
|
||||
recv_output_ptr,
|
||||
recv_lse_ptr,
|
||||
# Output pointers
|
||||
out_ptr,
|
||||
out_lse_ptr,
|
||||
# Strides for recv_output [N, B, H_local, D]
|
||||
ro_stride_N,
|
||||
ro_stride_B,
|
||||
ro_stride_H,
|
||||
ro_stride_D,
|
||||
# Strides for recv_lse [N, B, H_local]
|
||||
rl_stride_N,
|
||||
rl_stride_B,
|
||||
rl_stride_H,
|
||||
# Strides for output [B, H_local, D]
|
||||
o_stride_B,
|
||||
o_stride_H,
|
||||
o_stride_D,
|
||||
# Constants
|
||||
N: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
IS_BASE_E: tl.constexpr,
|
||||
RETURN_LSE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Triton kernel for LSE-weighted combination of partial attention outputs.
|
||||
|
||||
After All-to-All, each rank has:
|
||||
- recv_output [N, B, H_local, D]: partial outputs from all KV shards
|
||||
- recv_lse [N, B, H_local]: partial LSEs from all KV shards
|
||||
|
||||
This kernel computes the weighted combination locally (no communication).
|
||||
|
||||
Grid: (B, H_local)
|
||||
Each program handles one (batch, head) and processes all D elements.
|
||||
"""
|
||||
batch_idx = tl.program_id(0).to(tl.int64)
|
||||
head_idx = tl.program_id(1).to(tl.int64)
|
||||
|
||||
# Base offset for this (batch, head)
|
||||
base_lse_offset = batch_idx * rl_stride_B + head_idx * rl_stride_H
|
||||
base_out_offset = batch_idx * ro_stride_B + head_idx * ro_stride_H
|
||||
|
||||
# First pass: find max LSE for numerical stability
|
||||
lse_max = -float("inf")
|
||||
for n in tl.static_range(N):
|
||||
lse_offset = n * rl_stride_N + base_lse_offset
|
||||
lse_val = tl.load(recv_lse_ptr + lse_offset)
|
||||
lse_val = tl.where(
|
||||
(lse_val != lse_val) | (lse_val == float("inf")),
|
||||
-float("inf"),
|
||||
lse_val,
|
||||
)
|
||||
lse_max = tl.maximum(lse_max, lse_val)
|
||||
|
||||
lse_max = tl.where(lse_max == -float("inf"), 0.0, lse_max)
|
||||
|
||||
# Second pass: compute sum of exp(lse - max)
|
||||
lse_sum = 0.0
|
||||
for n in tl.static_range(N):
|
||||
lse_offset = n * rl_stride_N + base_lse_offset
|
||||
lse_val = tl.load(recv_lse_ptr + lse_offset)
|
||||
lse_val = tl.where(
|
||||
(lse_val != lse_val) | (lse_val == float("inf")),
|
||||
-float("inf"),
|
||||
lse_val,
|
||||
)
|
||||
if IS_BASE_E:
|
||||
lse_sum += tl.exp(lse_val - lse_max)
|
||||
else:
|
||||
lse_sum += tl.exp2(lse_val - lse_max)
|
||||
|
||||
# Compute global LSE
|
||||
if IS_BASE_E: # noqa: SIM108
|
||||
global_lse = tl.log(lse_sum) + lse_max
|
||||
else:
|
||||
global_lse = tl.log2(lse_sum) + lse_max
|
||||
|
||||
# Third pass: weighted combination across D dimension
|
||||
d_offsets = tl.arange(0, HEAD_DIM)
|
||||
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
|
||||
|
||||
for n in tl.static_range(N):
|
||||
lse_offset = n * rl_stride_N + base_lse_offset
|
||||
lse_val = tl.load(recv_lse_ptr + lse_offset)
|
||||
lse_val = tl.where(
|
||||
(lse_val != lse_val) | (lse_val == float("inf")),
|
||||
-float("inf"),
|
||||
lse_val,
|
||||
)
|
||||
if IS_BASE_E:
|
||||
weight = tl.exp(lse_val - global_lse)
|
||||
else:
|
||||
weight = tl.exp2(lse_val - global_lse)
|
||||
weight = tl.where(weight != weight, 0.0, weight)
|
||||
|
||||
out_offsets = n * ro_stride_N + base_out_offset + d_offsets * ro_stride_D
|
||||
out_vals = tl.load(recv_output_ptr + out_offsets)
|
||||
acc += out_vals.to(tl.float32) * weight
|
||||
|
||||
# Store result
|
||||
final_offsets = (
|
||||
batch_idx * o_stride_B + head_idx * o_stride_H + d_offsets * o_stride_D
|
||||
)
|
||||
tl.store(out_ptr + final_offsets, acc)
|
||||
|
||||
if RETURN_LSE:
|
||||
tl.store(out_lse_ptr + base_lse_offset, global_lse)
|
||||
|
||||
|
||||
def dcp_lse_combine_triton(
|
||||
recv_output: torch.Tensor,
|
||||
recv_lse: torch.Tensor,
|
||||
return_lse: bool = False,
|
||||
is_lse_base_on_e: bool = True,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Triton-accelerated LSE-weighted combination for DCP A2A.
|
||||
|
||||
Args:
|
||||
recv_output: [N, B, H_local, D] - partial outputs from all KV shards
|
||||
recv_lse: [N, B, H_local] - partial LSEs from all KV shards
|
||||
return_lse: If True, also return the global LSE
|
||||
is_lse_base_on_e: If True, LSE is base e; if False, base 2
|
||||
|
||||
Returns:
|
||||
Combined output [B, H_local, D]
|
||||
If return_lse=True, also returns global_lse [B, H_local]
|
||||
"""
|
||||
N, B, H_local, D = recv_output.shape
|
||||
|
||||
out = torch.empty(
|
||||
(B, H_local, D), device=recv_output.device, dtype=recv_output.dtype
|
||||
)
|
||||
|
||||
if return_lse:
|
||||
out_lse = torch.empty(
|
||||
(B, H_local), device=recv_lse.device, dtype=recv_lse.dtype
|
||||
)
|
||||
else:
|
||||
out_lse = torch.empty(1, device=recv_lse.device, dtype=recv_lse.dtype)
|
||||
|
||||
ro_stride_N, ro_stride_B, ro_stride_H, ro_stride_D = recv_output.stride()
|
||||
rl_stride_N, rl_stride_B, rl_stride_H = recv_lse.stride()
|
||||
o_stride_B, o_stride_H, o_stride_D = out.stride()
|
||||
|
||||
grid = (B, H_local, 1)
|
||||
|
||||
_dcp_lse_combine_kernel[grid](
|
||||
recv_output,
|
||||
recv_lse,
|
||||
out,
|
||||
out_lse,
|
||||
ro_stride_N,
|
||||
ro_stride_B,
|
||||
ro_stride_H,
|
||||
ro_stride_D,
|
||||
rl_stride_N,
|
||||
rl_stride_B,
|
||||
rl_stride_H,
|
||||
o_stride_B,
|
||||
o_stride_H,
|
||||
o_stride_D,
|
||||
N=N,
|
||||
HEAD_DIM=D,
|
||||
IS_BASE_E=is_lse_base_on_e,
|
||||
RETURN_LSE=return_lse,
|
||||
)
|
||||
|
||||
if return_lse:
|
||||
return out, out_lse
|
||||
return out
|
||||
|
||||
|
||||
def dcp_a2a_lse_reduce(
|
||||
cp_attn_out: torch.Tensor,
|
||||
cp_attn_lse: torch.Tensor,
|
||||
cp_group: GroupCoordinator,
|
||||
ctx: CPTritonContext | None = None,
|
||||
return_lse: bool = False,
|
||||
is_lse_base_on_e: bool = True,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Combine partial attention outputs across DCP ranks using All-to-All.
|
||||
|
||||
Each rank holds attention output for all heads but only a local shard
|
||||
of the KV cache. This function:
|
||||
1. Exchanges partial outputs across ranks via All-to-All
|
||||
2. Exchanges LSE values via All-to-All
|
||||
3. Combines them with exact LSE-weighted reduction (Triton kernel)
|
||||
|
||||
Tensor flow:
|
||||
Input: cp_attn_out [B, H, D] - all heads, local KV shard
|
||||
Reshape: [N, B, H/N, D] - split heads across ranks
|
||||
A2A: Two all_to_all_single calls (output and LSE)
|
||||
Combine: recv [N, B, H/N, D] + lse [N, B, H/N] -> [B, H/N, D]
|
||||
|
||||
Args:
|
||||
cp_attn_out: [B, H, D] where B=num_tokens, H=total_heads, D=head_dim
|
||||
cp_attn_lse: [B, H] log-sum-exp values (fp32)
|
||||
cp_group: GroupCoordinator for DCP communication
|
||||
ctx: CPTritonContext (unused, for signature compatibility)
|
||||
return_lse: If True, also return the combined global LSE
|
||||
is_lse_base_on_e: If True, LSE is base e; if False, base 2
|
||||
|
||||
Returns:
|
||||
Combined output [B, H/N, D] (head-scattered)
|
||||
If return_lse=True, also returns global_lse [B, H/N]
|
||||
"""
|
||||
world_size = cp_group.world_size
|
||||
|
||||
if world_size == 1:
|
||||
if return_lse:
|
||||
return cp_attn_out, cp_attn_lse
|
||||
return cp_attn_out
|
||||
|
||||
local_output = cp_attn_out.contiguous()
|
||||
local_lse = cp_attn_lse.contiguous()
|
||||
|
||||
B, H, D = local_output.shape
|
||||
H_per_rank = H // world_size
|
||||
|
||||
# Reshape for All-to-All: [B, H, D] -> [N, B, H/N, D]
|
||||
# Split heads into N chunks, each destined for a different rank
|
||||
send_output = (
|
||||
local_output.view(B, world_size, H_per_rank, D).permute(1, 0, 2, 3).contiguous()
|
||||
)
|
||||
recv_output = torch.empty_like(send_output)
|
||||
|
||||
# Same for LSE: [B, H] -> [N, B, H/N]
|
||||
send_lse = local_lse.view(B, world_size, H_per_rank).permute(1, 0, 2).contiguous()
|
||||
recv_lse = torch.empty_like(send_lse)
|
||||
|
||||
# All-to-All for partial attention outputs and LSE values (async overlap)
|
||||
work_output = dist.all_to_all_single(
|
||||
recv_output.view(-1),
|
||||
send_output.view(-1),
|
||||
group=cp_group.device_group,
|
||||
async_op=True,
|
||||
)
|
||||
work_lse = dist.all_to_all_single(
|
||||
recv_lse.view(-1),
|
||||
send_lse.view(-1),
|
||||
group=cp_group.device_group,
|
||||
async_op=True,
|
||||
)
|
||||
work_output.wait()
|
||||
work_lse.wait()
|
||||
|
||||
# LSE-weighted combination via Triton kernel (local, no communication)
|
||||
return dcp_lse_combine_triton(
|
||||
recv_output,
|
||||
recv_lse,
|
||||
return_lse=return_lse,
|
||||
is_lse_base_on_e=is_lse_base_on_e,
|
||||
)
|
||||
Reference in New Issue
Block a user