[Core] Add All-to-All communication backend for DCP (#34883)

Signed-off-by: Sungsoo Ha <sungsooh@nvidia.com>
Signed-off-by: sungsoo ha <hasungsoo@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
sungsoo ha
2026-03-04 07:01:57 -08:00
committed by GitHub
parent ead7bde1ab
commit 6cb901093f
8 changed files with 658 additions and 17 deletions

View File

@@ -0,0 +1,192 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for DCP A2A communication backend (no GPU required).
Tests cover:
1. DCP A2A config validation (--dcp-comm-backend)
2. KVP group function exists
3. LSE-weighted combination correctness
"""
import math
import pytest
import torch
from vllm.config.parallel import ParallelConfig
class TestDCPCommBackendConfig:
"""Test --dcp-comm-backend config validation."""
def test_default_is_ag_rs(self):
"""Default comm backend is ag_rs."""
config = ParallelConfig()
assert config.dcp_comm_backend == "ag_rs"
def test_a2a_requires_dcp_greater_than_1(self):
"""A2A backend requires decode_context_parallel_size > 1."""
with pytest.raises(
ValueError, match="requires decode_context_parallel_size > 1"
):
ParallelConfig(
dcp_comm_backend="a2a",
decode_context_parallel_size=1,
)
def test_a2a_with_dcp_valid(self):
"""A2A backend is valid when DCP > 1."""
config = ParallelConfig(
dcp_comm_backend="a2a",
tensor_parallel_size=8,
decode_context_parallel_size=4,
)
assert config.dcp_comm_backend == "a2a"
def test_invalid_backend_rejected(self):
"""Invalid backend values are rejected."""
with pytest.raises(ValueError, match="must be one of"):
ParallelConfig(
dcp_comm_backend="invalid",
)
def test_ag_rs_with_dcp_1_valid(self):
"""ag_rs backend is valid with DCP=1 (no DCP)."""
config = ParallelConfig(
dcp_comm_backend="ag_rs",
decode_context_parallel_size=1,
)
assert config.dcp_comm_backend == "ag_rs"
class TestLSEWeightedCombine:
"""Test LSE-weighted combination logic (CPU only, no GPU).
The _lse_weighted_combine function is the reference implementation
that verifies the Triton kernel's correctness. It computes:
result[b,h,d] = sum_n(w_n * output_n[b,h,d])
where w_n = softmax(lse_n) = exp(lse_n) / sum_k(exp(lse_k))
"""
def test_importable(self):
"""Verify _lse_weighted_combine is importable."""
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
assert callable(_lse_weighted_combine)
def test_single_rank(self):
"""Single rank: output unchanged."""
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
# N=1, B=2, H=4, D=8
outputs = torch.randn(1, 2, 4, 8)
lses = torch.randn(1, 2, 4)
result = _lse_weighted_combine(outputs, lses)
assert result.shape == (2, 4, 8)
torch.testing.assert_close(result, outputs.squeeze(0), rtol=1e-5, atol=1e-5)
def test_equal_lse(self):
"""Equal LSE values: outputs averaged equally."""
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
_N, B, H, D = 2, 1, 1, 4
outputs = torch.tensor(
[
[[[1.0, 2.0, 3.0, 4.0]]], # Rank 0
[[[5.0, 6.0, 7.0, 8.0]]], # Rank 1
]
)
lses = torch.tensor(
[
[[0.0]], # Rank 0
[[0.0]], # Rank 1
]
)
result = _lse_weighted_combine(outputs, lses)
expected = (outputs[0] + outputs[1]) / 2
assert result.shape == (B, H, D)
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5)
def test_dominant_rank(self):
"""Different LSE values: larger LSE gets more weight."""
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
B, H, D = 1, 1, 2
outputs = torch.tensor(
[
[[[0.0, 0.0]]], # Rank 0
[[[1.0, 1.0]]], # Rank 1
]
)
lses = torch.tensor(
[
[[-100.0]], # Rank 0: negligible contribution
[[0.0]], # Rank 1: dominant
]
)
result = _lse_weighted_combine(outputs, lses)
assert result.shape == (B, H, D)
torch.testing.assert_close(result, outputs[1].squeeze(0), atol=1e-5, rtol=1e-5)
def test_mathematically_correct(self):
"""Verify mathematical correctness of LSE combination."""
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
outputs = torch.tensor(
[
[[[2.0, 4.0]]],
[[[6.0, 8.0]]],
]
)
lses = torch.tensor(
[
[[1.0]], # exp(1) ≈ 2.718
[[2.0]], # exp(2) ≈ 7.389
]
)
result = _lse_weighted_combine(outputs, lses)
w0 = math.exp(1) / (math.exp(1) + math.exp(2))
w1 = math.exp(2) / (math.exp(1) + math.exp(2))
expected = torch.tensor([[[w0 * 2.0 + w1 * 6.0, w0 * 4.0 + w1 * 8.0]]])
torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4)
def test_return_lse(self):
"""return_lse=True returns global LSE (logsumexp of inputs)."""
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
B, H, D = 1, 1, 2
outputs = torch.tensor(
[
[[[1.0, 2.0]]],
[[[3.0, 4.0]]],
]
)
lses = torch.tensor(
[
[[1.0]],
[[2.0]],
]
)
result, global_lse = _lse_weighted_combine(outputs, lses, return_lse=True)
expected_global_lse = math.log(math.exp(1) + math.exp(2))
assert result.shape == (B, H, D)
assert global_lse.shape == (B, H)
assert abs(global_lse.item() - expected_global_lse) < 1e-5
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -36,6 +36,7 @@ ExpertPlacementStrategy = Literal["linear", "round_robin"]
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
DataParallelBackend = Literal["ray", "mp"]
EPLBPolicyOption = Literal["default"]
DCPCommBackend = Literal["ag_rs", "a2a"]
All2AllBackend = Literal[
"naive",
"pplx",
@@ -287,6 +288,14 @@ class ParallelConfig:
and will be deprecated when PCP is fully supported.
"""
dcp_comm_backend: DCPCommBackend = "ag_rs"
"""Communication backend for Decode Context Parallel (DCP).
- "ag_rs": AllGather + ReduceScatter (default, existing behavior)
- "a2a": All-to-All exchange of partial outputs + LSE, then
combine with Triton kernel. Reduces NCCL calls from 3 to 2
per layer for MLA models.
"""
cp_kv_cache_interleave_size: int = 1
"""Interleave size of kv_cache storage while using DCP or PCP.
For `total_cp_rank = pcp_rank * dcp_world_size + dcp_rank`,
@@ -392,6 +401,11 @@ class ParallelConfig:
f"dcp_size={self.decode_context_parallel_size}."
)
if self.dcp_comm_backend == "a2a" and self.decode_context_parallel_size <= 1:
raise ValueError(
"dcp_comm_backend='a2a' requires decode_context_parallel_size > 1."
)
return self
@property

View File

@@ -1645,6 +1645,8 @@ class VllmConfig:
f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}, " # noqa
f"pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, " # noqa
f"data_parallel_size={self.parallel_config.data_parallel_size}, " # noqa
f"decode_context_parallel_size={self.parallel_config.decode_context_parallel_size}, " # noqa
f"dcp_comm_backend={self.parallel_config.dcp_comm_backend}, " # noqa
f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa
f"quantization={self.model_config.quantization}, "
f"enforce_eager={self.model_config.enforce_eager}, "

View File

@@ -85,6 +85,7 @@ from vllm.config.observability import DetailedTraceModules
from vllm.config.parallel import (
All2AllBackend,
DataParallelBackend,
DCPCommBackend,
DistributedExecutorBackend,
ExpertPlacementStrategy,
)
@@ -405,6 +406,7 @@ class EngineArgs:
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
prefill_context_parallel_size: int = ParallelConfig.prefill_context_parallel_size
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
dcp_comm_backend: DCPCommBackend = ParallelConfig.dcp_comm_backend
dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
cp_kv_cache_interleave_size: int = ParallelConfig.cp_kv_cache_interleave_size
data_parallel_size: int = ParallelConfig.data_parallel_size
@@ -820,6 +822,10 @@ class EngineArgs:
"-dcp",
**parallel_kwargs["decode_context_parallel_size"],
)
parallel_group.add_argument(
"--dcp-comm-backend",
**parallel_kwargs["dcp_comm_backend"],
)
parallel_group.add_argument(
"--dcp-kv-cache-interleave-size",
**parallel_kwargs["dcp_kv_cache_interleave_size"],
@@ -1720,6 +1726,7 @@ class EngineArgs:
worker_cls=self.worker_cls,
worker_extension_cls=self.worker_extension_cls,
decode_context_parallel_size=self.decode_context_parallel_size,
dcp_comm_backend=self.dcp_comm_backend,
dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size,
cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size,
_api_process_count=self._api_process_count,

View File

@@ -203,8 +203,17 @@ from tqdm import tqdm
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
from vllm.config import (
CacheConfig,
ModelConfig,
VllmConfig,
get_current_vllm_config,
get_current_vllm_config_or_none,
)
from vllm.distributed.parallel_state import (
get_dcp_group,
is_global_first_rank,
)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
@@ -253,6 +262,7 @@ from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills,
)
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
from vllm.v1.attention.ops.dcp_alltoall import dcp_a2a_lse_reduce
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import (
@@ -393,6 +403,13 @@ class MLAAttention(nn.Module, AttentionLayerBase):
self.use_sparse = use_sparse
vllm_config = get_current_vllm_config_or_none()
self.dcp_a2a = (
vllm_config is not None
and vllm_config.parallel_config.decode_context_parallel_size > 1
and vllm_config.parallel_config.dcp_comm_backend == "a2a"
)
# Initialize q/k/v range constants.
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
@@ -647,6 +664,14 @@ class MLAAttention(nn.Module, AttentionLayerBase):
# correct dcp attn_out with lse.
if self.impl.dcp_world_size > 1:
if self.dcp_a2a:
attn_out = dcp_a2a_lse_reduce(
attn_out,
lse,
get_dcp_group(),
is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False),
)
else:
attn_out = cp_lse_ag_out_rs(
attn_out,
lse,

View File

@@ -23,6 +23,7 @@ from vllm.v1.attention.backends.fa_utils import (
is_flash_attn_varlen_func_available,
)
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
from vllm.v1.attention.ops.dcp_alltoall import dcp_a2a_lse_reduce
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
if is_flash_attn_varlen_func_available():
@@ -32,7 +33,12 @@ if is_flash_attn_varlen_func_available():
get_scheduler_metadata,
reshape_and_cache_flash,
)
from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
from vllm.config import (
VllmConfig,
get_current_vllm_config,
get_current_vllm_config_or_none,
get_layers_from_vllm_config,
)
from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
@@ -609,6 +615,14 @@ class FlashAttentionImpl(AttentionImpl):
self.supports_quant_query_input = True
vllm_config = get_current_vllm_config_or_none()
dcp_a2a = (
vllm_config is not None
and vllm_config.parallel_config.decode_context_parallel_size > 1
and vllm_config.parallel_config.dcp_comm_backend == "a2a"
)
self.dcp_combine = dcp_a2a_lse_reduce if dcp_a2a else cp_lse_ag_out_rs
def forward(
self,
layer: torch.nn.Module,
@@ -857,8 +871,8 @@ class FlashAttentionImpl(AttentionImpl):
v_descale=v_descale,
num_splits=attn_metadata.max_num_splits,
)
# FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ]
context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs(
# FA returns LSE in shape [ H, B ] but DCP combine wants [ B, H ]
context_attn_out_cor, context_lse_cor = self.dcp_combine(
context_attn_out,
context_lse.transpose(0, 1),
get_dcp_group(),

View File

@@ -3,6 +3,7 @@
"""Attention layer with FlashInfer."""
from dataclasses import dataclass
from functools import partial
from typing import ClassVar
import numpy as np
@@ -19,7 +20,11 @@ from flashinfer.utils import FP4Tensor
from typing_extensions import override
from vllm import envs
from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
from vllm.config import (
CUDAGraphMode,
VllmConfig,
get_current_vllm_config_or_none,
)
from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
@@ -59,6 +64,7 @@ from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills,
)
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
from vllm.v1.attention.ops.dcp_alltoall import dcp_a2a_lse_reduce
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
from vllm.v1.kv_cache_interface import AttentionSpec, UniformTypeKVCacheSpecs
from vllm.v1.utils import CpuGpuBuffer
@@ -170,7 +176,12 @@ class BatchDCPPrefillWrapper:
def __init__(
self,
workspace_buffer: torch.Tensor | None = None,
dcp_a2a: bool = False,
):
if dcp_a2a:
self._dcp_combine = partial(dcp_a2a_lse_reduce, is_lse_base_on_e=False)
else:
self._dcp_combine = partial(cp_lse_ag_out_rs, is_lse_base_on_e=False)
self._context = BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, get_kv_cache_layout()
)
@@ -249,12 +260,11 @@ class BatchDCPPrefillWrapper:
v_scale=layer._v_scale_float,
return_lse=True,
)
output_context, lse_context = cp_lse_ag_out_rs(
output_context, lse_context = self._dcp_combine(
output_context_tmp,
lse_context_tmp,
get_dcp_group(),
return_lse=True,
is_lse_base_on_e=False,
)
lse_context = lse_context.transpose(0, 1).contiguous()
@@ -550,6 +560,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.dcp_rank = 0
self.dcp_kv_cache_interleave_size = 1
self.use_dcp = self.dcp_world_size > 1
self.dcp_a2a = (
self.use_dcp and vllm_config.parallel_config.dcp_comm_backend == "a2a"
)
self.num_qo_heads = self.model_config.get_num_attention_heads(
self.vllm_config.parallel_config
@@ -699,6 +712,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
if self.use_dcp:
self._prefill_wrapper = BatchDCPPrefillWrapper(
workspace_buffer=self._get_workspace_buffer(),
dcp_a2a=self.dcp_a2a,
)
else:
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
@@ -1208,15 +1222,26 @@ class FlashInferImpl(AttentionImpl):
self.sinks = sinks
self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads)
vllm_config = get_current_vllm_config()
vllm_config = get_current_vllm_config_or_none()
self.supports_quant_query_input = (
self.support_trtllm_attn
and vllm_config is not None
and not vllm_config.attention_config.disable_flashinfer_q_quantization
)
self.bmm1_scale: float | None = None
self.bmm2_scale: float | None = None
self.o_sf_scale: float | None = None
dcp_a2a = (
vllm_config is not None
and vllm_config.parallel_config.decode_context_parallel_size > 1
and vllm_config.parallel_config.dcp_comm_backend == "a2a"
)
if dcp_a2a:
self.dcp_combine = partial(dcp_a2a_lse_reduce, is_lse_base_on_e=False)
else:
self.dcp_combine = partial(cp_lse_ag_out_rs, is_lse_base_on_e=False)
def fused_output_quant_supported(self, quant_key: QuantKey):
return (
self.support_trtllm_attn
@@ -1503,11 +1528,10 @@ class FlashInferImpl(AttentionImpl):
lse=lse,
return_lse=True,
)
output[:num_decode_tokens] = cp_lse_ag_out_rs(
output[:num_decode_tokens] = self.dcp_combine(
output_tmp,
lse,
get_dcp_group(),
is_lse_base_on_e=False,
)
else:
decode_wrapper.run(

View File

@@ -0,0 +1,363 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
DCP All-to-All communication backend for attention.
Provides All-to-All (A2A) communication as an alternative to
AllGather + ReduceScatter (AG+RS) for Decode Context Parallel (DCP).
Instead of gathering the full Q tensor and scattering partial outputs,
A2A exchanges partial attention outputs and their LSE values across
ranks, then combines them with exact LSE-weighted reduction.
This reduces the number of NCCL calls per attention layer from 3
(AG for Q, AG for K metadata, RS for output) to 2 (A2A for output,
A2A for LSE), lowering per-step communication overhead for long-context
decode where NCCL latency is a significant fraction of step time.
Usage:
vllm serve model --tp 16 --dcp 16 --dcp-comm-backend a2a
Reference: https://arxiv.org/abs/2507.07120
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
import torch.distributed as dist
from vllm.triton_utils import tl, triton
if TYPE_CHECKING:
from vllm.distributed.parallel_state import GroupCoordinator
from vllm.v1.attention.ops.common import CPTritonContext
def _lse_weighted_combine(
outputs: torch.Tensor,
lses: torch.Tensor,
return_lse: bool = False,
is_lse_base_on_e: bool = True,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
CPU reference implementation for LSE-weighted combination.
This is a pure PyTorch implementation used for testing and validation.
For GPU execution, use dcp_lse_combine_triton instead.
Args:
outputs: Partial attention outputs [N, B, H, D]
N = number of KV shards (ranks)
B = batch size (num_tokens)
H = number of heads per rank
D = head dimension
lses: Log-sum-exp values [N, B, H]
return_lse: If True, also return the global LSE
is_lse_base_on_e: If True, LSE is base e; if False, base 2
Returns:
Combined output [B, H, D], and optionally global LSE [B, H]
"""
N, B, H, D = outputs.shape
# Handle NaN and inf in LSEs
lses = torch.where(
torch.isnan(lses) | torch.isinf(lses),
torch.tensor(float("-inf"), device=lses.device, dtype=lses.dtype),
lses,
)
# Compute max LSE for numerical stability
lse_max, _ = lses.max(dim=0) # [B, H]
lse_max = torch.where(
lse_max == float("-inf"),
torch.zeros_like(lse_max),
lse_max,
)
# Compute weights: softmax over the N dimension
if is_lse_base_on_e:
weights = torch.exp(lses - lse_max.unsqueeze(0)) # [N, B, H]
else:
weights = torch.pow(2.0, lses - lse_max.unsqueeze(0)) # [N, B, H]
# Handle NaN weights
weights = torch.where(torch.isnan(weights), torch.zeros_like(weights), weights)
# Normalize weights
weight_sum = weights.sum(dim=0, keepdim=True) # [1, B, H]
weights = weights / weight_sum.clamp(min=1e-10) # [N, B, H]
# Weighted combination: sum over N dimension
result = (outputs * weights.unsqueeze(-1)).sum(dim=0) # [B, H, D]
if return_lse:
if is_lse_base_on_e:
global_lse = torch.log(weight_sum.squeeze(0)) + lse_max # [B, H]
else:
global_lse = torch.log2(weight_sum.squeeze(0)) + lse_max # [B, H]
return result, global_lse
return result
@triton.jit
def _dcp_lse_combine_kernel(
# Input pointers
recv_output_ptr,
recv_lse_ptr,
# Output pointers
out_ptr,
out_lse_ptr,
# Strides for recv_output [N, B, H_local, D]
ro_stride_N,
ro_stride_B,
ro_stride_H,
ro_stride_D,
# Strides for recv_lse [N, B, H_local]
rl_stride_N,
rl_stride_B,
rl_stride_H,
# Strides for output [B, H_local, D]
o_stride_B,
o_stride_H,
o_stride_D,
# Constants
N: tl.constexpr,
HEAD_DIM: tl.constexpr,
IS_BASE_E: tl.constexpr,
RETURN_LSE: tl.constexpr,
):
"""
Triton kernel for LSE-weighted combination of partial attention outputs.
After All-to-All, each rank has:
- recv_output [N, B, H_local, D]: partial outputs from all KV shards
- recv_lse [N, B, H_local]: partial LSEs from all KV shards
This kernel computes the weighted combination locally (no communication).
Grid: (B, H_local)
Each program handles one (batch, head) and processes all D elements.
"""
batch_idx = tl.program_id(0).to(tl.int64)
head_idx = tl.program_id(1).to(tl.int64)
# Base offset for this (batch, head)
base_lse_offset = batch_idx * rl_stride_B + head_idx * rl_stride_H
base_out_offset = batch_idx * ro_stride_B + head_idx * ro_stride_H
# First pass: find max LSE for numerical stability
lse_max = -float("inf")
for n in tl.static_range(N):
lse_offset = n * rl_stride_N + base_lse_offset
lse_val = tl.load(recv_lse_ptr + lse_offset)
lse_val = tl.where(
(lse_val != lse_val) | (lse_val == float("inf")),
-float("inf"),
lse_val,
)
lse_max = tl.maximum(lse_max, lse_val)
lse_max = tl.where(lse_max == -float("inf"), 0.0, lse_max)
# Second pass: compute sum of exp(lse - max)
lse_sum = 0.0
for n in tl.static_range(N):
lse_offset = n * rl_stride_N + base_lse_offset
lse_val = tl.load(recv_lse_ptr + lse_offset)
lse_val = tl.where(
(lse_val != lse_val) | (lse_val == float("inf")),
-float("inf"),
lse_val,
)
if IS_BASE_E:
lse_sum += tl.exp(lse_val - lse_max)
else:
lse_sum += tl.exp2(lse_val - lse_max)
# Compute global LSE
if IS_BASE_E: # noqa: SIM108
global_lse = tl.log(lse_sum) + lse_max
else:
global_lse = tl.log2(lse_sum) + lse_max
# Third pass: weighted combination across D dimension
d_offsets = tl.arange(0, HEAD_DIM)
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
for n in tl.static_range(N):
lse_offset = n * rl_stride_N + base_lse_offset
lse_val = tl.load(recv_lse_ptr + lse_offset)
lse_val = tl.where(
(lse_val != lse_val) | (lse_val == float("inf")),
-float("inf"),
lse_val,
)
if IS_BASE_E:
weight = tl.exp(lse_val - global_lse)
else:
weight = tl.exp2(lse_val - global_lse)
weight = tl.where(weight != weight, 0.0, weight)
out_offsets = n * ro_stride_N + base_out_offset + d_offsets * ro_stride_D
out_vals = tl.load(recv_output_ptr + out_offsets)
acc += out_vals.to(tl.float32) * weight
# Store result
final_offsets = (
batch_idx * o_stride_B + head_idx * o_stride_H + d_offsets * o_stride_D
)
tl.store(out_ptr + final_offsets, acc)
if RETURN_LSE:
tl.store(out_lse_ptr + base_lse_offset, global_lse)
def dcp_lse_combine_triton(
recv_output: torch.Tensor,
recv_lse: torch.Tensor,
return_lse: bool = False,
is_lse_base_on_e: bool = True,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
Triton-accelerated LSE-weighted combination for DCP A2A.
Args:
recv_output: [N, B, H_local, D] - partial outputs from all KV shards
recv_lse: [N, B, H_local] - partial LSEs from all KV shards
return_lse: If True, also return the global LSE
is_lse_base_on_e: If True, LSE is base e; if False, base 2
Returns:
Combined output [B, H_local, D]
If return_lse=True, also returns global_lse [B, H_local]
"""
N, B, H_local, D = recv_output.shape
out = torch.empty(
(B, H_local, D), device=recv_output.device, dtype=recv_output.dtype
)
if return_lse:
out_lse = torch.empty(
(B, H_local), device=recv_lse.device, dtype=recv_lse.dtype
)
else:
out_lse = torch.empty(1, device=recv_lse.device, dtype=recv_lse.dtype)
ro_stride_N, ro_stride_B, ro_stride_H, ro_stride_D = recv_output.stride()
rl_stride_N, rl_stride_B, rl_stride_H = recv_lse.stride()
o_stride_B, o_stride_H, o_stride_D = out.stride()
grid = (B, H_local, 1)
_dcp_lse_combine_kernel[grid](
recv_output,
recv_lse,
out,
out_lse,
ro_stride_N,
ro_stride_B,
ro_stride_H,
ro_stride_D,
rl_stride_N,
rl_stride_B,
rl_stride_H,
o_stride_B,
o_stride_H,
o_stride_D,
N=N,
HEAD_DIM=D,
IS_BASE_E=is_lse_base_on_e,
RETURN_LSE=return_lse,
)
if return_lse:
return out, out_lse
return out
def dcp_a2a_lse_reduce(
cp_attn_out: torch.Tensor,
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext | None = None,
return_lse: bool = False,
is_lse_base_on_e: bool = True,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
Combine partial attention outputs across DCP ranks using All-to-All.
Each rank holds attention output for all heads but only a local shard
of the KV cache. This function:
1. Exchanges partial outputs across ranks via All-to-All
2. Exchanges LSE values via All-to-All
3. Combines them with exact LSE-weighted reduction (Triton kernel)
Tensor flow:
Input: cp_attn_out [B, H, D] - all heads, local KV shard
Reshape: [N, B, H/N, D] - split heads across ranks
A2A: Two all_to_all_single calls (output and LSE)
Combine: recv [N, B, H/N, D] + lse [N, B, H/N] -> [B, H/N, D]
Args:
cp_attn_out: [B, H, D] where B=num_tokens, H=total_heads, D=head_dim
cp_attn_lse: [B, H] log-sum-exp values (fp32)
cp_group: GroupCoordinator for DCP communication
ctx: CPTritonContext (unused, for signature compatibility)
return_lse: If True, also return the combined global LSE
is_lse_base_on_e: If True, LSE is base e; if False, base 2
Returns:
Combined output [B, H/N, D] (head-scattered)
If return_lse=True, also returns global_lse [B, H/N]
"""
world_size = cp_group.world_size
if world_size == 1:
if return_lse:
return cp_attn_out, cp_attn_lse
return cp_attn_out
local_output = cp_attn_out.contiguous()
local_lse = cp_attn_lse.contiguous()
B, H, D = local_output.shape
H_per_rank = H // world_size
# Reshape for All-to-All: [B, H, D] -> [N, B, H/N, D]
# Split heads into N chunks, each destined for a different rank
send_output = (
local_output.view(B, world_size, H_per_rank, D).permute(1, 0, 2, 3).contiguous()
)
recv_output = torch.empty_like(send_output)
# Same for LSE: [B, H] -> [N, B, H/N]
send_lse = local_lse.view(B, world_size, H_per_rank).permute(1, 0, 2).contiguous()
recv_lse = torch.empty_like(send_lse)
# All-to-All for partial attention outputs and LSE values (async overlap)
work_output = dist.all_to_all_single(
recv_output.view(-1),
send_output.view(-1),
group=cp_group.device_group,
async_op=True,
)
work_lse = dist.all_to_all_single(
recv_lse.view(-1),
send_lse.view(-1),
group=cp_group.device_group,
async_op=True,
)
work_output.wait()
work_lse.wait()
# LSE-weighted combination via Triton kernel (local, no communication)
return dcp_lse_combine_triton(
recv_output,
recv_lse,
return_lse=return_lse,
is_lse_base_on_e=is_lse_base_on_e,
)