[Core] Add All-to-All communication backend for DCP (#34883)
Signed-off-by: Sungsoo Ha <sungsooh@nvidia.com> Signed-off-by: sungsoo ha <hasungsoo@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
192
tests/distributed/test_dcp_a2a.py
Normal file
192
tests/distributed/test_dcp_a2a.py
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Unit tests for DCP A2A communication backend (no GPU required).
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
1. DCP A2A config validation (--dcp-comm-backend)
|
||||||
|
2. KVP group function exists
|
||||||
|
3. LSE-weighted combination correctness
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.config.parallel import ParallelConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TestDCPCommBackendConfig:
|
||||||
|
"""Test --dcp-comm-backend config validation."""
|
||||||
|
|
||||||
|
def test_default_is_ag_rs(self):
|
||||||
|
"""Default comm backend is ag_rs."""
|
||||||
|
config = ParallelConfig()
|
||||||
|
assert config.dcp_comm_backend == "ag_rs"
|
||||||
|
|
||||||
|
def test_a2a_requires_dcp_greater_than_1(self):
|
||||||
|
"""A2A backend requires decode_context_parallel_size > 1."""
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="requires decode_context_parallel_size > 1"
|
||||||
|
):
|
||||||
|
ParallelConfig(
|
||||||
|
dcp_comm_backend="a2a",
|
||||||
|
decode_context_parallel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_a2a_with_dcp_valid(self):
|
||||||
|
"""A2A backend is valid when DCP > 1."""
|
||||||
|
config = ParallelConfig(
|
||||||
|
dcp_comm_backend="a2a",
|
||||||
|
tensor_parallel_size=8,
|
||||||
|
decode_context_parallel_size=4,
|
||||||
|
)
|
||||||
|
assert config.dcp_comm_backend == "a2a"
|
||||||
|
|
||||||
|
def test_invalid_backend_rejected(self):
|
||||||
|
"""Invalid backend values are rejected."""
|
||||||
|
with pytest.raises(ValueError, match="must be one of"):
|
||||||
|
ParallelConfig(
|
||||||
|
dcp_comm_backend="invalid",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_ag_rs_with_dcp_1_valid(self):
|
||||||
|
"""ag_rs backend is valid with DCP=1 (no DCP)."""
|
||||||
|
config = ParallelConfig(
|
||||||
|
dcp_comm_backend="ag_rs",
|
||||||
|
decode_context_parallel_size=1,
|
||||||
|
)
|
||||||
|
assert config.dcp_comm_backend == "ag_rs"
|
||||||
|
|
||||||
|
|
||||||
|
class TestLSEWeightedCombine:
|
||||||
|
"""Test LSE-weighted combination logic (CPU only, no GPU).
|
||||||
|
|
||||||
|
The _lse_weighted_combine function is the reference implementation
|
||||||
|
that verifies the Triton kernel's correctness. It computes:
|
||||||
|
|
||||||
|
result[b,h,d] = sum_n(w_n * output_n[b,h,d])
|
||||||
|
|
||||||
|
where w_n = softmax(lse_n) = exp(lse_n) / sum_k(exp(lse_k))
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_importable(self):
|
||||||
|
"""Verify _lse_weighted_combine is importable."""
|
||||||
|
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
|
||||||
|
|
||||||
|
assert callable(_lse_weighted_combine)
|
||||||
|
|
||||||
|
def test_single_rank(self):
|
||||||
|
"""Single rank: output unchanged."""
|
||||||
|
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
|
||||||
|
|
||||||
|
# N=1, B=2, H=4, D=8
|
||||||
|
outputs = torch.randn(1, 2, 4, 8)
|
||||||
|
lses = torch.randn(1, 2, 4)
|
||||||
|
|
||||||
|
result = _lse_weighted_combine(outputs, lses)
|
||||||
|
|
||||||
|
assert result.shape == (2, 4, 8)
|
||||||
|
torch.testing.assert_close(result, outputs.squeeze(0), rtol=1e-5, atol=1e-5)
|
||||||
|
|
||||||
|
def test_equal_lse(self):
|
||||||
|
"""Equal LSE values: outputs averaged equally."""
|
||||||
|
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
|
||||||
|
|
||||||
|
_N, B, H, D = 2, 1, 1, 4
|
||||||
|
outputs = torch.tensor(
|
||||||
|
[
|
||||||
|
[[[1.0, 2.0, 3.0, 4.0]]], # Rank 0
|
||||||
|
[[[5.0, 6.0, 7.0, 8.0]]], # Rank 1
|
||||||
|
]
|
||||||
|
)
|
||||||
|
lses = torch.tensor(
|
||||||
|
[
|
||||||
|
[[0.0]], # Rank 0
|
||||||
|
[[0.0]], # Rank 1
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = _lse_weighted_combine(outputs, lses)
|
||||||
|
|
||||||
|
expected = (outputs[0] + outputs[1]) / 2
|
||||||
|
assert result.shape == (B, H, D)
|
||||||
|
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5)
|
||||||
|
|
||||||
|
def test_dominant_rank(self):
|
||||||
|
"""Different LSE values: larger LSE gets more weight."""
|
||||||
|
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
|
||||||
|
|
||||||
|
B, H, D = 1, 1, 2
|
||||||
|
outputs = torch.tensor(
|
||||||
|
[
|
||||||
|
[[[0.0, 0.0]]], # Rank 0
|
||||||
|
[[[1.0, 1.0]]], # Rank 1
|
||||||
|
]
|
||||||
|
)
|
||||||
|
lses = torch.tensor(
|
||||||
|
[
|
||||||
|
[[-100.0]], # Rank 0: negligible contribution
|
||||||
|
[[0.0]], # Rank 1: dominant
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = _lse_weighted_combine(outputs, lses)
|
||||||
|
|
||||||
|
assert result.shape == (B, H, D)
|
||||||
|
torch.testing.assert_close(result, outputs[1].squeeze(0), atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
def test_mathematically_correct(self):
|
||||||
|
"""Verify mathematical correctness of LSE combination."""
|
||||||
|
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
|
||||||
|
|
||||||
|
outputs = torch.tensor(
|
||||||
|
[
|
||||||
|
[[[2.0, 4.0]]],
|
||||||
|
[[[6.0, 8.0]]],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
lses = torch.tensor(
|
||||||
|
[
|
||||||
|
[[1.0]], # exp(1) ≈ 2.718
|
||||||
|
[[2.0]], # exp(2) ≈ 7.389
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = _lse_weighted_combine(outputs, lses)
|
||||||
|
|
||||||
|
w0 = math.exp(1) / (math.exp(1) + math.exp(2))
|
||||||
|
w1 = math.exp(2) / (math.exp(1) + math.exp(2))
|
||||||
|
expected = torch.tensor([[[w0 * 2.0 + w1 * 6.0, w0 * 4.0 + w1 * 8.0]]])
|
||||||
|
|
||||||
|
torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4)
|
||||||
|
|
||||||
|
def test_return_lse(self):
|
||||||
|
"""return_lse=True returns global LSE (logsumexp of inputs)."""
|
||||||
|
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine
|
||||||
|
|
||||||
|
B, H, D = 1, 1, 2
|
||||||
|
outputs = torch.tensor(
|
||||||
|
[
|
||||||
|
[[[1.0, 2.0]]],
|
||||||
|
[[[3.0, 4.0]]],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
lses = torch.tensor(
|
||||||
|
[
|
||||||
|
[[1.0]],
|
||||||
|
[[2.0]],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
result, global_lse = _lse_weighted_combine(outputs, lses, return_lse=True)
|
||||||
|
|
||||||
|
expected_global_lse = math.log(math.exp(1) + math.exp(2))
|
||||||
|
|
||||||
|
assert result.shape == (B, H, D)
|
||||||
|
assert global_lse.shape == (B, H)
|
||||||
|
assert abs(global_lse.item() - expected_global_lse) < 1e-5
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
@@ -36,6 +36,7 @@ ExpertPlacementStrategy = Literal["linear", "round_robin"]
|
|||||||
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
|
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
|
||||||
DataParallelBackend = Literal["ray", "mp"]
|
DataParallelBackend = Literal["ray", "mp"]
|
||||||
EPLBPolicyOption = Literal["default"]
|
EPLBPolicyOption = Literal["default"]
|
||||||
|
DCPCommBackend = Literal["ag_rs", "a2a"]
|
||||||
All2AllBackend = Literal[
|
All2AllBackend = Literal[
|
||||||
"naive",
|
"naive",
|
||||||
"pplx",
|
"pplx",
|
||||||
@@ -287,6 +288,14 @@ class ParallelConfig:
|
|||||||
and will be deprecated when PCP is fully supported.
|
and will be deprecated when PCP is fully supported.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
dcp_comm_backend: DCPCommBackend = "ag_rs"
|
||||||
|
"""Communication backend for Decode Context Parallel (DCP).
|
||||||
|
- "ag_rs": AllGather + ReduceScatter (default, existing behavior)
|
||||||
|
- "a2a": All-to-All exchange of partial outputs + LSE, then
|
||||||
|
combine with Triton kernel. Reduces NCCL calls from 3 to 2
|
||||||
|
per layer for MLA models.
|
||||||
|
"""
|
||||||
|
|
||||||
cp_kv_cache_interleave_size: int = 1
|
cp_kv_cache_interleave_size: int = 1
|
||||||
"""Interleave size of kv_cache storage while using DCP or PCP.
|
"""Interleave size of kv_cache storage while using DCP or PCP.
|
||||||
For `total_cp_rank = pcp_rank * dcp_world_size + dcp_rank`,
|
For `total_cp_rank = pcp_rank * dcp_world_size + dcp_rank`,
|
||||||
@@ -392,6 +401,11 @@ class ParallelConfig:
|
|||||||
f"dcp_size={self.decode_context_parallel_size}."
|
f"dcp_size={self.decode_context_parallel_size}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.dcp_comm_backend == "a2a" and self.decode_context_parallel_size <= 1:
|
||||||
|
raise ValueError(
|
||||||
|
"dcp_comm_backend='a2a' requires decode_context_parallel_size > 1."
|
||||||
|
)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -1645,6 +1645,8 @@ class VllmConfig:
|
|||||||
f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}, " # noqa
|
f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}, " # noqa
|
||||||
f"pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, " # noqa
|
f"pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, " # noqa
|
||||||
f"data_parallel_size={self.parallel_config.data_parallel_size}, " # noqa
|
f"data_parallel_size={self.parallel_config.data_parallel_size}, " # noqa
|
||||||
|
f"decode_context_parallel_size={self.parallel_config.decode_context_parallel_size}, " # noqa
|
||||||
|
f"dcp_comm_backend={self.parallel_config.dcp_comm_backend}, " # noqa
|
||||||
f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa
|
f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa
|
||||||
f"quantization={self.model_config.quantization}, "
|
f"quantization={self.model_config.quantization}, "
|
||||||
f"enforce_eager={self.model_config.enforce_eager}, "
|
f"enforce_eager={self.model_config.enforce_eager}, "
|
||||||
|
|||||||
@@ -85,6 +85,7 @@ from vllm.config.observability import DetailedTraceModules
|
|||||||
from vllm.config.parallel import (
|
from vllm.config.parallel import (
|
||||||
All2AllBackend,
|
All2AllBackend,
|
||||||
DataParallelBackend,
|
DataParallelBackend,
|
||||||
|
DCPCommBackend,
|
||||||
DistributedExecutorBackend,
|
DistributedExecutorBackend,
|
||||||
ExpertPlacementStrategy,
|
ExpertPlacementStrategy,
|
||||||
)
|
)
|
||||||
@@ -405,6 +406,7 @@ class EngineArgs:
|
|||||||
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
|
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
|
||||||
prefill_context_parallel_size: int = ParallelConfig.prefill_context_parallel_size
|
prefill_context_parallel_size: int = ParallelConfig.prefill_context_parallel_size
|
||||||
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
|
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
|
||||||
|
dcp_comm_backend: DCPCommBackend = ParallelConfig.dcp_comm_backend
|
||||||
dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
|
dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
|
||||||
cp_kv_cache_interleave_size: int = ParallelConfig.cp_kv_cache_interleave_size
|
cp_kv_cache_interleave_size: int = ParallelConfig.cp_kv_cache_interleave_size
|
||||||
data_parallel_size: int = ParallelConfig.data_parallel_size
|
data_parallel_size: int = ParallelConfig.data_parallel_size
|
||||||
@@ -820,6 +822,10 @@ class EngineArgs:
|
|||||||
"-dcp",
|
"-dcp",
|
||||||
**parallel_kwargs["decode_context_parallel_size"],
|
**parallel_kwargs["decode_context_parallel_size"],
|
||||||
)
|
)
|
||||||
|
parallel_group.add_argument(
|
||||||
|
"--dcp-comm-backend",
|
||||||
|
**parallel_kwargs["dcp_comm_backend"],
|
||||||
|
)
|
||||||
parallel_group.add_argument(
|
parallel_group.add_argument(
|
||||||
"--dcp-kv-cache-interleave-size",
|
"--dcp-kv-cache-interleave-size",
|
||||||
**parallel_kwargs["dcp_kv_cache_interleave_size"],
|
**parallel_kwargs["dcp_kv_cache_interleave_size"],
|
||||||
@@ -1720,6 +1726,7 @@ class EngineArgs:
|
|||||||
worker_cls=self.worker_cls,
|
worker_cls=self.worker_cls,
|
||||||
worker_extension_cls=self.worker_extension_cls,
|
worker_extension_cls=self.worker_extension_cls,
|
||||||
decode_context_parallel_size=self.decode_context_parallel_size,
|
decode_context_parallel_size=self.decode_context_parallel_size,
|
||||||
|
dcp_comm_backend=self.dcp_comm_backend,
|
||||||
dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size,
|
dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size,
|
||||||
cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size,
|
cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size,
|
||||||
_api_process_count=self._api_process_count,
|
_api_process_count=self._api_process_count,
|
||||||
|
|||||||
@@ -203,8 +203,17 @@ from tqdm import tqdm
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm._aiter_ops import rocm_aiter_ops
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config
|
from vllm.config import (
|
||||||
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
|
CacheConfig,
|
||||||
|
ModelConfig,
|
||||||
|
VllmConfig,
|
||||||
|
get_current_vllm_config,
|
||||||
|
get_current_vllm_config_or_none,
|
||||||
|
)
|
||||||
|
from vllm.distributed.parallel_state import (
|
||||||
|
get_dcp_group,
|
||||||
|
is_global_first_rank,
|
||||||
|
)
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
@@ -253,6 +262,7 @@ from vllm.v1.attention.backends.utils import (
|
|||||||
split_decodes_and_prefills,
|
split_decodes_and_prefills,
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
|
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
|
||||||
|
from vllm.v1.attention.ops.dcp_alltoall import dcp_a2a_lse_reduce
|
||||||
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
|
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
|
||||||
from vllm.v1.attention.selector import get_attn_backend
|
from vllm.v1.attention.selector import get_attn_backend
|
||||||
from vllm.v1.kv_cache_interface import (
|
from vllm.v1.kv_cache_interface import (
|
||||||
@@ -393,6 +403,13 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
|||||||
|
|
||||||
self.use_sparse = use_sparse
|
self.use_sparse = use_sparse
|
||||||
|
|
||||||
|
vllm_config = get_current_vllm_config_or_none()
|
||||||
|
self.dcp_a2a = (
|
||||||
|
vllm_config is not None
|
||||||
|
and vllm_config.parallel_config.decode_context_parallel_size > 1
|
||||||
|
and vllm_config.parallel_config.dcp_comm_backend == "a2a"
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize q/k/v range constants.
|
# Initialize q/k/v range constants.
|
||||||
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
|
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
|
||||||
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
|
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
|
||||||
@@ -647,12 +664,20 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
|||||||
|
|
||||||
# correct dcp attn_out with lse.
|
# correct dcp attn_out with lse.
|
||||||
if self.impl.dcp_world_size > 1:
|
if self.impl.dcp_world_size > 1:
|
||||||
attn_out = cp_lse_ag_out_rs(
|
if self.dcp_a2a:
|
||||||
attn_out,
|
attn_out = dcp_a2a_lse_reduce(
|
||||||
lse,
|
attn_out,
|
||||||
get_dcp_group(),
|
lse,
|
||||||
is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False),
|
get_dcp_group(),
|
||||||
)
|
is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_out = cp_lse_ag_out_rs(
|
||||||
|
attn_out,
|
||||||
|
lse,
|
||||||
|
get_dcp_group(),
|
||||||
|
is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False),
|
||||||
|
)
|
||||||
|
|
||||||
# v_up projection
|
# v_up projection
|
||||||
self._v_up_proj(attn_out, out=mqa_output_slice)
|
self._v_up_proj(attn_out, out=mqa_output_slice)
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from vllm.v1.attention.backends.fa_utils import (
|
|||||||
is_flash_attn_varlen_func_available,
|
is_flash_attn_varlen_func_available,
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
|
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
|
||||||
|
from vllm.v1.attention.ops.dcp_alltoall import dcp_a2a_lse_reduce
|
||||||
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
|
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
|
||||||
|
|
||||||
if is_flash_attn_varlen_func_available():
|
if is_flash_attn_varlen_func_available():
|
||||||
@@ -32,7 +33,12 @@ if is_flash_attn_varlen_func_available():
|
|||||||
get_scheduler_metadata,
|
get_scheduler_metadata,
|
||||||
reshape_and_cache_flash,
|
reshape_and_cache_flash,
|
||||||
)
|
)
|
||||||
from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
|
from vllm.config import (
|
||||||
|
VllmConfig,
|
||||||
|
get_current_vllm_config,
|
||||||
|
get_current_vllm_config_or_none,
|
||||||
|
get_layers_from_vllm_config,
|
||||||
|
)
|
||||||
from vllm.config.cache import CacheDType
|
from vllm.config.cache import CacheDType
|
||||||
from vllm.distributed.parallel_state import get_dcp_group
|
from vllm.distributed.parallel_state import get_dcp_group
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@@ -609,6 +615,14 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
self.supports_quant_query_input = True
|
self.supports_quant_query_input = True
|
||||||
|
|
||||||
|
vllm_config = get_current_vllm_config_or_none()
|
||||||
|
dcp_a2a = (
|
||||||
|
vllm_config is not None
|
||||||
|
and vllm_config.parallel_config.decode_context_parallel_size > 1
|
||||||
|
and vllm_config.parallel_config.dcp_comm_backend == "a2a"
|
||||||
|
)
|
||||||
|
self.dcp_combine = dcp_a2a_lse_reduce if dcp_a2a else cp_lse_ag_out_rs
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@@ -857,8 +871,8 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
v_descale=v_descale,
|
v_descale=v_descale,
|
||||||
num_splits=attn_metadata.max_num_splits,
|
num_splits=attn_metadata.max_num_splits,
|
||||||
)
|
)
|
||||||
# FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ]
|
# FA returns LSE in shape [ H, B ] but DCP combine wants [ B, H ]
|
||||||
context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs(
|
context_attn_out_cor, context_lse_cor = self.dcp_combine(
|
||||||
context_attn_out,
|
context_attn_out,
|
||||||
context_lse.transpose(0, 1),
|
context_lse.transpose(0, 1),
|
||||||
get_dcp_group(),
|
get_dcp_group(),
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
"""Attention layer with FlashInfer."""
|
"""Attention layer with FlashInfer."""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
from typing import ClassVar
|
from typing import ClassVar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -19,7 +20,11 @@ from flashinfer.utils import FP4Tensor
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
|
from vllm.config import (
|
||||||
|
CUDAGraphMode,
|
||||||
|
VllmConfig,
|
||||||
|
get_current_vllm_config_or_none,
|
||||||
|
)
|
||||||
from vllm.config.cache import CacheDType
|
from vllm.config.cache import CacheDType
|
||||||
from vllm.distributed.parallel_state import get_dcp_group
|
from vllm.distributed.parallel_state import get_dcp_group
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@@ -59,6 +64,7 @@ from vllm.v1.attention.backends.utils import (
|
|||||||
split_decodes_and_prefills,
|
split_decodes_and_prefills,
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
|
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
|
||||||
|
from vllm.v1.attention.ops.dcp_alltoall import dcp_a2a_lse_reduce
|
||||||
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
|
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec, UniformTypeKVCacheSpecs
|
from vllm.v1.kv_cache_interface import AttentionSpec, UniformTypeKVCacheSpecs
|
||||||
from vllm.v1.utils import CpuGpuBuffer
|
from vllm.v1.utils import CpuGpuBuffer
|
||||||
@@ -170,7 +176,12 @@ class BatchDCPPrefillWrapper:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
workspace_buffer: torch.Tensor | None = None,
|
workspace_buffer: torch.Tensor | None = None,
|
||||||
|
dcp_a2a: bool = False,
|
||||||
):
|
):
|
||||||
|
if dcp_a2a:
|
||||||
|
self._dcp_combine = partial(dcp_a2a_lse_reduce, is_lse_base_on_e=False)
|
||||||
|
else:
|
||||||
|
self._dcp_combine = partial(cp_lse_ag_out_rs, is_lse_base_on_e=False)
|
||||||
self._context = BatchPrefillWithPagedKVCacheWrapper(
|
self._context = BatchPrefillWithPagedKVCacheWrapper(
|
||||||
workspace_buffer, get_kv_cache_layout()
|
workspace_buffer, get_kv_cache_layout()
|
||||||
)
|
)
|
||||||
@@ -249,12 +260,11 @@ class BatchDCPPrefillWrapper:
|
|||||||
v_scale=layer._v_scale_float,
|
v_scale=layer._v_scale_float,
|
||||||
return_lse=True,
|
return_lse=True,
|
||||||
)
|
)
|
||||||
output_context, lse_context = cp_lse_ag_out_rs(
|
output_context, lse_context = self._dcp_combine(
|
||||||
output_context_tmp,
|
output_context_tmp,
|
||||||
lse_context_tmp,
|
lse_context_tmp,
|
||||||
get_dcp_group(),
|
get_dcp_group(),
|
||||||
return_lse=True,
|
return_lse=True,
|
||||||
is_lse_base_on_e=False,
|
|
||||||
)
|
)
|
||||||
lse_context = lse_context.transpose(0, 1).contiguous()
|
lse_context = lse_context.transpose(0, 1).contiguous()
|
||||||
|
|
||||||
@@ -550,6 +560,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
self.dcp_rank = 0
|
self.dcp_rank = 0
|
||||||
self.dcp_kv_cache_interleave_size = 1
|
self.dcp_kv_cache_interleave_size = 1
|
||||||
self.use_dcp = self.dcp_world_size > 1
|
self.use_dcp = self.dcp_world_size > 1
|
||||||
|
self.dcp_a2a = (
|
||||||
|
self.use_dcp and vllm_config.parallel_config.dcp_comm_backend == "a2a"
|
||||||
|
)
|
||||||
|
|
||||||
self.num_qo_heads = self.model_config.get_num_attention_heads(
|
self.num_qo_heads = self.model_config.get_num_attention_heads(
|
||||||
self.vllm_config.parallel_config
|
self.vllm_config.parallel_config
|
||||||
@@ -699,6 +712,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
if self.use_dcp:
|
if self.use_dcp:
|
||||||
self._prefill_wrapper = BatchDCPPrefillWrapper(
|
self._prefill_wrapper = BatchDCPPrefillWrapper(
|
||||||
workspace_buffer=self._get_workspace_buffer(),
|
workspace_buffer=self._get_workspace_buffer(),
|
||||||
|
dcp_a2a=self.dcp_a2a,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
||||||
@@ -1208,15 +1222,26 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
self.sinks = sinks
|
self.sinks = sinks
|
||||||
|
|
||||||
self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads)
|
self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads)
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config_or_none()
|
||||||
self.supports_quant_query_input = (
|
self.supports_quant_query_input = (
|
||||||
self.support_trtllm_attn
|
self.support_trtllm_attn
|
||||||
|
and vllm_config is not None
|
||||||
and not vllm_config.attention_config.disable_flashinfer_q_quantization
|
and not vllm_config.attention_config.disable_flashinfer_q_quantization
|
||||||
)
|
)
|
||||||
self.bmm1_scale: float | None = None
|
self.bmm1_scale: float | None = None
|
||||||
self.bmm2_scale: float | None = None
|
self.bmm2_scale: float | None = None
|
||||||
self.o_sf_scale: float | None = None
|
self.o_sf_scale: float | None = None
|
||||||
|
|
||||||
|
dcp_a2a = (
|
||||||
|
vllm_config is not None
|
||||||
|
and vllm_config.parallel_config.decode_context_parallel_size > 1
|
||||||
|
and vllm_config.parallel_config.dcp_comm_backend == "a2a"
|
||||||
|
)
|
||||||
|
if dcp_a2a:
|
||||||
|
self.dcp_combine = partial(dcp_a2a_lse_reduce, is_lse_base_on_e=False)
|
||||||
|
else:
|
||||||
|
self.dcp_combine = partial(cp_lse_ag_out_rs, is_lse_base_on_e=False)
|
||||||
|
|
||||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||||
return (
|
return (
|
||||||
self.support_trtllm_attn
|
self.support_trtllm_attn
|
||||||
@@ -1503,11 +1528,10 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
lse=lse,
|
lse=lse,
|
||||||
return_lse=True,
|
return_lse=True,
|
||||||
)
|
)
|
||||||
output[:num_decode_tokens] = cp_lse_ag_out_rs(
|
output[:num_decode_tokens] = self.dcp_combine(
|
||||||
output_tmp,
|
output_tmp,
|
||||||
lse,
|
lse,
|
||||||
get_dcp_group(),
|
get_dcp_group(),
|
||||||
is_lse_base_on_e=False,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
decode_wrapper.run(
|
decode_wrapper.run(
|
||||||
|
|||||||
363
vllm/v1/attention/ops/dcp_alltoall.py
Normal file
363
vllm/v1/attention/ops/dcp_alltoall.py
Normal file
@@ -0,0 +1,363 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
DCP All-to-All communication backend for attention.
|
||||||
|
|
||||||
|
Provides All-to-All (A2A) communication as an alternative to
|
||||||
|
AllGather + ReduceScatter (AG+RS) for Decode Context Parallel (DCP).
|
||||||
|
Instead of gathering the full Q tensor and scattering partial outputs,
|
||||||
|
A2A exchanges partial attention outputs and their LSE values across
|
||||||
|
ranks, then combines them with exact LSE-weighted reduction.
|
||||||
|
|
||||||
|
This reduces the number of NCCL calls per attention layer from 3
|
||||||
|
(AG for Q, AG for K metadata, RS for output) to 2 (A2A for output,
|
||||||
|
A2A for LSE), lowering per-step communication overhead for long-context
|
||||||
|
decode where NCCL latency is a significant fraction of step time.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
vllm serve model --tp 16 --dcp 16 --dcp-comm-backend a2a
|
||||||
|
|
||||||
|
Reference: https://arxiv.org/abs/2507.07120
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.distributed.parallel_state import GroupCoordinator
|
||||||
|
from vllm.v1.attention.ops.common import CPTritonContext
|
||||||
|
|
||||||
|
|
||||||
|
def _lse_weighted_combine(
|
||||||
|
outputs: torch.Tensor,
|
||||||
|
lses: torch.Tensor,
|
||||||
|
return_lse: bool = False,
|
||||||
|
is_lse_base_on_e: bool = True,
|
||||||
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
CPU reference implementation for LSE-weighted combination.
|
||||||
|
|
||||||
|
This is a pure PyTorch implementation used for testing and validation.
|
||||||
|
For GPU execution, use dcp_lse_combine_triton instead.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs: Partial attention outputs [N, B, H, D]
|
||||||
|
N = number of KV shards (ranks)
|
||||||
|
B = batch size (num_tokens)
|
||||||
|
H = number of heads per rank
|
||||||
|
D = head dimension
|
||||||
|
lses: Log-sum-exp values [N, B, H]
|
||||||
|
return_lse: If True, also return the global LSE
|
||||||
|
is_lse_base_on_e: If True, LSE is base e; if False, base 2
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined output [B, H, D], and optionally global LSE [B, H]
|
||||||
|
"""
|
||||||
|
N, B, H, D = outputs.shape
|
||||||
|
|
||||||
|
# Handle NaN and inf in LSEs
|
||||||
|
lses = torch.where(
|
||||||
|
torch.isnan(lses) | torch.isinf(lses),
|
||||||
|
torch.tensor(float("-inf"), device=lses.device, dtype=lses.dtype),
|
||||||
|
lses,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute max LSE for numerical stability
|
||||||
|
lse_max, _ = lses.max(dim=0) # [B, H]
|
||||||
|
lse_max = torch.where(
|
||||||
|
lse_max == float("-inf"),
|
||||||
|
torch.zeros_like(lse_max),
|
||||||
|
lse_max,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute weights: softmax over the N dimension
|
||||||
|
if is_lse_base_on_e:
|
||||||
|
weights = torch.exp(lses - lse_max.unsqueeze(0)) # [N, B, H]
|
||||||
|
else:
|
||||||
|
weights = torch.pow(2.0, lses - lse_max.unsqueeze(0)) # [N, B, H]
|
||||||
|
|
||||||
|
# Handle NaN weights
|
||||||
|
weights = torch.where(torch.isnan(weights), torch.zeros_like(weights), weights)
|
||||||
|
|
||||||
|
# Normalize weights
|
||||||
|
weight_sum = weights.sum(dim=0, keepdim=True) # [1, B, H]
|
||||||
|
weights = weights / weight_sum.clamp(min=1e-10) # [N, B, H]
|
||||||
|
|
||||||
|
# Weighted combination: sum over N dimension
|
||||||
|
result = (outputs * weights.unsqueeze(-1)).sum(dim=0) # [B, H, D]
|
||||||
|
|
||||||
|
if return_lse:
|
||||||
|
if is_lse_base_on_e:
|
||||||
|
global_lse = torch.log(weight_sum.squeeze(0)) + lse_max # [B, H]
|
||||||
|
else:
|
||||||
|
global_lse = torch.log2(weight_sum.squeeze(0)) + lse_max # [B, H]
|
||||||
|
return result, global_lse
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _dcp_lse_combine_kernel(
|
||||||
|
# Input pointers
|
||||||
|
recv_output_ptr,
|
||||||
|
recv_lse_ptr,
|
||||||
|
# Output pointers
|
||||||
|
out_ptr,
|
||||||
|
out_lse_ptr,
|
||||||
|
# Strides for recv_output [N, B, H_local, D]
|
||||||
|
ro_stride_N,
|
||||||
|
ro_stride_B,
|
||||||
|
ro_stride_H,
|
||||||
|
ro_stride_D,
|
||||||
|
# Strides for recv_lse [N, B, H_local]
|
||||||
|
rl_stride_N,
|
||||||
|
rl_stride_B,
|
||||||
|
rl_stride_H,
|
||||||
|
# Strides for output [B, H_local, D]
|
||||||
|
o_stride_B,
|
||||||
|
o_stride_H,
|
||||||
|
o_stride_D,
|
||||||
|
# Constants
|
||||||
|
N: tl.constexpr,
|
||||||
|
HEAD_DIM: tl.constexpr,
|
||||||
|
IS_BASE_E: tl.constexpr,
|
||||||
|
RETURN_LSE: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Triton kernel for LSE-weighted combination of partial attention outputs.
|
||||||
|
|
||||||
|
After All-to-All, each rank has:
|
||||||
|
- recv_output [N, B, H_local, D]: partial outputs from all KV shards
|
||||||
|
- recv_lse [N, B, H_local]: partial LSEs from all KV shards
|
||||||
|
|
||||||
|
This kernel computes the weighted combination locally (no communication).
|
||||||
|
|
||||||
|
Grid: (B, H_local)
|
||||||
|
Each program handles one (batch, head) and processes all D elements.
|
||||||
|
"""
|
||||||
|
batch_idx = tl.program_id(0).to(tl.int64)
|
||||||
|
head_idx = tl.program_id(1).to(tl.int64)
|
||||||
|
|
||||||
|
# Base offset for this (batch, head)
|
||||||
|
base_lse_offset = batch_idx * rl_stride_B + head_idx * rl_stride_H
|
||||||
|
base_out_offset = batch_idx * ro_stride_B + head_idx * ro_stride_H
|
||||||
|
|
||||||
|
# First pass: find max LSE for numerical stability
|
||||||
|
lse_max = -float("inf")
|
||||||
|
for n in tl.static_range(N):
|
||||||
|
lse_offset = n * rl_stride_N + base_lse_offset
|
||||||
|
lse_val = tl.load(recv_lse_ptr + lse_offset)
|
||||||
|
lse_val = tl.where(
|
||||||
|
(lse_val != lse_val) | (lse_val == float("inf")),
|
||||||
|
-float("inf"),
|
||||||
|
lse_val,
|
||||||
|
)
|
||||||
|
lse_max = tl.maximum(lse_max, lse_val)
|
||||||
|
|
||||||
|
lse_max = tl.where(lse_max == -float("inf"), 0.0, lse_max)
|
||||||
|
|
||||||
|
# Second pass: compute sum of exp(lse - max)
|
||||||
|
lse_sum = 0.0
|
||||||
|
for n in tl.static_range(N):
|
||||||
|
lse_offset = n * rl_stride_N + base_lse_offset
|
||||||
|
lse_val = tl.load(recv_lse_ptr + lse_offset)
|
||||||
|
lse_val = tl.where(
|
||||||
|
(lse_val != lse_val) | (lse_val == float("inf")),
|
||||||
|
-float("inf"),
|
||||||
|
lse_val,
|
||||||
|
)
|
||||||
|
if IS_BASE_E:
|
||||||
|
lse_sum += tl.exp(lse_val - lse_max)
|
||||||
|
else:
|
||||||
|
lse_sum += tl.exp2(lse_val - lse_max)
|
||||||
|
|
||||||
|
# Compute global LSE
|
||||||
|
if IS_BASE_E: # noqa: SIM108
|
||||||
|
global_lse = tl.log(lse_sum) + lse_max
|
||||||
|
else:
|
||||||
|
global_lse = tl.log2(lse_sum) + lse_max
|
||||||
|
|
||||||
|
# Third pass: weighted combination across D dimension
|
||||||
|
d_offsets = tl.arange(0, HEAD_DIM)
|
||||||
|
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
|
||||||
|
|
||||||
|
for n in tl.static_range(N):
|
||||||
|
lse_offset = n * rl_stride_N + base_lse_offset
|
||||||
|
lse_val = tl.load(recv_lse_ptr + lse_offset)
|
||||||
|
lse_val = tl.where(
|
||||||
|
(lse_val != lse_val) | (lse_val == float("inf")),
|
||||||
|
-float("inf"),
|
||||||
|
lse_val,
|
||||||
|
)
|
||||||
|
if IS_BASE_E:
|
||||||
|
weight = tl.exp(lse_val - global_lse)
|
||||||
|
else:
|
||||||
|
weight = tl.exp2(lse_val - global_lse)
|
||||||
|
weight = tl.where(weight != weight, 0.0, weight)
|
||||||
|
|
||||||
|
out_offsets = n * ro_stride_N + base_out_offset + d_offsets * ro_stride_D
|
||||||
|
out_vals = tl.load(recv_output_ptr + out_offsets)
|
||||||
|
acc += out_vals.to(tl.float32) * weight
|
||||||
|
|
||||||
|
# Store result
|
||||||
|
final_offsets = (
|
||||||
|
batch_idx * o_stride_B + head_idx * o_stride_H + d_offsets * o_stride_D
|
||||||
|
)
|
||||||
|
tl.store(out_ptr + final_offsets, acc)
|
||||||
|
|
||||||
|
if RETURN_LSE:
|
||||||
|
tl.store(out_lse_ptr + base_lse_offset, global_lse)
|
||||||
|
|
||||||
|
|
||||||
|
def dcp_lse_combine_triton(
|
||||||
|
recv_output: torch.Tensor,
|
||||||
|
recv_lse: torch.Tensor,
|
||||||
|
return_lse: bool = False,
|
||||||
|
is_lse_base_on_e: bool = True,
|
||||||
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Triton-accelerated LSE-weighted combination for DCP A2A.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
recv_output: [N, B, H_local, D] - partial outputs from all KV shards
|
||||||
|
recv_lse: [N, B, H_local] - partial LSEs from all KV shards
|
||||||
|
return_lse: If True, also return the global LSE
|
||||||
|
is_lse_base_on_e: If True, LSE is base e; if False, base 2
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined output [B, H_local, D]
|
||||||
|
If return_lse=True, also returns global_lse [B, H_local]
|
||||||
|
"""
|
||||||
|
N, B, H_local, D = recv_output.shape
|
||||||
|
|
||||||
|
out = torch.empty(
|
||||||
|
(B, H_local, D), device=recv_output.device, dtype=recv_output.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
if return_lse:
|
||||||
|
out_lse = torch.empty(
|
||||||
|
(B, H_local), device=recv_lse.device, dtype=recv_lse.dtype
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out_lse = torch.empty(1, device=recv_lse.device, dtype=recv_lse.dtype)
|
||||||
|
|
||||||
|
ro_stride_N, ro_stride_B, ro_stride_H, ro_stride_D = recv_output.stride()
|
||||||
|
rl_stride_N, rl_stride_B, rl_stride_H = recv_lse.stride()
|
||||||
|
o_stride_B, o_stride_H, o_stride_D = out.stride()
|
||||||
|
|
||||||
|
grid = (B, H_local, 1)
|
||||||
|
|
||||||
|
_dcp_lse_combine_kernel[grid](
|
||||||
|
recv_output,
|
||||||
|
recv_lse,
|
||||||
|
out,
|
||||||
|
out_lse,
|
||||||
|
ro_stride_N,
|
||||||
|
ro_stride_B,
|
||||||
|
ro_stride_H,
|
||||||
|
ro_stride_D,
|
||||||
|
rl_stride_N,
|
||||||
|
rl_stride_B,
|
||||||
|
rl_stride_H,
|
||||||
|
o_stride_B,
|
||||||
|
o_stride_H,
|
||||||
|
o_stride_D,
|
||||||
|
N=N,
|
||||||
|
HEAD_DIM=D,
|
||||||
|
IS_BASE_E=is_lse_base_on_e,
|
||||||
|
RETURN_LSE=return_lse,
|
||||||
|
)
|
||||||
|
|
||||||
|
if return_lse:
|
||||||
|
return out, out_lse
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def dcp_a2a_lse_reduce(
|
||||||
|
cp_attn_out: torch.Tensor,
|
||||||
|
cp_attn_lse: torch.Tensor,
|
||||||
|
cp_group: GroupCoordinator,
|
||||||
|
ctx: CPTritonContext | None = None,
|
||||||
|
return_lse: bool = False,
|
||||||
|
is_lse_base_on_e: bool = True,
|
||||||
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Combine partial attention outputs across DCP ranks using All-to-All.
|
||||||
|
|
||||||
|
Each rank holds attention output for all heads but only a local shard
|
||||||
|
of the KV cache. This function:
|
||||||
|
1. Exchanges partial outputs across ranks via All-to-All
|
||||||
|
2. Exchanges LSE values via All-to-All
|
||||||
|
3. Combines them with exact LSE-weighted reduction (Triton kernel)
|
||||||
|
|
||||||
|
Tensor flow:
|
||||||
|
Input: cp_attn_out [B, H, D] - all heads, local KV shard
|
||||||
|
Reshape: [N, B, H/N, D] - split heads across ranks
|
||||||
|
A2A: Two all_to_all_single calls (output and LSE)
|
||||||
|
Combine: recv [N, B, H/N, D] + lse [N, B, H/N] -> [B, H/N, D]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cp_attn_out: [B, H, D] where B=num_tokens, H=total_heads, D=head_dim
|
||||||
|
cp_attn_lse: [B, H] log-sum-exp values (fp32)
|
||||||
|
cp_group: GroupCoordinator for DCP communication
|
||||||
|
ctx: CPTritonContext (unused, for signature compatibility)
|
||||||
|
return_lse: If True, also return the combined global LSE
|
||||||
|
is_lse_base_on_e: If True, LSE is base e; if False, base 2
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined output [B, H/N, D] (head-scattered)
|
||||||
|
If return_lse=True, also returns global_lse [B, H/N]
|
||||||
|
"""
|
||||||
|
world_size = cp_group.world_size
|
||||||
|
|
||||||
|
if world_size == 1:
|
||||||
|
if return_lse:
|
||||||
|
return cp_attn_out, cp_attn_lse
|
||||||
|
return cp_attn_out
|
||||||
|
|
||||||
|
local_output = cp_attn_out.contiguous()
|
||||||
|
local_lse = cp_attn_lse.contiguous()
|
||||||
|
|
||||||
|
B, H, D = local_output.shape
|
||||||
|
H_per_rank = H // world_size
|
||||||
|
|
||||||
|
# Reshape for All-to-All: [B, H, D] -> [N, B, H/N, D]
|
||||||
|
# Split heads into N chunks, each destined for a different rank
|
||||||
|
send_output = (
|
||||||
|
local_output.view(B, world_size, H_per_rank, D).permute(1, 0, 2, 3).contiguous()
|
||||||
|
)
|
||||||
|
recv_output = torch.empty_like(send_output)
|
||||||
|
|
||||||
|
# Same for LSE: [B, H] -> [N, B, H/N]
|
||||||
|
send_lse = local_lse.view(B, world_size, H_per_rank).permute(1, 0, 2).contiguous()
|
||||||
|
recv_lse = torch.empty_like(send_lse)
|
||||||
|
|
||||||
|
# All-to-All for partial attention outputs and LSE values (async overlap)
|
||||||
|
work_output = dist.all_to_all_single(
|
||||||
|
recv_output.view(-1),
|
||||||
|
send_output.view(-1),
|
||||||
|
group=cp_group.device_group,
|
||||||
|
async_op=True,
|
||||||
|
)
|
||||||
|
work_lse = dist.all_to_all_single(
|
||||||
|
recv_lse.view(-1),
|
||||||
|
send_lse.view(-1),
|
||||||
|
group=cp_group.device_group,
|
||||||
|
async_op=True,
|
||||||
|
)
|
||||||
|
work_output.wait()
|
||||||
|
work_lse.wait()
|
||||||
|
|
||||||
|
# LSE-weighted combination via Triton kernel (local, no communication)
|
||||||
|
return dcp_lse_combine_triton(
|
||||||
|
recv_output,
|
||||||
|
recv_lse,
|
||||||
|
return_lse=return_lse,
|
||||||
|
is_lse_base_on_e=is_lse_base_on_e,
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user