[BugFix][Router Replay] Capture Logical Experts with EPLB (#33013)
Signed-off-by: Hollow Man <hollowman@opensuse.org>
This commit is contained in:
@@ -44,9 +44,6 @@ from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
init_aiter_topK_meta_data,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
|
||||
RoutedExpertsCapturer,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.router_factory import (
|
||||
create_fused_moe_router,
|
||||
)
|
||||
@@ -523,18 +520,6 @@ class FusedMoE(CustomOp):
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
self.activation = activation
|
||||
|
||||
self.capture: Callable[[torch.Tensor], None] | None = None
|
||||
if (
|
||||
self.vllm_config.model_config is not None
|
||||
and self.vllm_config.model_config.enable_return_routed_experts
|
||||
):
|
||||
# In dummy runs, the capturer is not initialized.
|
||||
capturer = RoutedExpertsCapturer.get_instance()
|
||||
if capturer is not None:
|
||||
self.capture = lambda topk_ids: capturer.capture(
|
||||
self.layer_id, topk_ids
|
||||
)
|
||||
|
||||
self.router = create_fused_moe_router(
|
||||
top_k=top_k,
|
||||
global_num_experts=self.global_num_experts,
|
||||
@@ -1688,9 +1673,6 @@ class FusedMoE(CustomOp):
|
||||
router_logits=staged_router_logits,
|
||||
)
|
||||
|
||||
if self.capture is not None:
|
||||
self.capture(topk_ids)
|
||||
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=staged_hidden_states,
|
||||
@@ -1883,9 +1865,6 @@ class FusedMoE(CustomOp):
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
if self.capture is not None:
|
||||
self.capture(topk_ids)
|
||||
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=x, # The type signture of this is wrong due to the hack.
|
||||
|
||||
@@ -127,6 +127,11 @@ class BaseRouter(FusedMoERouter):
|
||||
self.eplb_state = eplb_state
|
||||
self.enable_eplb = enable_eplb
|
||||
self.indices_type_getter = indices_type_getter
|
||||
self.capture_fn: Callable[[torch.Tensor], None] | None = None
|
||||
|
||||
def set_capture_fn(self, capture_fn: Callable[[torch.Tensor], None] | None) -> None:
|
||||
"""Set a capture callback for logical routed expert IDs."""
|
||||
self.capture_fn = capture_fn
|
||||
|
||||
def _validate_eplb_state(self) -> None:
|
||||
"""Validate that EPLB state is properly initialized if EPLB is enabled."""
|
||||
@@ -231,6 +236,10 @@ class BaseRouter(FusedMoERouter):
|
||||
hidden_states, router_logits, indices_type
|
||||
)
|
||||
|
||||
# Capture logical ids before EPLB mapping.
|
||||
if self.capture_fn is not None:
|
||||
self.capture_fn(topk_ids)
|
||||
|
||||
# Step 4: Apply EPLB mapping
|
||||
topk_ids = self._apply_eplb_mapping(topk_ids)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user