diff --git a/tests/distributed/test_dcp_a2a.py b/tests/distributed/test_dcp_a2a.py new file mode 100644 index 000000000..2f92413e5 --- /dev/null +++ b/tests/distributed/test_dcp_a2a.py @@ -0,0 +1,192 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for DCP A2A communication backend (no GPU required). + +Tests cover: +1. DCP A2A config validation (--dcp-comm-backend) +2. KVP group function exists +3. LSE-weighted combination correctness +""" + +import math + +import pytest +import torch + +from vllm.config.parallel import ParallelConfig + + +class TestDCPCommBackendConfig: + """Test --dcp-comm-backend config validation.""" + + def test_default_is_ag_rs(self): + """Default comm backend is ag_rs.""" + config = ParallelConfig() + assert config.dcp_comm_backend == "ag_rs" + + def test_a2a_requires_dcp_greater_than_1(self): + """A2A backend requires decode_context_parallel_size > 1.""" + with pytest.raises( + ValueError, match="requires decode_context_parallel_size > 1" + ): + ParallelConfig( + dcp_comm_backend="a2a", + decode_context_parallel_size=1, + ) + + def test_a2a_with_dcp_valid(self): + """A2A backend is valid when DCP > 1.""" + config = ParallelConfig( + dcp_comm_backend="a2a", + tensor_parallel_size=8, + decode_context_parallel_size=4, + ) + assert config.dcp_comm_backend == "a2a" + + def test_invalid_backend_rejected(self): + """Invalid backend values are rejected.""" + with pytest.raises(ValueError, match="must be one of"): + ParallelConfig( + dcp_comm_backend="invalid", + ) + + def test_ag_rs_with_dcp_1_valid(self): + """ag_rs backend is valid with DCP=1 (no DCP).""" + config = ParallelConfig( + dcp_comm_backend="ag_rs", + decode_context_parallel_size=1, + ) + assert config.dcp_comm_backend == "ag_rs" + + +class TestLSEWeightedCombine: + """Test LSE-weighted combination logic (CPU only, no GPU). + + The _lse_weighted_combine function is the reference implementation + that verifies the Triton kernel's correctness. It computes: + + result[b,h,d] = sum_n(w_n * output_n[b,h,d]) + + where w_n = softmax(lse_n) = exp(lse_n) / sum_k(exp(lse_k)) + """ + + def test_importable(self): + """Verify _lse_weighted_combine is importable.""" + from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine + + assert callable(_lse_weighted_combine) + + def test_single_rank(self): + """Single rank: output unchanged.""" + from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine + + # N=1, B=2, H=4, D=8 + outputs = torch.randn(1, 2, 4, 8) + lses = torch.randn(1, 2, 4) + + result = _lse_weighted_combine(outputs, lses) + + assert result.shape == (2, 4, 8) + torch.testing.assert_close(result, outputs.squeeze(0), rtol=1e-5, atol=1e-5) + + def test_equal_lse(self): + """Equal LSE values: outputs averaged equally.""" + from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine + + _N, B, H, D = 2, 1, 1, 4 + outputs = torch.tensor( + [ + [[[1.0, 2.0, 3.0, 4.0]]], # Rank 0 + [[[5.0, 6.0, 7.0, 8.0]]], # Rank 1 + ] + ) + lses = torch.tensor( + [ + [[0.0]], # Rank 0 + [[0.0]], # Rank 1 + ] + ) + + result = _lse_weighted_combine(outputs, lses) + + expected = (outputs[0] + outputs[1]) / 2 + assert result.shape == (B, H, D) + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5) + + def test_dominant_rank(self): + """Different LSE values: larger LSE gets more weight.""" + from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine + + B, H, D = 1, 1, 2 + outputs = torch.tensor( + [ + [[[0.0, 0.0]]], # Rank 0 + [[[1.0, 1.0]]], # Rank 1 + ] + ) + lses = torch.tensor( + [ + [[-100.0]], # Rank 0: negligible contribution + [[0.0]], # Rank 1: dominant + ] + ) + + result = _lse_weighted_combine(outputs, lses) + + assert result.shape == (B, H, D) + torch.testing.assert_close(result, outputs[1].squeeze(0), atol=1e-5, rtol=1e-5) + + def test_mathematically_correct(self): + """Verify mathematical correctness of LSE combination.""" + from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine + + outputs = torch.tensor( + [ + [[[2.0, 4.0]]], + [[[6.0, 8.0]]], + ] + ) + lses = torch.tensor( + [ + [[1.0]], # exp(1) ≈ 2.718 + [[2.0]], # exp(2) ≈ 7.389 + ] + ) + + result = _lse_weighted_combine(outputs, lses) + + w0 = math.exp(1) / (math.exp(1) + math.exp(2)) + w1 = math.exp(2) / (math.exp(1) + math.exp(2)) + expected = torch.tensor([[[w0 * 2.0 + w1 * 6.0, w0 * 4.0 + w1 * 8.0]]]) + + torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) + + def test_return_lse(self): + """return_lse=True returns global LSE (logsumexp of inputs).""" + from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine + + B, H, D = 1, 1, 2 + outputs = torch.tensor( + [ + [[[1.0, 2.0]]], + [[[3.0, 4.0]]], + ] + ) + lses = torch.tensor( + [ + [[1.0]], + [[2.0]], + ] + ) + + result, global_lse = _lse_weighted_combine(outputs, lses, return_lse=True) + + expected_global_lse = math.log(math.exp(1) + math.exp(2)) + + assert result.shape == (B, H, D) + assert global_lse.shape == (B, H) + assert abs(global_lse.item() - expected_global_lse) < 1e-5 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 6e84cf16b..6b69198eb 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -36,6 +36,7 @@ ExpertPlacementStrategy = Literal["linear", "round_robin"] DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"] DataParallelBackend = Literal["ray", "mp"] EPLBPolicyOption = Literal["default"] +DCPCommBackend = Literal["ag_rs", "a2a"] All2AllBackend = Literal[ "naive", "pplx", @@ -287,6 +288,14 @@ class ParallelConfig: and will be deprecated when PCP is fully supported. """ + dcp_comm_backend: DCPCommBackend = "ag_rs" + """Communication backend for Decode Context Parallel (DCP). + - "ag_rs": AllGather + ReduceScatter (default, existing behavior) + - "a2a": All-to-All exchange of partial outputs + LSE, then + combine with Triton kernel. Reduces NCCL calls from 3 to 2 + per layer for MLA models. + """ + cp_kv_cache_interleave_size: int = 1 """Interleave size of kv_cache storage while using DCP or PCP. For `total_cp_rank = pcp_rank * dcp_world_size + dcp_rank`, @@ -392,6 +401,11 @@ class ParallelConfig: f"dcp_size={self.decode_context_parallel_size}." ) + if self.dcp_comm_backend == "a2a" and self.decode_context_parallel_size <= 1: + raise ValueError( + "dcp_comm_backend='a2a' requires decode_context_parallel_size > 1." + ) + return self @property diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 44d78d737..fd5e3b464 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -1645,6 +1645,8 @@ class VllmConfig: f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}, " # noqa f"pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, " # noqa f"data_parallel_size={self.parallel_config.data_parallel_size}, " # noqa + f"decode_context_parallel_size={self.parallel_config.decode_context_parallel_size}, " # noqa + f"dcp_comm_backend={self.parallel_config.dcp_comm_backend}, " # noqa f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa f"quantization={self.model_config.quantization}, " f"enforce_eager={self.model_config.enforce_eager}, " diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c4d3c039a..6d74e867b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -85,6 +85,7 @@ from vllm.config.observability import DetailedTraceModules from vllm.config.parallel import ( All2AllBackend, DataParallelBackend, + DCPCommBackend, DistributedExecutorBackend, ExpertPlacementStrategy, ) @@ -405,6 +406,7 @@ class EngineArgs: tensor_parallel_size: int = ParallelConfig.tensor_parallel_size prefill_context_parallel_size: int = ParallelConfig.prefill_context_parallel_size decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size + dcp_comm_backend: DCPCommBackend = ParallelConfig.dcp_comm_backend dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size cp_kv_cache_interleave_size: int = ParallelConfig.cp_kv_cache_interleave_size data_parallel_size: int = ParallelConfig.data_parallel_size @@ -820,6 +822,10 @@ class EngineArgs: "-dcp", **parallel_kwargs["decode_context_parallel_size"], ) + parallel_group.add_argument( + "--dcp-comm-backend", + **parallel_kwargs["dcp_comm_backend"], + ) parallel_group.add_argument( "--dcp-kv-cache-interleave-size", **parallel_kwargs["dcp_kv_cache_interleave_size"], @@ -1720,6 +1726,7 @@ class EngineArgs: worker_cls=self.worker_cls, worker_extension_cls=self.worker_extension_cls, decode_context_parallel_size=self.decode_context_parallel_size, + dcp_comm_backend=self.dcp_comm_backend, dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size, cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size, _api_process_count=self._api_process_count, diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 820755b9c..25bc57de6 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -203,8 +203,17 @@ from tqdm import tqdm import vllm.envs as envs from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops -from vllm.config import CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config -from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank +from vllm.config import ( + CacheConfig, + ModelConfig, + VllmConfig, + get_current_vllm_config, + get_current_vllm_config_or_none, +) +from vllm.distributed.parallel_state import ( + get_dcp_group, + is_global_first_rank, +) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp @@ -253,6 +262,7 @@ from vllm.v1.attention.backends.utils import ( split_decodes_and_prefills, ) from vllm.v1.attention.ops.common import cp_lse_ag_out_rs +from vllm.v1.attention.ops.dcp_alltoall import dcp_a2a_lse_reduce from vllm.v1.attention.ops.merge_attn_states import merge_attn_states from vllm.v1.attention.selector import get_attn_backend from vllm.v1.kv_cache_interface import ( @@ -393,6 +403,13 @@ class MLAAttention(nn.Module, AttentionLayerBase): self.use_sparse = use_sparse + vllm_config = get_current_vllm_config_or_none() + self.dcp_a2a = ( + vllm_config is not None + and vllm_config.parallel_config.decode_context_parallel_size > 1 + and vllm_config.parallel_config.dcp_comm_backend == "a2a" + ) + # Initialize q/k/v range constants. self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) @@ -647,12 +664,20 @@ class MLAAttention(nn.Module, AttentionLayerBase): # correct dcp attn_out with lse. if self.impl.dcp_world_size > 1: - attn_out = cp_lse_ag_out_rs( - attn_out, - lse, - get_dcp_group(), - is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False), - ) + if self.dcp_a2a: + attn_out = dcp_a2a_lse_reduce( + attn_out, + lse, + get_dcp_group(), + is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False), + ) + else: + attn_out = cp_lse_ag_out_rs( + attn_out, + lse, + get_dcp_group(), + is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False), + ) # v_up projection self._v_up_proj(attn_out, out=mqa_output_slice) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 91c49c55c..81d62629d 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -23,6 +23,7 @@ from vllm.v1.attention.backends.fa_utils import ( is_flash_attn_varlen_func_available, ) from vllm.v1.attention.ops.common import cp_lse_ag_out_rs +from vllm.v1.attention.ops.dcp_alltoall import dcp_a2a_lse_reduce from vllm.v1.attention.ops.merge_attn_states import merge_attn_states if is_flash_attn_varlen_func_available(): @@ -32,7 +33,12 @@ if is_flash_attn_varlen_func_available(): get_scheduler_metadata, reshape_and_cache_flash, ) -from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config +from vllm.config import ( + VllmConfig, + get_current_vllm_config, + get_current_vllm_config_or_none, + get_layers_from_vllm_config, +) from vllm.config.cache import CacheDType from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger @@ -609,6 +615,14 @@ class FlashAttentionImpl(AttentionImpl): self.supports_quant_query_input = True + vllm_config = get_current_vllm_config_or_none() + dcp_a2a = ( + vllm_config is not None + and vllm_config.parallel_config.decode_context_parallel_size > 1 + and vllm_config.parallel_config.dcp_comm_backend == "a2a" + ) + self.dcp_combine = dcp_a2a_lse_reduce if dcp_a2a else cp_lse_ag_out_rs + def forward( self, layer: torch.nn.Module, @@ -857,8 +871,8 @@ class FlashAttentionImpl(AttentionImpl): v_descale=v_descale, num_splits=attn_metadata.max_num_splits, ) - # FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ] - context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs( + # FA returns LSE in shape [ H, B ] but DCP combine wants [ B, H ] + context_attn_out_cor, context_lse_cor = self.dcp_combine( context_attn_out, context_lse.transpose(0, 1), get_dcp_group(), diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 4362bacb7..46e9d2cb5 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -3,6 +3,7 @@ """Attention layer with FlashInfer.""" from dataclasses import dataclass +from functools import partial from typing import ClassVar import numpy as np @@ -19,7 +20,11 @@ from flashinfer.utils import FP4Tensor from typing_extensions import override from vllm import envs -from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config +from vllm.config import ( + CUDAGraphMode, + VllmConfig, + get_current_vllm_config_or_none, +) from vllm.config.cache import CacheDType from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger @@ -59,6 +64,7 @@ from vllm.v1.attention.backends.utils import ( split_decodes_and_prefills, ) from vllm.v1.attention.ops.common import cp_lse_ag_out_rs +from vllm.v1.attention.ops.dcp_alltoall import dcp_a2a_lse_reduce from vllm.v1.attention.ops.merge_attn_states import merge_attn_states from vllm.v1.kv_cache_interface import AttentionSpec, UniformTypeKVCacheSpecs from vllm.v1.utils import CpuGpuBuffer @@ -170,7 +176,12 @@ class BatchDCPPrefillWrapper: def __init__( self, workspace_buffer: torch.Tensor | None = None, + dcp_a2a: bool = False, ): + if dcp_a2a: + self._dcp_combine = partial(dcp_a2a_lse_reduce, is_lse_base_on_e=False) + else: + self._dcp_combine = partial(cp_lse_ag_out_rs, is_lse_base_on_e=False) self._context = BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, get_kv_cache_layout() ) @@ -249,12 +260,11 @@ class BatchDCPPrefillWrapper: v_scale=layer._v_scale_float, return_lse=True, ) - output_context, lse_context = cp_lse_ag_out_rs( + output_context, lse_context = self._dcp_combine( output_context_tmp, lse_context_tmp, get_dcp_group(), return_lse=True, - is_lse_base_on_e=False, ) lse_context = lse_context.transpose(0, 1).contiguous() @@ -550,6 +560,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.dcp_rank = 0 self.dcp_kv_cache_interleave_size = 1 self.use_dcp = self.dcp_world_size > 1 + self.dcp_a2a = ( + self.use_dcp and vllm_config.parallel_config.dcp_comm_backend == "a2a" + ) self.num_qo_heads = self.model_config.get_num_attention_heads( self.vllm_config.parallel_config @@ -699,6 +712,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): if self.use_dcp: self._prefill_wrapper = BatchDCPPrefillWrapper( workspace_buffer=self._get_workspace_buffer(), + dcp_a2a=self.dcp_a2a, ) else: self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( @@ -1208,15 +1222,26 @@ class FlashInferImpl(AttentionImpl): self.sinks = sinks self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads) - vllm_config = get_current_vllm_config() + vllm_config = get_current_vllm_config_or_none() self.supports_quant_query_input = ( self.support_trtllm_attn + and vllm_config is not None and not vllm_config.attention_config.disable_flashinfer_q_quantization ) self.bmm1_scale: float | None = None self.bmm2_scale: float | None = None self.o_sf_scale: float | None = None + dcp_a2a = ( + vllm_config is not None + and vllm_config.parallel_config.decode_context_parallel_size > 1 + and vllm_config.parallel_config.dcp_comm_backend == "a2a" + ) + if dcp_a2a: + self.dcp_combine = partial(dcp_a2a_lse_reduce, is_lse_base_on_e=False) + else: + self.dcp_combine = partial(cp_lse_ag_out_rs, is_lse_base_on_e=False) + def fused_output_quant_supported(self, quant_key: QuantKey): return ( self.support_trtllm_attn @@ -1503,11 +1528,10 @@ class FlashInferImpl(AttentionImpl): lse=lse, return_lse=True, ) - output[:num_decode_tokens] = cp_lse_ag_out_rs( + output[:num_decode_tokens] = self.dcp_combine( output_tmp, lse, get_dcp_group(), - is_lse_base_on_e=False, ) else: decode_wrapper.run( diff --git a/vllm/v1/attention/ops/dcp_alltoall.py b/vllm/v1/attention/ops/dcp_alltoall.py new file mode 100644 index 000000000..92f50f63e --- /dev/null +++ b/vllm/v1/attention/ops/dcp_alltoall.py @@ -0,0 +1,363 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +DCP All-to-All communication backend for attention. + +Provides All-to-All (A2A) communication as an alternative to +AllGather + ReduceScatter (AG+RS) for Decode Context Parallel (DCP). +Instead of gathering the full Q tensor and scattering partial outputs, +A2A exchanges partial attention outputs and their LSE values across +ranks, then combines them with exact LSE-weighted reduction. + +This reduces the number of NCCL calls per attention layer from 3 +(AG for Q, AG for K metadata, RS for output) to 2 (A2A for output, +A2A for LSE), lowering per-step communication overhead for long-context +decode where NCCL latency is a significant fraction of step time. + +Usage: + vllm serve model --tp 16 --dcp 16 --dcp-comm-backend a2a + +Reference: https://arxiv.org/abs/2507.07120 +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +import torch.distributed as dist + +from vllm.triton_utils import tl, triton + +if TYPE_CHECKING: + from vllm.distributed.parallel_state import GroupCoordinator + from vllm.v1.attention.ops.common import CPTritonContext + + +def _lse_weighted_combine( + outputs: torch.Tensor, + lses: torch.Tensor, + return_lse: bool = False, + is_lse_base_on_e: bool = True, +) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """ + CPU reference implementation for LSE-weighted combination. + + This is a pure PyTorch implementation used for testing and validation. + For GPU execution, use dcp_lse_combine_triton instead. + + Args: + outputs: Partial attention outputs [N, B, H, D] + N = number of KV shards (ranks) + B = batch size (num_tokens) + H = number of heads per rank + D = head dimension + lses: Log-sum-exp values [N, B, H] + return_lse: If True, also return the global LSE + is_lse_base_on_e: If True, LSE is base e; if False, base 2 + + Returns: + Combined output [B, H, D], and optionally global LSE [B, H] + """ + N, B, H, D = outputs.shape + + # Handle NaN and inf in LSEs + lses = torch.where( + torch.isnan(lses) | torch.isinf(lses), + torch.tensor(float("-inf"), device=lses.device, dtype=lses.dtype), + lses, + ) + + # Compute max LSE for numerical stability + lse_max, _ = lses.max(dim=0) # [B, H] + lse_max = torch.where( + lse_max == float("-inf"), + torch.zeros_like(lse_max), + lse_max, + ) + + # Compute weights: softmax over the N dimension + if is_lse_base_on_e: + weights = torch.exp(lses - lse_max.unsqueeze(0)) # [N, B, H] + else: + weights = torch.pow(2.0, lses - lse_max.unsqueeze(0)) # [N, B, H] + + # Handle NaN weights + weights = torch.where(torch.isnan(weights), torch.zeros_like(weights), weights) + + # Normalize weights + weight_sum = weights.sum(dim=0, keepdim=True) # [1, B, H] + weights = weights / weight_sum.clamp(min=1e-10) # [N, B, H] + + # Weighted combination: sum over N dimension + result = (outputs * weights.unsqueeze(-1)).sum(dim=0) # [B, H, D] + + if return_lse: + if is_lse_base_on_e: + global_lse = torch.log(weight_sum.squeeze(0)) + lse_max # [B, H] + else: + global_lse = torch.log2(weight_sum.squeeze(0)) + lse_max # [B, H] + return result, global_lse + + return result + + +@triton.jit +def _dcp_lse_combine_kernel( + # Input pointers + recv_output_ptr, + recv_lse_ptr, + # Output pointers + out_ptr, + out_lse_ptr, + # Strides for recv_output [N, B, H_local, D] + ro_stride_N, + ro_stride_B, + ro_stride_H, + ro_stride_D, + # Strides for recv_lse [N, B, H_local] + rl_stride_N, + rl_stride_B, + rl_stride_H, + # Strides for output [B, H_local, D] + o_stride_B, + o_stride_H, + o_stride_D, + # Constants + N: tl.constexpr, + HEAD_DIM: tl.constexpr, + IS_BASE_E: tl.constexpr, + RETURN_LSE: tl.constexpr, +): + """ + Triton kernel for LSE-weighted combination of partial attention outputs. + + After All-to-All, each rank has: + - recv_output [N, B, H_local, D]: partial outputs from all KV shards + - recv_lse [N, B, H_local]: partial LSEs from all KV shards + + This kernel computes the weighted combination locally (no communication). + + Grid: (B, H_local) + Each program handles one (batch, head) and processes all D elements. + """ + batch_idx = tl.program_id(0).to(tl.int64) + head_idx = tl.program_id(1).to(tl.int64) + + # Base offset for this (batch, head) + base_lse_offset = batch_idx * rl_stride_B + head_idx * rl_stride_H + base_out_offset = batch_idx * ro_stride_B + head_idx * ro_stride_H + + # First pass: find max LSE for numerical stability + lse_max = -float("inf") + for n in tl.static_range(N): + lse_offset = n * rl_stride_N + base_lse_offset + lse_val = tl.load(recv_lse_ptr + lse_offset) + lse_val = tl.where( + (lse_val != lse_val) | (lse_val == float("inf")), + -float("inf"), + lse_val, + ) + lse_max = tl.maximum(lse_max, lse_val) + + lse_max = tl.where(lse_max == -float("inf"), 0.0, lse_max) + + # Second pass: compute sum of exp(lse - max) + lse_sum = 0.0 + for n in tl.static_range(N): + lse_offset = n * rl_stride_N + base_lse_offset + lse_val = tl.load(recv_lse_ptr + lse_offset) + lse_val = tl.where( + (lse_val != lse_val) | (lse_val == float("inf")), + -float("inf"), + lse_val, + ) + if IS_BASE_E: + lse_sum += tl.exp(lse_val - lse_max) + else: + lse_sum += tl.exp2(lse_val - lse_max) + + # Compute global LSE + if IS_BASE_E: # noqa: SIM108 + global_lse = tl.log(lse_sum) + lse_max + else: + global_lse = tl.log2(lse_sum) + lse_max + + # Third pass: weighted combination across D dimension + d_offsets = tl.arange(0, HEAD_DIM) + acc = tl.zeros([HEAD_DIM], dtype=tl.float32) + + for n in tl.static_range(N): + lse_offset = n * rl_stride_N + base_lse_offset + lse_val = tl.load(recv_lse_ptr + lse_offset) + lse_val = tl.where( + (lse_val != lse_val) | (lse_val == float("inf")), + -float("inf"), + lse_val, + ) + if IS_BASE_E: + weight = tl.exp(lse_val - global_lse) + else: + weight = tl.exp2(lse_val - global_lse) + weight = tl.where(weight != weight, 0.0, weight) + + out_offsets = n * ro_stride_N + base_out_offset + d_offsets * ro_stride_D + out_vals = tl.load(recv_output_ptr + out_offsets) + acc += out_vals.to(tl.float32) * weight + + # Store result + final_offsets = ( + batch_idx * o_stride_B + head_idx * o_stride_H + d_offsets * o_stride_D + ) + tl.store(out_ptr + final_offsets, acc) + + if RETURN_LSE: + tl.store(out_lse_ptr + base_lse_offset, global_lse) + + +def dcp_lse_combine_triton( + recv_output: torch.Tensor, + recv_lse: torch.Tensor, + return_lse: bool = False, + is_lse_base_on_e: bool = True, +) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """ + Triton-accelerated LSE-weighted combination for DCP A2A. + + Args: + recv_output: [N, B, H_local, D] - partial outputs from all KV shards + recv_lse: [N, B, H_local] - partial LSEs from all KV shards + return_lse: If True, also return the global LSE + is_lse_base_on_e: If True, LSE is base e; if False, base 2 + + Returns: + Combined output [B, H_local, D] + If return_lse=True, also returns global_lse [B, H_local] + """ + N, B, H_local, D = recv_output.shape + + out = torch.empty( + (B, H_local, D), device=recv_output.device, dtype=recv_output.dtype + ) + + if return_lse: + out_lse = torch.empty( + (B, H_local), device=recv_lse.device, dtype=recv_lse.dtype + ) + else: + out_lse = torch.empty(1, device=recv_lse.device, dtype=recv_lse.dtype) + + ro_stride_N, ro_stride_B, ro_stride_H, ro_stride_D = recv_output.stride() + rl_stride_N, rl_stride_B, rl_stride_H = recv_lse.stride() + o_stride_B, o_stride_H, o_stride_D = out.stride() + + grid = (B, H_local, 1) + + _dcp_lse_combine_kernel[grid]( + recv_output, + recv_lse, + out, + out_lse, + ro_stride_N, + ro_stride_B, + ro_stride_H, + ro_stride_D, + rl_stride_N, + rl_stride_B, + rl_stride_H, + o_stride_B, + o_stride_H, + o_stride_D, + N=N, + HEAD_DIM=D, + IS_BASE_E=is_lse_base_on_e, + RETURN_LSE=return_lse, + ) + + if return_lse: + return out, out_lse + return out + + +def dcp_a2a_lse_reduce( + cp_attn_out: torch.Tensor, + cp_attn_lse: torch.Tensor, + cp_group: GroupCoordinator, + ctx: CPTritonContext | None = None, + return_lse: bool = False, + is_lse_base_on_e: bool = True, +) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """ + Combine partial attention outputs across DCP ranks using All-to-All. + + Each rank holds attention output for all heads but only a local shard + of the KV cache. This function: + 1. Exchanges partial outputs across ranks via All-to-All + 2. Exchanges LSE values via All-to-All + 3. Combines them with exact LSE-weighted reduction (Triton kernel) + + Tensor flow: + Input: cp_attn_out [B, H, D] - all heads, local KV shard + Reshape: [N, B, H/N, D] - split heads across ranks + A2A: Two all_to_all_single calls (output and LSE) + Combine: recv [N, B, H/N, D] + lse [N, B, H/N] -> [B, H/N, D] + + Args: + cp_attn_out: [B, H, D] where B=num_tokens, H=total_heads, D=head_dim + cp_attn_lse: [B, H] log-sum-exp values (fp32) + cp_group: GroupCoordinator for DCP communication + ctx: CPTritonContext (unused, for signature compatibility) + return_lse: If True, also return the combined global LSE + is_lse_base_on_e: If True, LSE is base e; if False, base 2 + + Returns: + Combined output [B, H/N, D] (head-scattered) + If return_lse=True, also returns global_lse [B, H/N] + """ + world_size = cp_group.world_size + + if world_size == 1: + if return_lse: + return cp_attn_out, cp_attn_lse + return cp_attn_out + + local_output = cp_attn_out.contiguous() + local_lse = cp_attn_lse.contiguous() + + B, H, D = local_output.shape + H_per_rank = H // world_size + + # Reshape for All-to-All: [B, H, D] -> [N, B, H/N, D] + # Split heads into N chunks, each destined for a different rank + send_output = ( + local_output.view(B, world_size, H_per_rank, D).permute(1, 0, 2, 3).contiguous() + ) + recv_output = torch.empty_like(send_output) + + # Same for LSE: [B, H] -> [N, B, H/N] + send_lse = local_lse.view(B, world_size, H_per_rank).permute(1, 0, 2).contiguous() + recv_lse = torch.empty_like(send_lse) + + # All-to-All for partial attention outputs and LSE values (async overlap) + work_output = dist.all_to_all_single( + recv_output.view(-1), + send_output.view(-1), + group=cp_group.device_group, + async_op=True, + ) + work_lse = dist.all_to_all_single( + recv_lse.view(-1), + send_lse.view(-1), + group=cp_group.device_group, + async_op=True, + ) + work_output.wait() + work_lse.wait() + + # LSE-weighted combination via Triton kernel (local, no communication) + return dcp_lse_combine_triton( + recv_output, + recv_lse, + return_lse=return_lse, + is_lse_base_on_e=is_lse_base_on_e, + )