diff --git a/examples/offline_inference/routed_experts_e2e.py b/examples/offline_inference/routed_experts_e2e.py new file mode 100644 index 000000000..bb1d7b411 --- /dev/null +++ b/examples/offline_inference/routed_experts_e2e.py @@ -0,0 +1,384 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +End-to-end example for routed experts capture with hybrid models. + +Validates that: +1. routed_experts is returned in CompletionOutput for MoE models. +2. Expert IDs are within valid range. +3. Results are deterministic across runs (baseline vs reference). + +Usage: + python examples/offline_inference/routed_experts_e2e.py \ + --model Qwen/Qwen3-30B-A3B \ + --tp 4 \ + --max-model-len 4096 \ + --num-prompts 20 \ + --max-new-tokens 50 +""" + +from __future__ import annotations + +import argparse +import asyncio +import logging +import os +import uuid +from dataclasses import dataclass, field + +import numpy as np + +from vllm.engine.arg_utils import AsyncEngineArgs + +logger = logging.getLogger(__name__) + +DEFAULT_MODEL = "Qwen/Qwen3-30B-A3B" + +TEST_PROMPTS = [ + "Hello, my name is", + "The capital of France is", + "Explain quantum computing in simple terms:", + "Write a Python function that sorts a list:", + "The meaning of life is", + "In a distant galaxy, there was a", + "The best way to learn programming is", + "Once upon a time in a land far away,", + "The theory of relativity states that", + "How does photosynthesis work?", + "Describe the process of machine learning:", + "What are the benefits of exercise?", + "The history of artificial intelligence began", + "Translate the following to French: Hello world", + "Summarize the plot of Romeo and Juliet:", + "What is the difference between TCP and UDP?", + "The water cycle consists of", + "Explain how a neural network learns:", + "The periodic table organizes elements by", + "Write a haiku about the ocean:", +] + + +@dataclass +class InferenceResult: + """Result from a single inference run.""" + + experts_list: list[np.ndarray] = field(default_factory=list) + token_ids_list: list[list[int]] = field(default_factory=list) + num_experts: int = 0 + + +# --------------------------------------------------------------------------- +# Inference helpers +# --------------------------------------------------------------------------- + + +async def _run_async_inference( + engine_args: AsyncEngineArgs, + prompts: list[str], + max_new_tokens: int, +) -> InferenceResult: + """Run inference using AsyncLLM.""" + from vllm.sampling_params import SamplingParams + from vllm.v1.engine.async_llm import AsyncLLM + + engine = AsyncLLM.from_engine_args(engine_args) + + hf_config = engine.model_config.hf_text_config + num_experts: int = getattr(hf_config, "num_experts", 0) or getattr( + hf_config, "num_local_experts", 0 + ) + assert num_experts > 0, "Could not determine num_experts from model config" + + sampling_params = SamplingParams( + temperature=0, + max_tokens=max_new_tokens, + ) + + async def _generate_one(prompt: str, idx: int): + request_id = str(uuid.uuid4()) + final_output = None + async for output in engine.generate(prompt, sampling_params, request_id): + final_output = output + assert final_output is not None + + completion = final_output.outputs[0] + routed = completion.routed_experts + num_prompt_tokens = len(final_output.prompt_token_ids) + num_generated_tokens = len(completion.token_ids) + expected_len = num_prompt_tokens + num_generated_tokens - 1 + assert routed is not None, f"Prompt {idx}: routed_experts is None" + assert routed.shape[0] == expected_len, ( + f"Prompt {idx}: routed_experts length {routed.shape[0]} != " + f"prompt ({num_prompt_tokens}) + generated ({num_generated_tokens})" + f" - 1 = {expected_len}" + ) + return idx, routed, list(completion.token_ids) + + tasks = [_generate_one(p, i) for i, p in enumerate(prompts)] + outputs = await asyncio.gather(*tasks) + + # Sort by original index to maintain prompt order + outputs.sort(key=lambda x: x[0]) + + result = InferenceResult(num_experts=num_experts) + for _, routed, token_ids in outputs: + result.experts_list.append(routed) + result.token_ids_list.append(token_ids) + + engine.shutdown() + return result + + +def run_inference( + model: str, + prompts: list[str], + max_new_tokens: int = 50, + tp: int = 1, + max_model_len: int = 4096, +) -> InferenceResult: + """Run inference with routed experts capture enabled via AsyncLLM.""" + engine_args = AsyncEngineArgs( + model=model, + enable_return_routed_experts=True, + tensor_parallel_size=tp, + max_model_len=max_model_len, + disable_log_stats=True, + attention_backend="FLASH_ATTN", + ) + + result = asyncio.run(_run_async_inference(engine_args, prompts, max_new_tokens)) + + from vllm.platforms import current_platform + + if current_platform.is_cuda_alike(): + current_platform.empty_cache() + + return result + + +# --------------------------------------------------------------------------- +# Validation helpers +# --------------------------------------------------------------------------- + + +def validate_expert_ids( + experts_list: list[np.ndarray], + num_experts: int, +) -> None: + """Check that all expert IDs are within valid range [0, num_experts).""" + for i, experts in enumerate(experts_list): + assert np.all(experts >= 0), ( + f"Prompt {i}: negative expert IDs found, min={experts.min()}" + ) + assert np.all(experts < num_experts), ( + f"Prompt {i}: expert ID out of range [0, {num_experts}), " + f"max={experts.max()}" + ) + + +def validate_shapes(experts_list: list[np.ndarray]) -> None: + """Check that all routed_experts arrays have at least 2 dimensions.""" + for i, experts in enumerate(experts_list): + assert experts.ndim >= 2, ( + f"Prompt {i}: expected at least 2D array, got shape {experts.shape}" + ) + logger.info("Prompt %d: routed_experts shape = %s", i, experts.shape) + + +# --------------------------------------------------------------------------- +# Comparison helpers +# --------------------------------------------------------------------------- + + +def compare_token_ids( + baseline: list[list[int]], + reference: list[list[int]], +) -> float: + """Compare token IDs from two runs. Returns mismatch ratio.""" + assert len(baseline) == len(reference), ( + f"Length mismatch: {len(baseline)} vs {len(reference)}" + ) + + total_tokens = 0 + total_mismatches = 0 + + for i, (base, ref) in enumerate(zip(baseline, reference)): + min_len = min(len(base), len(ref)) + max_len = max(len(base), len(ref)) + matches = 0 + for a, b in zip(base[:min_len], ref[:min_len]): + if a != b: + break + matches += 1 + + total_mismatches += max_len - matches + total_tokens += max_len + + if matches < min_len or len(base) != len(ref): + print( + f" Prompt {i}: token_ids len={len(base)} vs {len(ref)}, " + f"mismatches={max_len - matches}/{max_len}" + ) + + if total_tokens == 0: + raise ValueError("No tokens to compare") + + mismatch_ratio = total_mismatches / total_tokens + print( + f"Token ID mismatches: {total_mismatches}/{total_tokens} ({mismatch_ratio:.4%})" + ) + return mismatch_ratio + + +def compare_routed_experts( + baseline: list[np.ndarray], + reference: list[np.ndarray], + threshold: float = 0.05, +) -> float: + """Compare two runs of routed experts. Returns mismatch ratio. + + Raises AssertionError if ratio exceeds threshold. + """ + assert len(baseline) == len(reference), ( + f"Length mismatch: {len(baseline)} vs {len(reference)}" + ) + + total_elements = 0 + total_mismatches = 0 + + for i, (base, ref) in enumerate(zip(baseline, reference)): + min_len = min(len(base), len(ref)) + max_len = max(len(base), len(ref)) + if min_len == 0: + continue + + base_trimmed = base[:min_len] + ref_trimmed = ref[:min_len] + + matches = 0 + for a, b in zip(base_trimmed, ref_trimmed): + if a.sum() != b.sum(): + break + matches += 1 + + total_mismatches += max_len - matches + total_elements += max_len + + if matches < min_len or len(base) != len(ref): + print( + f" Prompt {i}: routed_experts len={len(base)} vs {len(ref)}, " + f"mismatches={max_len - matches}/{max_len}" + ) + + if total_elements == 0: + raise ValueError("No elements to compare") + + mismatch_ratio = total_mismatches / total_elements + print( + f"Routed experts mismatches: {total_mismatches}/{total_elements} " + f"({mismatch_ratio:.4%})" + ) + + assert mismatch_ratio < threshold, ( + f"Too many mismatches: {total_mismatches}/{total_elements} " + f"({mismatch_ratio:.4%}) exceeds threshold {threshold:.4%}" + ) + + return mismatch_ratio + + +# --------------------------------------------------------------------------- +# CLI entry point +# --------------------------------------------------------------------------- + + +def main(): + os.environ.setdefault("VLLM_BATCH_INVARIANT", "1") + + parser = argparse.ArgumentParser( + description="Test routed experts capture for MoE models" + ) + parser.add_argument("--model", type=str, default=DEFAULT_MODEL) + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--max-model-len", type=int, default=4096) + parser.add_argument("--num-prompts", type=int, default=20) + parser.add_argument("--max-new-tokens", type=int, default=50) + parser.add_argument( + "--deterministic", + action="store_true", + help="Run twice and compare results for determinism check", + ) + parser.add_argument( + "--threshold", + type=float, + default=0.05, + help="Maximum allowed mismatch ratio for determinism check", + ) + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + prompts = TEST_PROMPTS[: args.num_prompts] + + print(f"Model: {args.model}") + print(f"TP: {args.tp}") + print(f"Prompts: {len(prompts)}") + print(f"Max new tokens: {args.max_new_tokens}") + print() + + print("=== Run 1 (baseline) ===") + baseline = run_inference( + model=args.model, + prompts=prompts, + max_new_tokens=args.max_new_tokens, + tp=args.tp, + max_model_len=args.max_model_len, + ) + print(f"num_experts (from model config): {baseline.num_experts}") + + print("\n=== Validation ===") + validate_shapes(baseline.experts_list) + validate_expert_ids(baseline.experts_list, num_experts=baseline.num_experts) + print(f"All {len(baseline.experts_list)} results passed validation.") + + for i, experts in enumerate(baseline.experts_list): + print( + f" Prompt {i}: shape={experts.shape}, " + f"min={experts.min()}, max={experts.max()}" + ) + + if args.deterministic: + print("\n=== Run 2 (reference) ===") + reference = run_inference( + model=args.model, + prompts=prompts, + max_new_tokens=args.max_new_tokens, + tp=args.tp, + max_model_len=args.max_model_len, + ) + + print("\n=== Determinism Check ===") + validate_expert_ids(reference.experts_list, num_experts=baseline.num_experts) + + print("\n--- Token IDs ---") + token_mismatch = compare_token_ids( + baseline.token_ids_list, reference.token_ids_list + ) + + print("\n--- Routed Experts ---") + expert_mismatch = compare_routed_experts( + baseline.experts_list, + reference.experts_list, + threshold=args.threshold, + ) + + print( + f"\nDeterminism check passed. " + f"Token mismatch: {token_mismatch:.4%}, " + f"Expert mismatch: {expert_mismatch:.4%}" + ) + + print("\nAll tests passed!") + + +if __name__ == "__main__": + main() diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 4628e6344..61418692b 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -52,7 +52,7 @@ from vllm.v1.core.sched.request_queue import ( ) from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs -from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig from vllm.v1.metrics.perf import ModelMetrics, PerfStats from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput @@ -259,9 +259,26 @@ class Scheduler(SchedulerInterface): assert len(kv_cache_config.kv_cache_groups) > 0, ( "enable_return_routed_experts requires at least one kv cache group" ) + # Find the attention group for routed experts indexing. + self.routed_experts_attn_gid = 0 + for gid, group in enumerate(kv_cache_config.kv_cache_groups): + if isinstance(group.kv_cache_spec, AttentionSpec): + self.routed_experts_attn_gid = gid + break + min_block_size = min( + [ + group.kv_cache_spec.block_size + for group in kv_cache_config.kv_cache_groups + ] + ) + num_groups = len(kv_cache_config.kv_cache_groups) self.max_num_kv_tokens = ( - kv_cache_config.num_blocks // len(kv_cache_config.kv_cache_groups) + 1 - ) * self.block_size + kv_cache_config.num_blocks // num_groups + ) * min_block_size + dcp_size = self.vllm_config.parallel_config.decode_context_parallel_size + pcp_size = self.vllm_config.parallel_config.prefill_context_parallel_size + if pcp_size * dcp_size > 1: + self.max_num_kv_tokens *= pcp_size * dcp_size self.routed_experts_reader.attach_buffer( max_num_kv_tokens=self.max_num_kv_tokens, @@ -1561,13 +1578,14 @@ class Scheduler(SchedulerInterface): return None kv_blocks = self.kv_cache_manager.get_blocks(request.request_id) - block_ids = kv_blocks.get_block_ids()[0] + block_ids = kv_blocks.get_block_ids()[self.routed_experts_attn_gid] num_tokens = request.num_tokens - 1 - # compute slot mapping + # compute slot mapping using attention group's block_size block_ids_array = np.array(block_ids, dtype=np.int32) num_blocks = len(block_ids) - block_size = self.block_size + attn_group = self.kv_cache_config.kv_cache_groups[self.routed_experts_attn_gid] + block_size = attn_group.kv_cache_spec.block_size # generate block offsets block_offsets = np.arange(0, block_size) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ba40e8e45..b53bd71a1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -422,6 +422,9 @@ class GPUModelRunner( ) # This will be overridden in load_model() self.is_multimodal_pruning_enabled = False + # Set to True after init_routed_experts_capturer() completes. + # Prevents routed experts code from running during profiling/dummy run. + self.routed_experts_initialized = False self.max_model_len = model_config.max_model_len # Always set to false after the first forward pass @@ -1951,8 +1954,10 @@ class GPUModelRunner( block_table_gid_0 = _get_block_table(0) slot_mapping_gid_0 = slot_mappings[0] - if self.model_config.enable_return_routed_experts: - self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy() + if self.routed_experts_initialized: + attn_gid = self.routed_experts_attn_gid + slot_mapping_attn = slot_mappings[attn_gid] + self.slot_mapping = slot_mapping_attn[:num_tokens].cpu().numpy() cm_base = CommonAttentionMetadata( query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1], query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1], @@ -3540,7 +3545,7 @@ class GPUModelRunner( "after execute_model() returns None." ) - if self.vllm_config.model_config.enable_return_routed_experts: + if self.routed_experts_initialized: capturer = RoutedExpertsCapturer.get_instance() if capturer is not None: capturer.clear_buffer() # noqa @@ -4049,7 +4054,7 @@ class GPUModelRunner( self.kv_connector_output = None with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): - if self.model_config.enable_return_routed_experts: + if self.routed_experts_initialized: capturer = RoutedExpertsCapturer.get_instance() if capturer is not None: capturer.save_captured_experts(indices=self.slot_mapping) # noqa @@ -6531,8 +6536,12 @@ class GPUModelRunner( kv_transfer_group.register_kv_caches(kv_caches) kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks) - if self.model_config.enable_return_routed_experts: - self.init_routed_experts_capturer() + def _get_attention_kv_cache_gid(self) -> int: + """Find the KV cache group index for attention layers.""" + for gid, group in enumerate(self.kv_cache_config.kv_cache_groups): + if isinstance(group.kv_cache_spec, AttentionSpec): + return gid + return 0 def init_routed_experts_capturer(self): logger.info( @@ -6540,17 +6549,29 @@ class GPUModelRunner( self.model_config.enable_return_routed_experts, ) routed_experts_capturer = RoutedExpertsCapturer.create() - block_size = self.cache_config.block_size + self.routed_experts_attn_gid = self._get_attention_kv_cache_gid() + min_block_size = min( + [ + group.kv_cache_spec.block_size + for group in self.kv_cache_config.kv_cache_groups + ] + ) + num_groups = len(self.kv_cache_config.kv_cache_groups) self.max_num_kv_tokens = ( - self.kv_cache_config.num_blocks // len(self.kv_cache_config.kv_cache_groups) - + 1 - ) * block_size + self.kv_cache_config.num_blocks // num_groups + ) * min_block_size + dcp_size = self.vllm_config.parallel_config.decode_context_parallel_size + pcp_size = self.vllm_config.parallel_config.prefill_context_parallel_size + if pcp_size * dcp_size > 1: + self.max_num_kv_tokens *= pcp_size * dcp_size + routed_experts_capturer.init_buffer( max_num_batched_tokens=self.scheduler_config.max_num_batched_tokens, max_num_kv_tokens=self.max_num_kv_tokens, vllm_config=self.vllm_config, ) self._bind_routed_experts_capturer(routed_experts_capturer) + self.routed_experts_initialized = True def _bind_routed_experts_capturer(self, capturer: RoutedExpertsCapturer) -> None: from vllm.model_executor.layers.fused_moe.layer import FusedMoE diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index b0e13d609..83e12710a 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -552,6 +552,9 @@ class Worker(WorkerBase): else: self.model_runner.initialize_kv_cache(kv_cache_config) + if self.model_config.enable_return_routed_experts: + self.model_runner.init_routed_experts_capturer() + # Build KV-zero metadata outside the CuMem pool so the bookkeeping # GPU tensors (seg_addrs, block-id buffers) use the standard PyTorch # allocator and are not discarded during sleep/wake cycles.