[Feat][Perf] Enable deepep-low-latency with round-robin expert placement. (#28449)

Signed-off-by: bruceszchen <bruceszchen@tencent.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Chen Bruce
2025-11-19 20:46:24 +08:00
committed by GitHub
parent ba558c029a
commit da2f6800e0
8 changed files with 208 additions and 37 deletions

View File

@@ -67,6 +67,7 @@ def maybe_roundup_layer_hidden_size(
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
moe: FusedMoEConfig, moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig | None, quant_config: FusedMoEQuantConfig | None,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalize | None: ) -> FusedMoEPrepareAndFinalize | None:
if not moe.moe_parallel_config.use_all2all_kernels: if not moe.moe_parallel_config.use_all2all_kernels:
return None return None
@@ -134,6 +135,13 @@ def maybe_make_prepare_finalize(
elif moe.use_deepep_ll_kernels: elif moe.use_deepep_ll_kernels:
assert quant_config is not None assert quant_config is not None
global_to_physical = physical_to_global = local_expert_global_ids = None
if routing_tables is not None:
(
global_to_physical,
physical_to_global,
local_expert_global_ids,
) = routing_tables
all_to_all_args = dict( all_to_all_args = dict(
max_num_tokens_per_dp_rank=moe.max_num_tokens, max_num_tokens_per_dp_rank=moe.max_num_tokens,
token_hidden_size=moe.hidden_dim, token_hidden_size=moe.hidden_dim,
@@ -155,6 +163,9 @@ def maybe_make_prepare_finalize(
max_tokens_per_rank=moe.max_num_tokens, max_tokens_per_rank=moe.max_num_tokens,
num_dispatchers=all2all_manager.world_size, num_dispatchers=all2all_manager.world_size,
use_fp8_dispatch=use_fp8_dispatch, use_fp8_dispatch=use_fp8_dispatch,
global_to_physical=global_to_physical,
physical_to_global=physical_to_global,
local_expert_global_ids=local_expert_global_ids,
) )
return prepare_finalize return prepare_finalize

View File

@@ -85,6 +85,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
max_tokens_per_rank: int, max_tokens_per_rank: int,
num_dispatchers: int, num_dispatchers: int,
use_fp8_dispatch: bool = False, use_fp8_dispatch: bool = False,
global_to_physical: torch.Tensor | None = None,
physical_to_global: torch.Tensor | None = None,
local_expert_global_ids: torch.Tensor | None = None,
): ):
super().__init__() super().__init__()
@@ -97,6 +100,17 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self.handles: list[tuple | None] = [None, None] self.handles: list[tuple | None] = [None, None]
self.num_dispatchers_ = num_dispatchers self.num_dispatchers_ = num_dispatchers
topk_indices_dtype = self.topk_indices_dtype()
def _maybe_cast(tensor: torch.Tensor | None) -> torch.Tensor | None:
if tensor is None or topk_indices_dtype is None:
return tensor
return tensor.to(dtype=topk_indices_dtype)
self.global_to_physical = _maybe_cast(global_to_physical)
self.physical_to_global = _maybe_cast(physical_to_global)
self.local_expert_global_ids = _maybe_cast(local_expert_global_ids)
# We don't have enough information to determine if we should dispatch # We don't have enough information to determine if we should dispatch
# activation scales in a packed ue8m0 format during object construction # activation scales in a packed ue8m0 format during object construction
# time. This setting is handled by post_init_setup. # time. This setting is handled by post_init_setup.
@@ -136,6 +150,16 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def topk_indices_dtype(self) -> torch.dtype | None: def topk_indices_dtype(self) -> torch.dtype | None:
return torch.int64 return torch.int64
def _map_global_to_physical_ids(self, topk_ids: torch.Tensor) -> torch.Tensor:
if self.global_to_physical is None:
return topk_ids
return self.global_to_physical[topk_ids]
def _map_local_to_global_ids(self, expert_topk_ids: torch.Tensor) -> torch.Tensor:
if self.local_expert_global_ids is None:
return expert_topk_ids
return self.local_expert_global_ids[expert_topk_ids]
def _do_quant( def _do_quant(
self, self,
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor], x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
@@ -226,9 +250,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1 = a1 * topk_weights.to(a1.dtype) a1 = a1 * topk_weights.to(a1.dtype)
# Dispatch # Dispatch
dispatch_topk_ids = self._map_global_to_physical_ids(topk_ids)
expert_x, expert_num_tokens, handle, _, hook = self.buffer.low_latency_dispatch( expert_x, expert_num_tokens, handle, _, hook = self.buffer.low_latency_dispatch(
a1, a1,
topk_ids, dispatch_topk_ids,
self.max_tokens_per_rank, self.max_tokens_per_rank,
num_experts, num_experts,
use_fp8=self.use_fp8_dispatch, use_fp8=self.use_fp8_dispatch,
@@ -313,11 +338,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# weights have already been applied. # weights have already been applied.
combine_topk_weights = torch.ones_like(topk_weights) combine_topk_weights = torch.ones_like(topk_weights)
combine_topk_ids = self._map_global_to_physical_ids(topk_ids)
# TODO (varun) : Enable zero copy mode # TODO (varun) : Enable zero copy mode
dbo_maybe_run_recv_hook() dbo_maybe_run_recv_hook()
_, _, recv_hook = self.buffer.low_latency_combine( _, _, recv_hook = self.buffer.low_latency_combine(
fused_expert_output, fused_expert_output,
topk_ids, combine_topk_ids,
combine_topk_weights, combine_topk_weights,
handle, handle,
async_finish=False, async_finish=False,

View File

@@ -50,10 +50,15 @@ class FusedMoEMethodBase(QuantizeMethodBase):
""" """
return False return False
def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None: def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalize | None:
from .all2all_utils import maybe_make_prepare_finalize from .all2all_utils import maybe_make_prepare_finalize
return maybe_make_prepare_finalize(self.moe, self.moe_quant_config) return maybe_make_prepare_finalize(
self.moe, self.moe_quant_config, routing_tables
)
def select_gemm_impl( def select_gemm_impl(
self, self,

View File

@@ -5,7 +5,7 @@ from collections.abc import Callable, Iterable
from contextlib import nullcontext from contextlib import nullcontext
from enum import Enum from enum import Enum
from functools import partial from functools import partial
from typing import Literal, get_args, overload from typing import Literal, cast, get_args, overload
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -192,6 +192,42 @@ def determine_expert_map(
return (local_num_experts, expert_map, expert_mask) return (local_num_experts, expert_map, expert_mask)
def determine_expert_placement_strategy(
expert_placement_strategy: ExpertPlacementStrategy,
moe_parallel_config: FusedMoEParallelConfig,
num_expert_group: int | None,
num_redundant_experts: int,
enable_eplb: bool,
) -> ExpertPlacementStrategy:
if expert_placement_strategy == "round_robin":
round_robin_supported = (
(num_expert_group is not None and num_expert_group > 1)
and num_redundant_experts == 0
and not enable_eplb
)
if not round_robin_supported:
logger.warning(
"Round-robin expert placement is only supported for "
"models with multiple expert groups and no redundant "
"experts. Falling back to linear expert placement."
)
return "linear"
if (
moe_parallel_config.use_all2all_kernels
and not moe_parallel_config.use_deepep_ll_kernels
):
logger.warning(
"Round-robin expert placement currently only supports "
"the DeepEP low-latency backend, but '%s' was configured. "
"Falling back to linear expert placement.",
moe_parallel_config.all2all_backend,
)
return "linear"
return expert_placement_strategy
def get_compressed_expert_map(expert_map: torch.Tensor) -> str: def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
""" """
Compresses the expert map by removing any -1 entries. Compresses the expert map by removing any -1 entries.
@@ -400,6 +436,9 @@ class FusedMoE(CustomOp):
self.expert_load_view: torch.Tensor | None = None self.expert_load_view: torch.Tensor | None = None
self.logical_to_physical_map: torch.Tensor | None = None self.logical_to_physical_map: torch.Tensor | None = None
self.logical_replica_count: torch.Tensor | None = None self.logical_replica_count: torch.Tensor | None = None
self.expert_placement_strategy: ExpertPlacementStrategy = (
vllm_config.parallel_config.expert_placement_strategy
)
# ROCm aiter shared experts fusion # ROCm aiter shared experts fusion
self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled() self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
@@ -433,38 +472,27 @@ class FusedMoE(CustomOp):
"Redundant experts are only supported with EPLB." "Redundant experts are only supported with EPLB."
) )
expert_placement_strategy = ( self.expert_placement_strategy = determine_expert_placement_strategy(
vllm_config.parallel_config.expert_placement_strategy expert_placement_strategy=self.expert_placement_strategy,
moe_parallel_config=self.moe_parallel_config,
num_expert_group=num_expert_group,
num_redundant_experts=num_redundant_experts,
enable_eplb=self.enable_eplb,
) )
if expert_placement_strategy == "round_robin":
# TODO(Bruce): will support round robin expert placement with
# EPLB enabled in the future.
round_robin_supported = (
(num_expert_group is not None and num_expert_group > 1)
and num_redundant_experts == 0
and not self.enable_eplb
)
if not round_robin_supported:
logger.warning(
"Round-robin expert placement is only supported for "
"models with multiple expert groups and no redundant "
"experts. Falling back to linear expert placement."
)
expert_placement_strategy = "linear"
self.expert_map: torch.Tensor | None self.expert_map: torch.Tensor | None
local_num_experts, expert_map, expert_mask = determine_expert_map( local_num_experts, expert_map, expert_mask = determine_expert_map(
ep_size=self.ep_size, ep_size=self.ep_size,
ep_rank=self.ep_rank, ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts, global_num_experts=self.global_num_experts,
expert_placement_strategy=expert_placement_strategy, expert_placement_strategy=self.expert_placement_strategy,
num_fused_shared_experts=self.num_fused_shared_experts, num_fused_shared_experts=self.num_fused_shared_experts,
return_expert_mask=self.rocm_aiter_fmoe_enabled, return_expert_mask=self.rocm_aiter_fmoe_enabled,
) )
self.local_num_experts = local_num_experts self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map) self.register_buffer("expert_map", expert_map)
self.register_buffer("expert_mask", expert_mask) self.register_buffer("expert_mask", expert_mask)
self._maybe_init_expert_routing_tables()
logger.info_once( logger.info_once(
"[EP Rank %s/%s] Expert parallelism is enabled. Expert " "[EP Rank %s/%s] Expert parallelism is enabled. Expert "
"placement strategy: %s. Local/global" "placement strategy: %s. Local/global"
@@ -472,7 +500,7 @@ class FusedMoE(CustomOp):
" %s.", " %s.",
self.ep_rank, self.ep_rank,
self.ep_size, self.ep_size,
expert_placement_strategy, self.expert_placement_strategy,
self.local_num_experts, self.local_num_experts,
self.global_num_experts, self.global_num_experts,
get_compressed_expert_map(self.expert_map), get_compressed_expert_map(self.expert_map),
@@ -621,7 +649,12 @@ class FusedMoE(CustomOp):
# should be safe to swap out the quant_method. # should be safe to swap out the quant_method.
def maybe_init_modular_kernel(self) -> None: def maybe_init_modular_kernel(self) -> None:
self.ensure_moe_quant_config_init() self.ensure_moe_quant_config_init()
prepare_finalize = self.quant_method.maybe_make_prepare_finalize() # routing_tables only needed for round-robin expert placement with
# DeepEP all2all backend.
routing_tables = self._maybe_init_expert_routing_tables()
prepare_finalize = self.quant_method.maybe_make_prepare_finalize(
routing_tables=routing_tables
)
if prepare_finalize is not None: if prepare_finalize is not None:
logger.debug( logger.debug(
"%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self) "%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self)
@@ -703,6 +736,84 @@ class FusedMoE(CustomOp):
# By default, router/gate is called before FusedMoE forward pass # By default, router/gate is called before FusedMoE forward pass
return False return False
def _maybe_init_expert_routing_tables(
self,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None:
# Currently routing_tables only needed for round-robin expert placement
# with DeepEP-ll all2all backend.
if (
self.expert_placement_strategy != "round_robin"
or not self.use_deepep_ll_kernels
):
return None
if hasattr(self, "expert_global_to_physical"):
return cast(
tuple[torch.Tensor, torch.Tensor, torch.Tensor],
(
self.expert_global_to_physical,
self.expert_physical_to_global,
self.expert_local_to_global,
),
)
if self.expert_map is None:
return None
routing_tables = self.ensure_round_robin_expert_routing_tables(
global_num_experts=self.global_num_experts,
ep_size=self.ep_size,
ep_rank=self.ep_rank,
local_num_experts=self.local_num_experts,
device=self.expert_map.device,
)
global_to_physical, physical_to_global, local_global = routing_tables
self.register_buffer("expert_global_to_physical", global_to_physical)
self.register_buffer("expert_physical_to_global", physical_to_global)
self.register_buffer("expert_local_to_global", local_global)
return routing_tables
@staticmethod
def ensure_round_robin_expert_routing_tables(
global_num_experts: int,
ep_size: int,
ep_rank: int,
local_num_experts: int,
device: torch.device | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
device_kwargs = {"device": device} if device is not None else {}
global_indices = torch.arange(
global_num_experts, dtype=torch.long, **device_kwargs
)
owner = torch.remainder(global_indices, ep_size)
local_index = torch.div(global_indices, ep_size, rounding_mode="floor")
base = global_num_experts // ep_size
remainder = global_num_experts % ep_size
physical_offset = owner * base
if remainder > 0:
remainder_tensor = torch.tensor(
remainder, dtype=torch.long, **device_kwargs
)
physical_offset = physical_offset + torch.minimum(owner, remainder_tensor)
global_to_physical = physical_offset + local_index
physical_to_global = torch.empty_like(global_to_physical)
physical_to_global[global_to_physical] = global_indices
local_global = torch.arange(
ep_rank,
global_num_experts,
ep_size,
dtype=torch.long,
**device_kwargs,
)
if local_global.numel() != local_num_experts:
local_global = local_global[:local_num_experts]
return (global_to_physical, physical_to_global, local_global)
def update_expert_map(self): def update_expert_map(self):
# ep_size and ep_rank should already be updated # ep_size and ep_rank should already be updated
assert self.expert_map is not None assert self.expert_map is not None
@@ -711,12 +822,14 @@ class FusedMoE(CustomOp):
ep_size=self.ep_size, ep_size=self.ep_size,
ep_rank=self.ep_rank, ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts, global_num_experts=self.global_num_experts,
expert_placement_strategy=self.expert_placement_strategy,
num_fused_shared_experts=self.num_fused_shared_experts, num_fused_shared_experts=self.num_fused_shared_experts,
return_expert_mask=self.rocm_aiter_fmoe_enabled, return_expert_mask=self.rocm_aiter_fmoe_enabled,
) )
self.local_num_experts = local_num_experts self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map) self.register_buffer("expert_map", expert_map)
self.register_buffer("expert_mask", expert_mask) self.register_buffer("expert_mask", expert_mask)
self._maybe_init_expert_routing_tables()
if self.aiter_fmoe_shared_expert_enabled: if self.aiter_fmoe_shared_expert_enabled:
self._init_aiter_shared_experts_topK_buffer( self._init_aiter_shared_experts_topK_buffer(
vllm_config=get_current_vllm_config(), vllm_config=get_current_vllm_config(),

View File

@@ -108,11 +108,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def allow_inplace(self) -> bool: def allow_inplace(self) -> bool:
return True return True
def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None: def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalize | None:
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
return None return None
else: else:
return super().maybe_make_prepare_finalize() return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,

View File

@@ -380,11 +380,14 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
(layer.w2_input_global_scale), requires_grad=False (layer.w2_input_global_scale), requires_grad=False
) )
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None: def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
if self.use_marlin: if self.use_marlin:
return None return None
elif not self.allow_flashinfer: elif not self.allow_flashinfer:
return super().maybe_make_prepare_finalize() return super().maybe_make_prepare_finalize(routing_tables)
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(self.moe) prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(self.moe)
logger.debug_once("%s", prepare_finalize.__class__.__name__) logger.debug_once("%s", prepare_finalize.__class__.__name__)
@@ -890,11 +893,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight_scale layer.w2_weight_scale
) )
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None: def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
if self.use_marlin or self.rocm_aiter_moe_enabled: if self.use_marlin or self.rocm_aiter_moe_enabled:
return None return None
else: else:
return super().maybe_make_prepare_finalize() return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,

View File

@@ -1018,7 +1018,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w13_input_scale del layer.w13_input_scale
del layer.w2_input_scale del layer.w2_input_scale
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None: def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
if ( if (
self.rocm_aiter_moe_enabled self.rocm_aiter_moe_enabled
or self.use_marlin or self.use_marlin
@@ -1039,7 +1042,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
logger.debug_once("%s", prepare_finalize.__class__.__name__) logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize return prepare_finalize
else: else:
return super().maybe_make_prepare_finalize() return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,

View File

@@ -373,6 +373,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
# TRT LLM not supported with all2all yet. # TRT LLM not supported with all2all yet.
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
@@ -384,7 +385,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
logger.debug_once("%s", prepare_finalize.__class__.__name__) logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize return prepare_finalize
else: else:
return super().maybe_make_prepare_finalize() return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,
@@ -1179,7 +1180,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
" for ModelOptNvFp4FusedMoE." " for ModelOptNvFp4FusedMoE."
) )
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None: def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
if self.use_marlin or ( if self.use_marlin or (
self.allow_flashinfer self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
@@ -1196,7 +1200,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
logger.debug_once("%s", prepare_finalize.__class__.__name__) logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize return prepare_finalize
else: else:
return super().maybe_make_prepare_finalize() return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,