[Core] Add All-to-All communication backend for DCP (#34883)
Signed-off-by: Sungsoo Ha <sungsooh@nvidia.com> Signed-off-by: sungsoo ha <hasungsoo@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -203,8 +203,17 @@ from tqdm import tqdm
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
ModelConfig,
|
||||
VllmConfig,
|
||||
get_current_vllm_config,
|
||||
get_current_vllm_config_or_none,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_dcp_group,
|
||||
is_global_first_rank,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
@@ -253,6 +262,7 @@ from vllm.v1.attention.backends.utils import (
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
|
||||
from vllm.v1.attention.ops.dcp_alltoall import dcp_a2a_lse_reduce
|
||||
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
@@ -393,6 +403,13 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
|
||||
self.use_sparse = use_sparse
|
||||
|
||||
vllm_config = get_current_vllm_config_or_none()
|
||||
self.dcp_a2a = (
|
||||
vllm_config is not None
|
||||
and vllm_config.parallel_config.decode_context_parallel_size > 1
|
||||
and vllm_config.parallel_config.dcp_comm_backend == "a2a"
|
||||
)
|
||||
|
||||
# Initialize q/k/v range constants.
|
||||
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
|
||||
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
|
||||
@@ -647,12 +664,20 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
|
||||
# correct dcp attn_out with lse.
|
||||
if self.impl.dcp_world_size > 1:
|
||||
attn_out = cp_lse_ag_out_rs(
|
||||
attn_out,
|
||||
lse,
|
||||
get_dcp_group(),
|
||||
is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False),
|
||||
)
|
||||
if self.dcp_a2a:
|
||||
attn_out = dcp_a2a_lse_reduce(
|
||||
attn_out,
|
||||
lse,
|
||||
get_dcp_group(),
|
||||
is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False),
|
||||
)
|
||||
else:
|
||||
attn_out = cp_lse_ag_out_rs(
|
||||
attn_out,
|
||||
lse,
|
||||
get_dcp_group(),
|
||||
is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False),
|
||||
)
|
||||
|
||||
# v_up projection
|
||||
self._v_up_proj(attn_out, out=mqa_output_slice)
|
||||
|
||||
Reference in New Issue
Block a user