Fix routed experts capture for hybrid models (Mamba + Attention) (#35744)
Signed-off-by: arlenxu <arlenxu@tencent.com> Signed-off-by: xhx1022 <1737006628@qq.com> Co-authored-by: arlenxu <arlenxu@tencent.com>
This commit is contained in:
@@ -52,7 +52,7 @@ from vllm.v1.core.sched.request_queue import (
|
||||
)
|
||||
from vllm.v1.core.sched.utils import check_stop, remove_all
|
||||
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig
|
||||
from vllm.v1.metrics.perf import ModelMetrics, PerfStats
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
|
||||
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
|
||||
@@ -259,9 +259,26 @@ class Scheduler(SchedulerInterface):
|
||||
assert len(kv_cache_config.kv_cache_groups) > 0, (
|
||||
"enable_return_routed_experts requires at least one kv cache group"
|
||||
)
|
||||
# Find the attention group for routed experts indexing.
|
||||
self.routed_experts_attn_gid = 0
|
||||
for gid, group in enumerate(kv_cache_config.kv_cache_groups):
|
||||
if isinstance(group.kv_cache_spec, AttentionSpec):
|
||||
self.routed_experts_attn_gid = gid
|
||||
break
|
||||
min_block_size = min(
|
||||
[
|
||||
group.kv_cache_spec.block_size
|
||||
for group in kv_cache_config.kv_cache_groups
|
||||
]
|
||||
)
|
||||
num_groups = len(kv_cache_config.kv_cache_groups)
|
||||
self.max_num_kv_tokens = (
|
||||
kv_cache_config.num_blocks // len(kv_cache_config.kv_cache_groups) + 1
|
||||
) * self.block_size
|
||||
kv_cache_config.num_blocks // num_groups
|
||||
) * min_block_size
|
||||
dcp_size = self.vllm_config.parallel_config.decode_context_parallel_size
|
||||
pcp_size = self.vllm_config.parallel_config.prefill_context_parallel_size
|
||||
if pcp_size * dcp_size > 1:
|
||||
self.max_num_kv_tokens *= pcp_size * dcp_size
|
||||
|
||||
self.routed_experts_reader.attach_buffer(
|
||||
max_num_kv_tokens=self.max_num_kv_tokens,
|
||||
@@ -1561,13 +1578,14 @@ class Scheduler(SchedulerInterface):
|
||||
return None
|
||||
|
||||
kv_blocks = self.kv_cache_manager.get_blocks(request.request_id)
|
||||
block_ids = kv_blocks.get_block_ids()[0]
|
||||
block_ids = kv_blocks.get_block_ids()[self.routed_experts_attn_gid]
|
||||
num_tokens = request.num_tokens - 1
|
||||
|
||||
# compute slot mapping
|
||||
# compute slot mapping using attention group's block_size
|
||||
block_ids_array = np.array(block_ids, dtype=np.int32)
|
||||
num_blocks = len(block_ids)
|
||||
block_size = self.block_size
|
||||
attn_group = self.kv_cache_config.kv_cache_groups[self.routed_experts_attn_gid]
|
||||
block_size = attn_group.kv_cache_spec.block_size
|
||||
|
||||
# generate block offsets
|
||||
block_offsets = np.arange(0, block_size)
|
||||
|
||||
@@ -422,6 +422,9 @@ class GPUModelRunner(
|
||||
)
|
||||
# This will be overridden in load_model()
|
||||
self.is_multimodal_pruning_enabled = False
|
||||
# Set to True after init_routed_experts_capturer() completes.
|
||||
# Prevents routed experts code from running during profiling/dummy run.
|
||||
self.routed_experts_initialized = False
|
||||
self.max_model_len = model_config.max_model_len
|
||||
|
||||
# Always set to false after the first forward pass
|
||||
@@ -1951,8 +1954,10 @@ class GPUModelRunner(
|
||||
block_table_gid_0 = _get_block_table(0)
|
||||
slot_mapping_gid_0 = slot_mappings[0]
|
||||
|
||||
if self.model_config.enable_return_routed_experts:
|
||||
self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy()
|
||||
if self.routed_experts_initialized:
|
||||
attn_gid = self.routed_experts_attn_gid
|
||||
slot_mapping_attn = slot_mappings[attn_gid]
|
||||
self.slot_mapping = slot_mapping_attn[:num_tokens].cpu().numpy()
|
||||
cm_base = CommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1],
|
||||
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1],
|
||||
@@ -3540,7 +3545,7 @@ class GPUModelRunner(
|
||||
"after execute_model() returns None."
|
||||
)
|
||||
|
||||
if self.vllm_config.model_config.enable_return_routed_experts:
|
||||
if self.routed_experts_initialized:
|
||||
capturer = RoutedExpertsCapturer.get_instance()
|
||||
if capturer is not None:
|
||||
capturer.clear_buffer() # noqa
|
||||
@@ -4049,7 +4054,7 @@ class GPUModelRunner(
|
||||
self.kv_connector_output = None
|
||||
|
||||
with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"):
|
||||
if self.model_config.enable_return_routed_experts:
|
||||
if self.routed_experts_initialized:
|
||||
capturer = RoutedExpertsCapturer.get_instance()
|
||||
if capturer is not None:
|
||||
capturer.save_captured_experts(indices=self.slot_mapping) # noqa
|
||||
@@ -6531,8 +6536,12 @@ class GPUModelRunner(
|
||||
kv_transfer_group.register_kv_caches(kv_caches)
|
||||
kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks)
|
||||
|
||||
if self.model_config.enable_return_routed_experts:
|
||||
self.init_routed_experts_capturer()
|
||||
def _get_attention_kv_cache_gid(self) -> int:
|
||||
"""Find the KV cache group index for attention layers."""
|
||||
for gid, group in enumerate(self.kv_cache_config.kv_cache_groups):
|
||||
if isinstance(group.kv_cache_spec, AttentionSpec):
|
||||
return gid
|
||||
return 0
|
||||
|
||||
def init_routed_experts_capturer(self):
|
||||
logger.info(
|
||||
@@ -6540,17 +6549,29 @@ class GPUModelRunner(
|
||||
self.model_config.enable_return_routed_experts,
|
||||
)
|
||||
routed_experts_capturer = RoutedExpertsCapturer.create()
|
||||
block_size = self.cache_config.block_size
|
||||
self.routed_experts_attn_gid = self._get_attention_kv_cache_gid()
|
||||
min_block_size = min(
|
||||
[
|
||||
group.kv_cache_spec.block_size
|
||||
for group in self.kv_cache_config.kv_cache_groups
|
||||
]
|
||||
)
|
||||
num_groups = len(self.kv_cache_config.kv_cache_groups)
|
||||
self.max_num_kv_tokens = (
|
||||
self.kv_cache_config.num_blocks // len(self.kv_cache_config.kv_cache_groups)
|
||||
+ 1
|
||||
) * block_size
|
||||
self.kv_cache_config.num_blocks // num_groups
|
||||
) * min_block_size
|
||||
dcp_size = self.vllm_config.parallel_config.decode_context_parallel_size
|
||||
pcp_size = self.vllm_config.parallel_config.prefill_context_parallel_size
|
||||
if pcp_size * dcp_size > 1:
|
||||
self.max_num_kv_tokens *= pcp_size * dcp_size
|
||||
|
||||
routed_experts_capturer.init_buffer(
|
||||
max_num_batched_tokens=self.scheduler_config.max_num_batched_tokens,
|
||||
max_num_kv_tokens=self.max_num_kv_tokens,
|
||||
vllm_config=self.vllm_config,
|
||||
)
|
||||
self._bind_routed_experts_capturer(routed_experts_capturer)
|
||||
self.routed_experts_initialized = True
|
||||
|
||||
def _bind_routed_experts_capturer(self, capturer: RoutedExpertsCapturer) -> None:
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
|
||||
@@ -552,6 +552,9 @@ class Worker(WorkerBase):
|
||||
else:
|
||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
if self.model_config.enable_return_routed_experts:
|
||||
self.model_runner.init_routed_experts_capturer()
|
||||
|
||||
# Build KV-zero metadata outside the CuMem pool so the bookkeeping
|
||||
# GPU tensors (seg_addrs, block-id buffers) use the standard PyTorch
|
||||
# allocator and are not discarded during sleep/wake cycles.
|
||||
|
||||
Reference in New Issue
Block a user