Fix routed experts capture for hybrid models (Mamba + Attention) (#35744)
Signed-off-by: arlenxu <arlenxu@tencent.com> Signed-off-by: xhx1022 <1737006628@qq.com> Co-authored-by: arlenxu <arlenxu@tencent.com>
This commit is contained in:
384
examples/offline_inference/routed_experts_e2e.py
Normal file
384
examples/offline_inference/routed_experts_e2e.py
Normal file
@@ -0,0 +1,384 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
End-to-end example for routed experts capture with hybrid models.
|
||||
|
||||
Validates that:
|
||||
1. routed_experts is returned in CompletionOutput for MoE models.
|
||||
2. Expert IDs are within valid range.
|
||||
3. Results are deterministic across runs (baseline vs reference).
|
||||
|
||||
Usage:
|
||||
python examples/offline_inference/routed_experts_e2e.py \
|
||||
--model Qwen/Qwen3-30B-A3B \
|
||||
--tp 4 \
|
||||
--max-model-len 4096 \
|
||||
--num-prompts 20 \
|
||||
--max-new-tokens 50
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MODEL = "Qwen/Qwen3-30B-A3B"
|
||||
|
||||
TEST_PROMPTS = [
|
||||
"Hello, my name is",
|
||||
"The capital of France is",
|
||||
"Explain quantum computing in simple terms:",
|
||||
"Write a Python function that sorts a list:",
|
||||
"The meaning of life is",
|
||||
"In a distant galaxy, there was a",
|
||||
"The best way to learn programming is",
|
||||
"Once upon a time in a land far away,",
|
||||
"The theory of relativity states that",
|
||||
"How does photosynthesis work?",
|
||||
"Describe the process of machine learning:",
|
||||
"What are the benefits of exercise?",
|
||||
"The history of artificial intelligence began",
|
||||
"Translate the following to French: Hello world",
|
||||
"Summarize the plot of Romeo and Juliet:",
|
||||
"What is the difference between TCP and UDP?",
|
||||
"The water cycle consists of",
|
||||
"Explain how a neural network learns:",
|
||||
"The periodic table organizes elements by",
|
||||
"Write a haiku about the ocean:",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceResult:
|
||||
"""Result from a single inference run."""
|
||||
|
||||
experts_list: list[np.ndarray] = field(default_factory=list)
|
||||
token_ids_list: list[list[int]] = field(default_factory=list)
|
||||
num_experts: int = 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inference helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _run_async_inference(
|
||||
engine_args: AsyncEngineArgs,
|
||||
prompts: list[str],
|
||||
max_new_tokens: int,
|
||||
) -> InferenceResult:
|
||||
"""Run inference using AsyncLLM."""
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
|
||||
hf_config = engine.model_config.hf_text_config
|
||||
num_experts: int = getattr(hf_config, "num_experts", 0) or getattr(
|
||||
hf_config, "num_local_experts", 0
|
||||
)
|
||||
assert num_experts > 0, "Could not determine num_experts from model config"
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_tokens=max_new_tokens,
|
||||
)
|
||||
|
||||
async def _generate_one(prompt: str, idx: int):
|
||||
request_id = str(uuid.uuid4())
|
||||
final_output = None
|
||||
async for output in engine.generate(prompt, sampling_params, request_id):
|
||||
final_output = output
|
||||
assert final_output is not None
|
||||
|
||||
completion = final_output.outputs[0]
|
||||
routed = completion.routed_experts
|
||||
num_prompt_tokens = len(final_output.prompt_token_ids)
|
||||
num_generated_tokens = len(completion.token_ids)
|
||||
expected_len = num_prompt_tokens + num_generated_tokens - 1
|
||||
assert routed is not None, f"Prompt {idx}: routed_experts is None"
|
||||
assert routed.shape[0] == expected_len, (
|
||||
f"Prompt {idx}: routed_experts length {routed.shape[0]} != "
|
||||
f"prompt ({num_prompt_tokens}) + generated ({num_generated_tokens})"
|
||||
f" - 1 = {expected_len}"
|
||||
)
|
||||
return idx, routed, list(completion.token_ids)
|
||||
|
||||
tasks = [_generate_one(p, i) for i, p in enumerate(prompts)]
|
||||
outputs = await asyncio.gather(*tasks)
|
||||
|
||||
# Sort by original index to maintain prompt order
|
||||
outputs.sort(key=lambda x: x[0])
|
||||
|
||||
result = InferenceResult(num_experts=num_experts)
|
||||
for _, routed, token_ids in outputs:
|
||||
result.experts_list.append(routed)
|
||||
result.token_ids_list.append(token_ids)
|
||||
|
||||
engine.shutdown()
|
||||
return result
|
||||
|
||||
|
||||
def run_inference(
|
||||
model: str,
|
||||
prompts: list[str],
|
||||
max_new_tokens: int = 50,
|
||||
tp: int = 1,
|
||||
max_model_len: int = 4096,
|
||||
) -> InferenceResult:
|
||||
"""Run inference with routed experts capture enabled via AsyncLLM."""
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model,
|
||||
enable_return_routed_experts=True,
|
||||
tensor_parallel_size=tp,
|
||||
max_model_len=max_model_len,
|
||||
disable_log_stats=True,
|
||||
attention_backend="FLASH_ATTN",
|
||||
)
|
||||
|
||||
result = asyncio.run(_run_async_inference(engine_args, prompts, max_new_tokens))
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
current_platform.empty_cache()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Validation helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def validate_expert_ids(
|
||||
experts_list: list[np.ndarray],
|
||||
num_experts: int,
|
||||
) -> None:
|
||||
"""Check that all expert IDs are within valid range [0, num_experts)."""
|
||||
for i, experts in enumerate(experts_list):
|
||||
assert np.all(experts >= 0), (
|
||||
f"Prompt {i}: negative expert IDs found, min={experts.min()}"
|
||||
)
|
||||
assert np.all(experts < num_experts), (
|
||||
f"Prompt {i}: expert ID out of range [0, {num_experts}), "
|
||||
f"max={experts.max()}"
|
||||
)
|
||||
|
||||
|
||||
def validate_shapes(experts_list: list[np.ndarray]) -> None:
|
||||
"""Check that all routed_experts arrays have at least 2 dimensions."""
|
||||
for i, experts in enumerate(experts_list):
|
||||
assert experts.ndim >= 2, (
|
||||
f"Prompt {i}: expected at least 2D array, got shape {experts.shape}"
|
||||
)
|
||||
logger.info("Prompt %d: routed_experts shape = %s", i, experts.shape)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Comparison helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compare_token_ids(
|
||||
baseline: list[list[int]],
|
||||
reference: list[list[int]],
|
||||
) -> float:
|
||||
"""Compare token IDs from two runs. Returns mismatch ratio."""
|
||||
assert len(baseline) == len(reference), (
|
||||
f"Length mismatch: {len(baseline)} vs {len(reference)}"
|
||||
)
|
||||
|
||||
total_tokens = 0
|
||||
total_mismatches = 0
|
||||
|
||||
for i, (base, ref) in enumerate(zip(baseline, reference)):
|
||||
min_len = min(len(base), len(ref))
|
||||
max_len = max(len(base), len(ref))
|
||||
matches = 0
|
||||
for a, b in zip(base[:min_len], ref[:min_len]):
|
||||
if a != b:
|
||||
break
|
||||
matches += 1
|
||||
|
||||
total_mismatches += max_len - matches
|
||||
total_tokens += max_len
|
||||
|
||||
if matches < min_len or len(base) != len(ref):
|
||||
print(
|
||||
f" Prompt {i}: token_ids len={len(base)} vs {len(ref)}, "
|
||||
f"mismatches={max_len - matches}/{max_len}"
|
||||
)
|
||||
|
||||
if total_tokens == 0:
|
||||
raise ValueError("No tokens to compare")
|
||||
|
||||
mismatch_ratio = total_mismatches / total_tokens
|
||||
print(
|
||||
f"Token ID mismatches: {total_mismatches}/{total_tokens} ({mismatch_ratio:.4%})"
|
||||
)
|
||||
return mismatch_ratio
|
||||
|
||||
|
||||
def compare_routed_experts(
|
||||
baseline: list[np.ndarray],
|
||||
reference: list[np.ndarray],
|
||||
threshold: float = 0.05,
|
||||
) -> float:
|
||||
"""Compare two runs of routed experts. Returns mismatch ratio.
|
||||
|
||||
Raises AssertionError if ratio exceeds threshold.
|
||||
"""
|
||||
assert len(baseline) == len(reference), (
|
||||
f"Length mismatch: {len(baseline)} vs {len(reference)}"
|
||||
)
|
||||
|
||||
total_elements = 0
|
||||
total_mismatches = 0
|
||||
|
||||
for i, (base, ref) in enumerate(zip(baseline, reference)):
|
||||
min_len = min(len(base), len(ref))
|
||||
max_len = max(len(base), len(ref))
|
||||
if min_len == 0:
|
||||
continue
|
||||
|
||||
base_trimmed = base[:min_len]
|
||||
ref_trimmed = ref[:min_len]
|
||||
|
||||
matches = 0
|
||||
for a, b in zip(base_trimmed, ref_trimmed):
|
||||
if a.sum() != b.sum():
|
||||
break
|
||||
matches += 1
|
||||
|
||||
total_mismatches += max_len - matches
|
||||
total_elements += max_len
|
||||
|
||||
if matches < min_len or len(base) != len(ref):
|
||||
print(
|
||||
f" Prompt {i}: routed_experts len={len(base)} vs {len(ref)}, "
|
||||
f"mismatches={max_len - matches}/{max_len}"
|
||||
)
|
||||
|
||||
if total_elements == 0:
|
||||
raise ValueError("No elements to compare")
|
||||
|
||||
mismatch_ratio = total_mismatches / total_elements
|
||||
print(
|
||||
f"Routed experts mismatches: {total_mismatches}/{total_elements} "
|
||||
f"({mismatch_ratio:.4%})"
|
||||
)
|
||||
|
||||
assert mismatch_ratio < threshold, (
|
||||
f"Too many mismatches: {total_mismatches}/{total_elements} "
|
||||
f"({mismatch_ratio:.4%}) exceeds threshold {threshold:.4%}"
|
||||
)
|
||||
|
||||
return mismatch_ratio
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main():
|
||||
os.environ.setdefault("VLLM_BATCH_INVARIANT", "1")
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Test routed experts capture for MoE models"
|
||||
)
|
||||
parser.add_argument("--model", type=str, default=DEFAULT_MODEL)
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--max-model-len", type=int, default=4096)
|
||||
parser.add_argument("--num-prompts", type=int, default=20)
|
||||
parser.add_argument("--max-new-tokens", type=int, default=50)
|
||||
parser.add_argument(
|
||||
"--deterministic",
|
||||
action="store_true",
|
||||
help="Run twice and compare results for determinism check",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--threshold",
|
||||
type=float,
|
||||
default=0.05,
|
||||
help="Maximum allowed mismatch ratio for determinism check",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
prompts = TEST_PROMPTS[: args.num_prompts]
|
||||
|
||||
print(f"Model: {args.model}")
|
||||
print(f"TP: {args.tp}")
|
||||
print(f"Prompts: {len(prompts)}")
|
||||
print(f"Max new tokens: {args.max_new_tokens}")
|
||||
print()
|
||||
|
||||
print("=== Run 1 (baseline) ===")
|
||||
baseline = run_inference(
|
||||
model=args.model,
|
||||
prompts=prompts,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
tp=args.tp,
|
||||
max_model_len=args.max_model_len,
|
||||
)
|
||||
print(f"num_experts (from model config): {baseline.num_experts}")
|
||||
|
||||
print("\n=== Validation ===")
|
||||
validate_shapes(baseline.experts_list)
|
||||
validate_expert_ids(baseline.experts_list, num_experts=baseline.num_experts)
|
||||
print(f"All {len(baseline.experts_list)} results passed validation.")
|
||||
|
||||
for i, experts in enumerate(baseline.experts_list):
|
||||
print(
|
||||
f" Prompt {i}: shape={experts.shape}, "
|
||||
f"min={experts.min()}, max={experts.max()}"
|
||||
)
|
||||
|
||||
if args.deterministic:
|
||||
print("\n=== Run 2 (reference) ===")
|
||||
reference = run_inference(
|
||||
model=args.model,
|
||||
prompts=prompts,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
tp=args.tp,
|
||||
max_model_len=args.max_model_len,
|
||||
)
|
||||
|
||||
print("\n=== Determinism Check ===")
|
||||
validate_expert_ids(reference.experts_list, num_experts=baseline.num_experts)
|
||||
|
||||
print("\n--- Token IDs ---")
|
||||
token_mismatch = compare_token_ids(
|
||||
baseline.token_ids_list, reference.token_ids_list
|
||||
)
|
||||
|
||||
print("\n--- Routed Experts ---")
|
||||
expert_mismatch = compare_routed_experts(
|
||||
baseline.experts_list,
|
||||
reference.experts_list,
|
||||
threshold=args.threshold,
|
||||
)
|
||||
|
||||
print(
|
||||
f"\nDeterminism check passed. "
|
||||
f"Token mismatch: {token_mismatch:.4%}, "
|
||||
f"Expert mismatch: {expert_mismatch:.4%}"
|
||||
)
|
||||
|
||||
print("\nAll tests passed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -52,7 +52,7 @@ from vllm.v1.core.sched.request_queue import (
|
||||
)
|
||||
from vllm.v1.core.sched.utils import check_stop, remove_all
|
||||
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig
|
||||
from vllm.v1.metrics.perf import ModelMetrics, PerfStats
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
|
||||
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
|
||||
@@ -259,9 +259,26 @@ class Scheduler(SchedulerInterface):
|
||||
assert len(kv_cache_config.kv_cache_groups) > 0, (
|
||||
"enable_return_routed_experts requires at least one kv cache group"
|
||||
)
|
||||
# Find the attention group for routed experts indexing.
|
||||
self.routed_experts_attn_gid = 0
|
||||
for gid, group in enumerate(kv_cache_config.kv_cache_groups):
|
||||
if isinstance(group.kv_cache_spec, AttentionSpec):
|
||||
self.routed_experts_attn_gid = gid
|
||||
break
|
||||
min_block_size = min(
|
||||
[
|
||||
group.kv_cache_spec.block_size
|
||||
for group in kv_cache_config.kv_cache_groups
|
||||
]
|
||||
)
|
||||
num_groups = len(kv_cache_config.kv_cache_groups)
|
||||
self.max_num_kv_tokens = (
|
||||
kv_cache_config.num_blocks // len(kv_cache_config.kv_cache_groups) + 1
|
||||
) * self.block_size
|
||||
kv_cache_config.num_blocks // num_groups
|
||||
) * min_block_size
|
||||
dcp_size = self.vllm_config.parallel_config.decode_context_parallel_size
|
||||
pcp_size = self.vllm_config.parallel_config.prefill_context_parallel_size
|
||||
if pcp_size * dcp_size > 1:
|
||||
self.max_num_kv_tokens *= pcp_size * dcp_size
|
||||
|
||||
self.routed_experts_reader.attach_buffer(
|
||||
max_num_kv_tokens=self.max_num_kv_tokens,
|
||||
@@ -1561,13 +1578,14 @@ class Scheduler(SchedulerInterface):
|
||||
return None
|
||||
|
||||
kv_blocks = self.kv_cache_manager.get_blocks(request.request_id)
|
||||
block_ids = kv_blocks.get_block_ids()[0]
|
||||
block_ids = kv_blocks.get_block_ids()[self.routed_experts_attn_gid]
|
||||
num_tokens = request.num_tokens - 1
|
||||
|
||||
# compute slot mapping
|
||||
# compute slot mapping using attention group's block_size
|
||||
block_ids_array = np.array(block_ids, dtype=np.int32)
|
||||
num_blocks = len(block_ids)
|
||||
block_size = self.block_size
|
||||
attn_group = self.kv_cache_config.kv_cache_groups[self.routed_experts_attn_gid]
|
||||
block_size = attn_group.kv_cache_spec.block_size
|
||||
|
||||
# generate block offsets
|
||||
block_offsets = np.arange(0, block_size)
|
||||
|
||||
@@ -422,6 +422,9 @@ class GPUModelRunner(
|
||||
)
|
||||
# This will be overridden in load_model()
|
||||
self.is_multimodal_pruning_enabled = False
|
||||
# Set to True after init_routed_experts_capturer() completes.
|
||||
# Prevents routed experts code from running during profiling/dummy run.
|
||||
self.routed_experts_initialized = False
|
||||
self.max_model_len = model_config.max_model_len
|
||||
|
||||
# Always set to false after the first forward pass
|
||||
@@ -1951,8 +1954,10 @@ class GPUModelRunner(
|
||||
block_table_gid_0 = _get_block_table(0)
|
||||
slot_mapping_gid_0 = slot_mappings[0]
|
||||
|
||||
if self.model_config.enable_return_routed_experts:
|
||||
self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy()
|
||||
if self.routed_experts_initialized:
|
||||
attn_gid = self.routed_experts_attn_gid
|
||||
slot_mapping_attn = slot_mappings[attn_gid]
|
||||
self.slot_mapping = slot_mapping_attn[:num_tokens].cpu().numpy()
|
||||
cm_base = CommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1],
|
||||
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1],
|
||||
@@ -3540,7 +3545,7 @@ class GPUModelRunner(
|
||||
"after execute_model() returns None."
|
||||
)
|
||||
|
||||
if self.vllm_config.model_config.enable_return_routed_experts:
|
||||
if self.routed_experts_initialized:
|
||||
capturer = RoutedExpertsCapturer.get_instance()
|
||||
if capturer is not None:
|
||||
capturer.clear_buffer() # noqa
|
||||
@@ -4049,7 +4054,7 @@ class GPUModelRunner(
|
||||
self.kv_connector_output = None
|
||||
|
||||
with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"):
|
||||
if self.model_config.enable_return_routed_experts:
|
||||
if self.routed_experts_initialized:
|
||||
capturer = RoutedExpertsCapturer.get_instance()
|
||||
if capturer is not None:
|
||||
capturer.save_captured_experts(indices=self.slot_mapping) # noqa
|
||||
@@ -6531,8 +6536,12 @@ class GPUModelRunner(
|
||||
kv_transfer_group.register_kv_caches(kv_caches)
|
||||
kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks)
|
||||
|
||||
if self.model_config.enable_return_routed_experts:
|
||||
self.init_routed_experts_capturer()
|
||||
def _get_attention_kv_cache_gid(self) -> int:
|
||||
"""Find the KV cache group index for attention layers."""
|
||||
for gid, group in enumerate(self.kv_cache_config.kv_cache_groups):
|
||||
if isinstance(group.kv_cache_spec, AttentionSpec):
|
||||
return gid
|
||||
return 0
|
||||
|
||||
def init_routed_experts_capturer(self):
|
||||
logger.info(
|
||||
@@ -6540,17 +6549,29 @@ class GPUModelRunner(
|
||||
self.model_config.enable_return_routed_experts,
|
||||
)
|
||||
routed_experts_capturer = RoutedExpertsCapturer.create()
|
||||
block_size = self.cache_config.block_size
|
||||
self.routed_experts_attn_gid = self._get_attention_kv_cache_gid()
|
||||
min_block_size = min(
|
||||
[
|
||||
group.kv_cache_spec.block_size
|
||||
for group in self.kv_cache_config.kv_cache_groups
|
||||
]
|
||||
)
|
||||
num_groups = len(self.kv_cache_config.kv_cache_groups)
|
||||
self.max_num_kv_tokens = (
|
||||
self.kv_cache_config.num_blocks // len(self.kv_cache_config.kv_cache_groups)
|
||||
+ 1
|
||||
) * block_size
|
||||
self.kv_cache_config.num_blocks // num_groups
|
||||
) * min_block_size
|
||||
dcp_size = self.vllm_config.parallel_config.decode_context_parallel_size
|
||||
pcp_size = self.vllm_config.parallel_config.prefill_context_parallel_size
|
||||
if pcp_size * dcp_size > 1:
|
||||
self.max_num_kv_tokens *= pcp_size * dcp_size
|
||||
|
||||
routed_experts_capturer.init_buffer(
|
||||
max_num_batched_tokens=self.scheduler_config.max_num_batched_tokens,
|
||||
max_num_kv_tokens=self.max_num_kv_tokens,
|
||||
vllm_config=self.vllm_config,
|
||||
)
|
||||
self._bind_routed_experts_capturer(routed_experts_capturer)
|
||||
self.routed_experts_initialized = True
|
||||
|
||||
def _bind_routed_experts_capturer(self, capturer: RoutedExpertsCapturer) -> None:
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
|
||||
@@ -552,6 +552,9 @@ class Worker(WorkerBase):
|
||||
else:
|
||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
if self.model_config.enable_return_routed_experts:
|
||||
self.model_runner.init_routed_experts_capturer()
|
||||
|
||||
# Build KV-zero metadata outside the CuMem pool so the bookkeeping
|
||||
# GPU tensors (seg_addrs, block-id buffers) use the standard PyTorch
|
||||
# allocator and are not discarded during sleep/wake cycles.
|
||||
|
||||
Reference in New Issue
Block a user