[EPLB] Reduce EPLB Inference Overhead (#24573)

Signed-off-by: Bowen Wang <abmfy@icloud.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Bowen Wang
2025-09-22 09:31:05 -07:00
committed by GitHub
parent 175811e3b5
commit 06a41334c7
2 changed files with 92 additions and 50 deletions

View File

@@ -1017,6 +1017,79 @@ def grouped_topk(
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def eplb_map_to_physical_and_record(
topk_ids: torch.Tensor,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
indices_type: Optional[torch.dtype] = None) -> torch.Tensor:
'''
Map the logical expert ids to physical expert ids
and record the expert load metrics.
This will select a pseudo-random replica for each logical expert.
Only used for EPLB.
Args:
topk_ids: The logical expert ids.
expert_load_view: The expert load view.
logical_to_physical_map: The logical to physical map.
logical_replica_count: The logical replica count.
indices_type: The indices type.
Returns:
The physical expert ids.
'''
# 1. Convert the logical expert ids to physical expert ids
# Directly select a random replica for each logical expert
# In case `indices_type` is not `torch.long` or `torch.int`,
# e.g. `torch.uint32` as required by dispatch/combine kernels
topk_ids_long = topk_ids.long()
# Use (token position) modulo (replica count)
# to deterministically choose a replica
replica_count = logical_replica_count[topk_ids_long]
# Flatten-position based index, reshaped back to `topk_ids` shape
pos_indices = torch.arange(topk_ids.numel(),
device=topk_ids.device,
dtype=torch.long).reshape_as(topk_ids)
# Compute pseudo-random indices by modulo
replica_indices = (pos_indices % replica_count).unsqueeze(-1)
physical_ids = logical_to_physical_map[topk_ids_long].gather(
-1, replica_indices).squeeze(-1)
topk_ids = physical_ids
# 2. Record expert load metrics.
# TODO(bowen): When using `FusedMoEModularKernel`, this
# can be done in a more unified way, since
# `FusedMoEPrepareAndFinalize` will return the expert
# token count, in some cases directly from the kernel.
# However, now there are many code paths not using
# the modular kernel, e.g. calling `fused_experts`,
# so we decide to keep the logic here.
#
# If later refactor moved all the MoE kernel calls
# to the modular kernel, we can move this logic there
# to achieve better efficiency.
# `expert_load_view`: (num_physical_experts,)
# `torch.bincount` is not compilable, so use `scatter_add_` instead.
topk_ids_flatten = topk_ids.flatten()
expert_load_view.scatter_add_(
dim=0,
index=topk_ids_flatten.long(),
src=torch.ones_like(topk_ids_flatten).to(expert_load_view))
if indices_type is not None:
topk_ids = topk_ids.to(dtype=indices_type)
return topk_ids
def fused_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,