[Core] Add All-to-All communication backend for DCP (#34883)

Signed-off-by: Sungsoo Ha <sungsooh@nvidia.com>
Signed-off-by: sungsoo ha <hasungsoo@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
sungsoo ha
2026-03-04 07:01:57 -08:00
committed by GitHub
parent ead7bde1ab
commit 6cb901093f
8 changed files with 658 additions and 17 deletions

View File

@@ -203,8 +203,17 @@ from tqdm import tqdm
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
from vllm.config import (
CacheConfig,
ModelConfig,
VllmConfig,
get_current_vllm_config,
get_current_vllm_config_or_none,
)
from vllm.distributed.parallel_state import (
get_dcp_group,
is_global_first_rank,
)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
@@ -253,6 +262,7 @@ from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills,
)
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
from vllm.v1.attention.ops.dcp_alltoall import dcp_a2a_lse_reduce
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import (
@@ -393,6 +403,13 @@ class MLAAttention(nn.Module, AttentionLayerBase):
self.use_sparse = use_sparse
vllm_config = get_current_vllm_config_or_none()
self.dcp_a2a = (
vllm_config is not None
and vllm_config.parallel_config.decode_context_parallel_size > 1
and vllm_config.parallel_config.dcp_comm_backend == "a2a"
)
# Initialize q/k/v range constants.
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
@@ -647,12 +664,20 @@ class MLAAttention(nn.Module, AttentionLayerBase):
# correct dcp attn_out with lse.
if self.impl.dcp_world_size > 1:
attn_out = cp_lse_ag_out_rs(
attn_out,
lse,
get_dcp_group(),
is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False),
)
if self.dcp_a2a:
attn_out = dcp_a2a_lse_reduce(
attn_out,
lse,
get_dcp_group(),
is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False),
)
else:
attn_out = cp_lse_ag_out_rs(
attn_out,
lse,
get_dcp_group(),
is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False),
)
# v_up projection
self._v_up_proj(attn_out, out=mqa_output_slice)