[Bugfix][Wide EP] Fix redundant work when using DeepEP, TP Attn, and EP MoE (#24134)
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
committed by
GitHub
parent
4f87abdcc6
commit
955c624915
@@ -35,7 +35,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx,
|
||||
from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx,
|
||||
round_up)
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
@@ -786,6 +786,7 @@ class FusedMoE(CustomOp):
|
||||
enable_eplb: bool = False,
|
||||
num_redundant_experts: int = 0,
|
||||
has_bias: bool = False,
|
||||
is_sequence_parallel=False,
|
||||
):
|
||||
super().__init__()
|
||||
if params_dtype is None:
|
||||
@@ -797,6 +798,10 @@ class FusedMoE(CustomOp):
|
||||
dp_size_ = (dp_size
|
||||
if dp_size is not None else get_dp_group().world_size)
|
||||
|
||||
self.is_sequence_parallel = is_sequence_parallel
|
||||
if self.is_sequence_parallel:
|
||||
self.sp_size = tp_size_
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.moe_parallel_config: FusedMoEParallelConfig = (
|
||||
FusedMoEParallelConfig.make(
|
||||
@@ -1699,14 +1704,22 @@ class FusedMoE(CustomOp):
|
||||
|
||||
ctx = get_forward_context()
|
||||
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
|
||||
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
|
||||
max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu
|
||||
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
|
||||
|
||||
# If the input to the MoE is sequence parallel then divide by sp_size
|
||||
# to find the maximum number of tokens for any individual dispatcher.
|
||||
if self.is_sequence_parallel:
|
||||
max_tokens_across_dispatchers = cdiv(max_tokens_across_dispatchers,
|
||||
self.sp_size)
|
||||
|
||||
num_tokens = full_hidden_states.size(0)
|
||||
for chunk_idx, chunk_start_ in enumerate(
|
||||
range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank)):
|
||||
range(0, max_tokens_across_dispatchers,
|
||||
moe_dp_chunk_size_per_rank)):
|
||||
chunk_start = chunk_start_
|
||||
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank,
|
||||
max_tokens_across_dp)
|
||||
max_tokens_across_dispatchers)
|
||||
# clamp start and end
|
||||
chunk_start = min(chunk_start, num_tokens - 1)
|
||||
chunk_end = min(chunk_end, num_tokens)
|
||||
|
||||
Reference in New Issue
Block a user