[Bugfix][Wide EP] Fix redundant work when using DeepEP, TP Attn, and EP MoE (#24134)

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
Tyler Michael Smith
2025-09-08 22:01:51 -04:00
committed by GitHub
parent 4f87abdcc6
commit 955c624915
4 changed files with 132 additions and 59 deletions

View File

@@ -35,7 +35,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx,
from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx,
round_up)
if current_platform.is_cuda_alike():
@@ -786,6 +786,7 @@ class FusedMoE(CustomOp):
enable_eplb: bool = False,
num_redundant_experts: int = 0,
has_bias: bool = False,
is_sequence_parallel=False,
):
super().__init__()
if params_dtype is None:
@@ -797,6 +798,10 @@ class FusedMoE(CustomOp):
dp_size_ = (dp_size
if dp_size is not None else get_dp_group().world_size)
self.is_sequence_parallel = is_sequence_parallel
if self.is_sequence_parallel:
self.sp_size = tp_size_
vllm_config = get_current_vllm_config()
self.moe_parallel_config: FusedMoEParallelConfig = (
FusedMoEParallelConfig.make(
@@ -1699,14 +1704,22 @@ class FusedMoE(CustomOp):
ctx = get_forward_context()
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
# If the input to the MoE is sequence parallel then divide by sp_size
# to find the maximum number of tokens for any individual dispatcher.
if self.is_sequence_parallel:
max_tokens_across_dispatchers = cdiv(max_tokens_across_dispatchers,
self.sp_size)
num_tokens = full_hidden_states.size(0)
for chunk_idx, chunk_start_ in enumerate(
range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank)):
range(0, max_tokens_across_dispatchers,
moe_dp_chunk_size_per_rank)):
chunk_start = chunk_start_
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank,
max_tokens_across_dp)
max_tokens_across_dispatchers)
# clamp start and end
chunk_start = min(chunk_start, num_tokens - 1)
chunk_end = min(chunk_end, num_tokens)