Fix routed experts capture for hybrid models (Mamba + Attention) (#35744)

Signed-off-by: arlenxu <arlenxu@tencent.com>
Signed-off-by: xhx1022 <1737006628@qq.com>
Co-authored-by: arlenxu <arlenxu@tencent.com>
This commit is contained in:
Hongxin Xu
2026-03-11 23:53:10 +08:00
committed by GitHub
parent a3ea760ea5
commit bea02cdf93
4 changed files with 442 additions and 16 deletions

View File

@@ -52,7 +52,7 @@ from vllm.v1.core.sched.request_queue import (
)
from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig
from vllm.v1.metrics.perf import ModelMetrics, PerfStats
from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
@@ -259,9 +259,26 @@ class Scheduler(SchedulerInterface):
assert len(kv_cache_config.kv_cache_groups) > 0, (
"enable_return_routed_experts requires at least one kv cache group"
)
# Find the attention group for routed experts indexing.
self.routed_experts_attn_gid = 0
for gid, group in enumerate(kv_cache_config.kv_cache_groups):
if isinstance(group.kv_cache_spec, AttentionSpec):
self.routed_experts_attn_gid = gid
break
min_block_size = min(
[
group.kv_cache_spec.block_size
for group in kv_cache_config.kv_cache_groups
]
)
num_groups = len(kv_cache_config.kv_cache_groups)
self.max_num_kv_tokens = (
kv_cache_config.num_blocks // len(kv_cache_config.kv_cache_groups) + 1
) * self.block_size
kv_cache_config.num_blocks // num_groups
) * min_block_size
dcp_size = self.vllm_config.parallel_config.decode_context_parallel_size
pcp_size = self.vllm_config.parallel_config.prefill_context_parallel_size
if pcp_size * dcp_size > 1:
self.max_num_kv_tokens *= pcp_size * dcp_size
self.routed_experts_reader.attach_buffer(
max_num_kv_tokens=self.max_num_kv_tokens,
@@ -1561,13 +1578,14 @@ class Scheduler(SchedulerInterface):
return None
kv_blocks = self.kv_cache_manager.get_blocks(request.request_id)
block_ids = kv_blocks.get_block_ids()[0]
block_ids = kv_blocks.get_block_ids()[self.routed_experts_attn_gid]
num_tokens = request.num_tokens - 1
# compute slot mapping
# compute slot mapping using attention group's block_size
block_ids_array = np.array(block_ids, dtype=np.int32)
num_blocks = len(block_ids)
block_size = self.block_size
attn_group = self.kv_cache_config.kv_cache_groups[self.routed_experts_attn_gid]
block_size = attn_group.kv_cache_spec.block_size
# generate block offsets
block_offsets = np.arange(0, block_size)

View File

@@ -422,6 +422,9 @@ class GPUModelRunner(
)
# This will be overridden in load_model()
self.is_multimodal_pruning_enabled = False
# Set to True after init_routed_experts_capturer() completes.
# Prevents routed experts code from running during profiling/dummy run.
self.routed_experts_initialized = False
self.max_model_len = model_config.max_model_len
# Always set to false after the first forward pass
@@ -1951,8 +1954,10 @@ class GPUModelRunner(
block_table_gid_0 = _get_block_table(0)
slot_mapping_gid_0 = slot_mappings[0]
if self.model_config.enable_return_routed_experts:
self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy()
if self.routed_experts_initialized:
attn_gid = self.routed_experts_attn_gid
slot_mapping_attn = slot_mappings[attn_gid]
self.slot_mapping = slot_mapping_attn[:num_tokens].cpu().numpy()
cm_base = CommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1],
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1],
@@ -3540,7 +3545,7 @@ class GPUModelRunner(
"after execute_model() returns None."
)
if self.vllm_config.model_config.enable_return_routed_experts:
if self.routed_experts_initialized:
capturer = RoutedExpertsCapturer.get_instance()
if capturer is not None:
capturer.clear_buffer() # noqa
@@ -4049,7 +4054,7 @@ class GPUModelRunner(
self.kv_connector_output = None
with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"):
if self.model_config.enable_return_routed_experts:
if self.routed_experts_initialized:
capturer = RoutedExpertsCapturer.get_instance()
if capturer is not None:
capturer.save_captured_experts(indices=self.slot_mapping) # noqa
@@ -6531,8 +6536,12 @@ class GPUModelRunner(
kv_transfer_group.register_kv_caches(kv_caches)
kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks)
if self.model_config.enable_return_routed_experts:
self.init_routed_experts_capturer()
def _get_attention_kv_cache_gid(self) -> int:
"""Find the KV cache group index for attention layers."""
for gid, group in enumerate(self.kv_cache_config.kv_cache_groups):
if isinstance(group.kv_cache_spec, AttentionSpec):
return gid
return 0
def init_routed_experts_capturer(self):
logger.info(
@@ -6540,17 +6549,29 @@ class GPUModelRunner(
self.model_config.enable_return_routed_experts,
)
routed_experts_capturer = RoutedExpertsCapturer.create()
block_size = self.cache_config.block_size
self.routed_experts_attn_gid = self._get_attention_kv_cache_gid()
min_block_size = min(
[
group.kv_cache_spec.block_size
for group in self.kv_cache_config.kv_cache_groups
]
)
num_groups = len(self.kv_cache_config.kv_cache_groups)
self.max_num_kv_tokens = (
self.kv_cache_config.num_blocks // len(self.kv_cache_config.kv_cache_groups)
+ 1
) * block_size
self.kv_cache_config.num_blocks // num_groups
) * min_block_size
dcp_size = self.vllm_config.parallel_config.decode_context_parallel_size
pcp_size = self.vllm_config.parallel_config.prefill_context_parallel_size
if pcp_size * dcp_size > 1:
self.max_num_kv_tokens *= pcp_size * dcp_size
routed_experts_capturer.init_buffer(
max_num_batched_tokens=self.scheduler_config.max_num_batched_tokens,
max_num_kv_tokens=self.max_num_kv_tokens,
vllm_config=self.vllm_config,
)
self._bind_routed_experts_capturer(routed_experts_capturer)
self.routed_experts_initialized = True
def _bind_routed_experts_capturer(self, capturer: RoutedExpertsCapturer) -> None:
from vllm.model_executor.layers.fused_moe.layer import FusedMoE

View File

@@ -552,6 +552,9 @@ class Worker(WorkerBase):
else:
self.model_runner.initialize_kv_cache(kv_cache_config)
if self.model_config.enable_return_routed_experts:
self.model_runner.init_routed_experts_capturer()
# Build KV-zero metadata outside the CuMem pool so the bookkeeping
# GPU tensors (seg_addrs, block-id buffers) use the standard PyTorch
# allocator and are not discarded during sleep/wake cycles.