[BugFix][Router Replay] Capture Logical Experts with EPLB (#33013)

Signed-off-by: Hollow Man <hollowman@opensuse.org>
This commit is contained in:
ℍ𝕠𝕝𝕝𝕠𝕨 𝕄𝕒𝕟
2026-01-31 17:12:17 +02:00
committed by GitHub
parent 15f40b20aa
commit 13b842f271
4 changed files with 185 additions and 21 deletions

View File

@@ -0,0 +1,160 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import types
import pytest
import torch
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
pytestmark = pytest.mark.cpu_test
class DummyRouter(BaseRouter):
@property
def routing_method_type(self) -> RoutingMethodType:
return RoutingMethodType.FUSED_TOPK
def _compute_routing(self, hidden_states, router_logits, indices_type):
topk_ids = torch.tensor([[1, 2], [3, 4]], dtype=torch.int64)
topk_weights = torch.ones_like(topk_ids, dtype=torch.float32)
return topk_weights, topk_ids
def _apply_eplb_mapping(self, topk_ids: torch.Tensor) -> torch.Tensor:
# Make mapping observable without requiring CUDA EPLB path.
return topk_ids + 10
def _make_router() -> DummyRouter:
return DummyRouter(
top_k=2,
global_num_experts=16,
eplb_state=EplbLayerState(),
enable_eplb=False,
indices_type_getter=None,
)
def test_base_router_capture_pre_eplb_mapping():
router = _make_router()
captured = []
def capture_fn(ids):
captured.append(ids.clone())
router.set_capture_fn(capture_fn)
topk_weights, topk_ids = router.select_experts(
hidden_states=torch.empty(1),
router_logits=torch.empty(1),
)
assert topk_weights.shape == topk_ids.shape
assert len(captured) == 1
assert torch.equal(captured[0], torch.tensor([[1, 2], [3, 4]]))
assert torch.equal(topk_ids, torch.tensor([[11, 12], [13, 14]]))
def test_base_router_capture_with_eplb_enabled():
router = _make_router()
router.enable_eplb = True
router.eplb_state.expert_load_view = torch.zeros(32, dtype=torch.int64)
router.eplb_state.logical_to_physical_map = torch.arange(32).view(32, 1)
router.eplb_state.logical_replica_count = torch.ones(32, dtype=torch.int64)
captured = []
def capture_fn(ids):
captured.append(ids.clone())
router.set_capture_fn(capture_fn)
_, topk_ids = router.select_experts(
hidden_states=torch.empty(1),
router_logits=torch.empty(1),
)
assert len(captured) == 1
# Capture should see logical ids pre-EPLB mapping.
assert torch.equal(captured[0], torch.tensor([[1, 2], [3, 4]]))
# Our DummyRouter mapping adds +10.
assert torch.equal(topk_ids, torch.tensor([[11, 12], [13, 14]]))
def test_gpu_model_runner_binds_router_capture(monkeypatch):
from vllm.v1.worker import gpu_model_runner as gmr
class DummyFusedMoE:
def __init__(self):
self.layer_id = 7
self.router = _make_router()
class DummyCapturer:
def __init__(self):
self.calls = []
def capture(self, layer_id, topk_ids):
self.calls.append((layer_id, topk_ids))
dummy_module = DummyFusedMoE()
# Patch the runtime import inside _bind_routed_experts_capturer.
import vllm.model_executor.layers.fused_moe.layer as fused_moe_layer
monkeypatch.setattr(fused_moe_layer, "FusedMoE", DummyFusedMoE)
dummy_self = types.SimpleNamespace(
compilation_config=types.SimpleNamespace(
static_forward_context={"dummy": dummy_module}
)
)
capturer = DummyCapturer()
gmr.GPUModelRunner._bind_routed_experts_capturer(dummy_self, capturer)
assert dummy_module.router.capture_fn is not None
dummy_module.router.capture_fn(torch.tensor([[5, 6]]))
assert len(capturer.calls) == 1
layer_id, topk_ids = capturer.calls[0]
assert layer_id == 7
assert torch.equal(topk_ids, torch.tensor([[5, 6]]))
def test_gpu_model_runner_binding_stage(monkeypatch):
from vllm.v1.worker import gpu_model_runner as gmr
class DummyFusedMoE:
def __init__(self):
self.layer_id = 11
self.router = _make_router()
class DummyCapturer:
def __init__(self):
self.calls = []
def capture(self, layer_id, topk_ids):
self.calls.append((layer_id, topk_ids))
dummy_module = DummyFusedMoE()
import vllm.model_executor.layers.fused_moe.layer as fused_moe_layer
monkeypatch.setattr(fused_moe_layer, "FusedMoE", DummyFusedMoE)
dummy_self = types.SimpleNamespace(
compilation_config=types.SimpleNamespace(
static_forward_context={"dummy": dummy_module}
)
)
# Before binding, no capture hook.
assert dummy_module.router.capture_fn is None
capturer = DummyCapturer()
gmr.GPUModelRunner._bind_routed_experts_capturer(dummy_self, capturer)
# After binding, hook should exist and be callable.
assert callable(dummy_module.router.capture_fn)
dummy_module.router.capture_fn(torch.tensor([[9, 10]]))
assert len(capturer.calls) == 1

View File

@@ -44,9 +44,6 @@ from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
init_aiter_topK_meta_data,
)
from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
RoutedExpertsCapturer,
)
from vllm.model_executor.layers.fused_moe.router.router_factory import (
create_fused_moe_router,
)
@@ -523,18 +520,6 @@ class FusedMoE(CustomOp):
self.apply_router_weight_on_input = apply_router_weight_on_input
self.activation = activation
self.capture: Callable[[torch.Tensor], None] | None = None
if (
self.vllm_config.model_config is not None
and self.vllm_config.model_config.enable_return_routed_experts
):
# In dummy runs, the capturer is not initialized.
capturer = RoutedExpertsCapturer.get_instance()
if capturer is not None:
self.capture = lambda topk_ids: capturer.capture(
self.layer_id, topk_ids
)
self.router = create_fused_moe_router(
top_k=top_k,
global_num_experts=self.global_num_experts,
@@ -1688,9 +1673,6 @@ class FusedMoE(CustomOp):
router_logits=staged_router_logits,
)
if self.capture is not None:
self.capture(topk_ids)
final_hidden_states = self.quant_method.apply(
layer=self,
x=staged_hidden_states,
@@ -1883,9 +1865,6 @@ class FusedMoE(CustomOp):
router_logits=router_logits,
)
if self.capture is not None:
self.capture(topk_ids)
final_hidden_states = self.quant_method.apply(
layer=self,
x=x, # The type signture of this is wrong due to the hack.

View File

@@ -127,6 +127,11 @@ class BaseRouter(FusedMoERouter):
self.eplb_state = eplb_state
self.enable_eplb = enable_eplb
self.indices_type_getter = indices_type_getter
self.capture_fn: Callable[[torch.Tensor], None] | None = None
def set_capture_fn(self, capture_fn: Callable[[torch.Tensor], None] | None) -> None:
"""Set a capture callback for logical routed expert IDs."""
self.capture_fn = capture_fn
def _validate_eplb_state(self) -> None:
"""Validate that EPLB state is properly initialized if EPLB is enabled."""
@@ -231,6 +236,10 @@ class BaseRouter(FusedMoERouter):
hidden_states, router_logits, indices_type
)
# Capture logical ids before EPLB mapping.
if self.capture_fn is not None:
self.capture_fn(topk_ids)
# Step 4: Apply EPLB mapping
topk_ids = self._apply_eplb_mapping(topk_ids)

View File

@@ -6068,6 +6068,22 @@ class GPUModelRunner(
max_num_kv_tokens=self.max_num_kv_tokens,
vllm_config=self.vllm_config,
)
self._bind_routed_experts_capturer(routed_experts_capturer)
def _bind_routed_experts_capturer(self, capturer: RoutedExpertsCapturer) -> None:
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.fused_moe.router.base_router import (
BaseRouter,
)
for module in self.compilation_config.static_forward_context.values():
if isinstance(module, FusedMoE) and isinstance(module.router, BaseRouter):
layer_id = module.layer_id
def _capture_fn(topk_ids, _layer_id=layer_id, _capturer=capturer):
_capturer.capture(_layer_id, topk_ids)
module.router.set_capture_fn(_capture_fn)
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
"""