diff --git a/vllm/config/model.py b/vllm/config/model.py index 249fb5668..da3a4618c 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -198,6 +198,8 @@ class ModelConfig: graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid for maximal performance and flexibility.""" + enable_return_routed_experts: bool = False + """Whether to return routed experts.""" max_logprobs: int = 20 """Maximum number of log probabilities to return when `logprobs` is specified in `SamplingParams`. The default value comes the default for the diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 9eaa55f1f..63bfd056b 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -1352,6 +1352,7 @@ class VllmConfig: f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa f"quantization={self.model_config.quantization}, " f"enforce_eager={self.model_config.enforce_eager}, " + f"enable_return_routed_experts={self.model_config.enable_return_routed_experts}, " # noqa f"kv_cache_dtype={self.cache_config.cache_dtype}, " f"device_config={self.device_config.device}, " f"structured_outputs_config={self.structured_outputs_config!r}, " diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 94608b13d..7631cd61d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -354,6 +354,7 @@ class EngineArgs: """Arguments for vLLM engine.""" model: str = ModelConfig.model + enable_return_routed_experts: bool = ModelConfig.enable_return_routed_experts model_weights: str = ModelConfig.model_weights served_model_name: str | list[str] | None = ModelConfig.served_model_name tokenizer: str | None = ModelConfig.tokenizer @@ -657,6 +658,10 @@ class EngineArgs: **model_kwargs["allow_deprecated_quantization"], ) model_group.add_argument("--enforce-eager", **model_kwargs["enforce_eager"]) + model_group.add_argument( + "--enable-return-routed-experts", + **model_kwargs["enable_return_routed_experts"], + ) model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"]) model_group.add_argument("--logprobs-mode", **model_kwargs["logprobs_mode"]) model_group.add_argument( @@ -1239,6 +1244,7 @@ class EngineArgs: quantization=self.quantization, allow_deprecated_quantization=self.allow_deprecated_quantization, enforce_eager=self.enforce_eager, + enable_return_routed_experts=self.enable_return_routed_experts, max_logprobs=self.max_logprobs, logprobs_mode=self.logprobs_mode, disable_sliding_window=self.disable_sliding_window, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 9c10e28c2..a7aa9a569 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -158,6 +158,7 @@ class LLM: enforce_eager: Whether to enforce eager execution. If True, we will disable CUDA graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid. + enable_return_routed_experts: Whether to return routed experts. disable_custom_all_reduce: See [ParallelConfig][vllm.config.ParallelConfig]. hf_token: The token to use as HTTP bearer authorization for remote files @@ -209,6 +210,7 @@ class LLM: swap_space: float = 4, cpu_offload_gb: float = 0, enforce_eager: bool = False, + enable_return_routed_experts: bool = False, disable_custom_all_reduce: bool = False, hf_token: bool | str | None = None, hf_overrides: HfOverrides | None = None, @@ -317,6 +319,7 @@ class LLM: swap_space=swap_space, cpu_offload_gb=cpu_offload_gb, enforce_eager=enforce_eager, + enable_return_routed_experts=enable_return_routed_experts, disable_custom_all_reduce=disable_custom_all_reduce, hf_token=hf_token, hf_overrides=hf_overrides, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 60e8ef9f7..3b3a789f6 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -35,6 +35,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( init_aiter_topK_meta_data, ) +from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( + RoutedExpertsCapturer, +) from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, @@ -701,6 +704,13 @@ class FusedMoE(CustomOp): def shared_experts(self) -> torch.nn.Module | None: return None + @property + def layer_id(self): + # Delayed import to avoid circular dependency + from vllm.model_executor.models.utils import extract_layer_index + + return extract_layer_index(self.layer_name) + @property def gate(self) -> torch.nn.Module | None: return None @@ -1650,6 +1660,18 @@ class FusedMoE(CustomOp): assert topk_ids.dtype == indices_type or indices_type is None + if ( + self.vllm_config.model_config is not None + and self.vllm_config.model_config.enable_return_routed_experts + ): + # In dummy runs, the capturer is not initialized. + capturer = RoutedExpertsCapturer.get_instance() + if capturer is not None: # in dummmy_run may be None + capturer.capture( # noqa + layer_id=self.layer_id, + topk_ids=topk_ids, + ) + return topk_weights, topk_ids def must_reduce_shared_expert_outputs(self) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py b/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py new file mode 100644 index 000000000..0fd788ea5 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py @@ -0,0 +1,324 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from +# https://github.com/sgl-project/sglang/blob/bed301a5acaa9577c9aa706468bdf242f6a43051/python/sglang/srt/layers/moe/routed_experts_capturer.py + +from __future__ import annotations + +import fcntl +import logging +import os +import tempfile +from collections.abc import Generator +from contextlib import contextmanager +from multiprocessing import shared_memory +from unittest.mock import patch + +import numpy as np +import torch + +from vllm.config import ModelConfig +from vllm.distributed import get_tensor_model_parallel_rank + +logger = logging.getLogger(__name__) + +# Constants +_TMP_DIR = tempfile.gettempdir() +_LOCK_FILE_PREFIX = os.path.join(_TMP_DIR, "vllm_routed_experts") +_BUFFER_PREFIX = "vllm_routed_experts_buffer" + +# Global singleton instances +_global_experts_capturer: RoutedExpertsCapturer | None = None +_global_experts_reader: RoutedExpertsReader | None = None + + +@contextmanager +def _file_lock(lock_file: str, mode: str = "wb+") -> Generator[None, None, None]: + """Context manager for file-based locking.""" + with open(lock_file, mode) as fp: + fcntl.flock(fp, fcntl.LOCK_EX) + try: + yield + finally: + fcntl.flock(fp, fcntl.LOCK_UN) + + +def _create_or_attach_shared_memory( + name: str, size: int, lock_file: str +) -> shared_memory.SharedMemory: + """Create or attach to shared memory with proper locking.""" + # Ensure lock file exists before acquiring lock + with open(lock_file, "wb"): + pass + + with _file_lock(lock_file): + try: + shm = shared_memory.SharedMemory(name=name, create=True, size=size) + except FileExistsError: + shm = shared_memory.SharedMemory(name=name, create=False, size=size) + + if shm.size != size: + logger.warning( + "Shared memory %s size mismatch; recreating", + name, + ) + shm.close() + shm.unlink() + try: + shm = shared_memory.SharedMemory(name=name, create=True, size=size) + logger.info("Created shared memory %s", name) + except FileExistsError: + shm = shared_memory.SharedMemory(name=name, create=False, size=size) + logger.info("Linked to existing shared memory %s", name) + + return shm + + +class RoutedExpertsCapturer: + """ + Capturer for routed experts with device and optional shared memory buffer. + + This class captures expert routing decisions during model forward passes + and optionally stores them in shared memory for cross-process access. + """ + + _instance: RoutedExpertsCapturer | None = None + + def __init__(self) -> None: + self._device_buffer: torch.Tensor | None = None + self._shm: shared_memory.SharedMemory | None = None + self._host_buffer_view: np.ndarray | None = None + self._lock_file: str | None = None + self._shm_name: str | None = None + + @classmethod + def create(cls) -> RoutedExpertsCapturer: + """Create a global singleton instance.""" + global _global_experts_capturer + if _global_experts_capturer is not None: + raise RuntimeError("Experts capturer already created.") + + _global_experts_capturer = cls() + return _global_experts_capturer + + @staticmethod + def get_instance() -> RoutedExpertsCapturer | None: + """Get the global singleton instance.""" + return _global_experts_capturer + + def init_buffer( + self, + max_num_batched_tokens: int, + max_num_kv_tokens: int, + model_config: ModelConfig, + instance_id: str, + ) -> None: + """ + Initialize the device buffer and optionally shared memory buffer. + + Args: + max_num_batched_tokens: Maximum number of tokens in a batch. + max_num_kv_tokens: Maximum number of KV tokens for shared memory. + model_config: Model configuration containing layer and expert info. + instance_id: Unique identifier for the shared memory buffer. + """ + + if self._device_buffer is not None: + raise RuntimeError("Device buffer has already been initialized") + + hf_config = model_config.hf_text_config + num_layers = hf_config.num_hidden_layers + num_experts_per_tok = hf_config.num_experts_per_tok + + # Initialize device buffer + self._device_buffer = torch.zeros( + (max_num_batched_tokens, num_layers, num_experts_per_tok), + dtype=torch.int32, + device="cuda", + ) + + if get_tensor_model_parallel_rank() != 0: + return + + # Initialize shared memory + shape = (max_num_kv_tokens, num_layers, num_experts_per_tok) + buffer_size = int(np.prod(shape)) * np.dtype(np.int32).itemsize + + self._lock_file = f"{_LOCK_FILE_PREFIX}_{instance_id}.lock" + self._shm_name = f"{_BUFFER_PREFIX}_{instance_id}" + + self._shm = _create_or_attach_shared_memory( + self._shm_name, buffer_size, self._lock_file + ) + self._host_buffer_view = np.ndarray(shape, dtype=np.int32, buffer=self._shm.buf) + self._host_buffer_view.fill(0) + + logger.debug( + "Created shared memory buffer '%s' with shape %s", + self._shm.name, + shape, + ) + + def capture(self, layer_id: int, topk_ids: torch.Tensor) -> None: + """ + Capture expert routing decisions for a specific layer. + + Args: + layer_id: The layer index. + topk_ids: Tensor of shape (batch_size, num_routed_experts). + """ + if self._device_buffer is None: + raise RuntimeError("Buffer not initialized. Call init_buffer() first.") + + if layer_id >= self._device_buffer.shape[1]: + return + + batch_size = topk_ids.shape[0] + self._device_buffer[:batch_size, layer_id, :] = topk_ids + + def clear_buffer(self) -> None: + """Clear the device buffer.""" + if self._device_buffer is not None: + self._device_buffer.zero_() + + def save_captured_experts(self, indices: np.ndarray) -> None: + """ + Save captured experts from device buffer to shared memory. + + Args: + indices: Array of indices indicating where to store the data. + """ + if get_tensor_model_parallel_rank() != 0: + return + if self._lock_file is None: + raise RuntimeError("Shared memory not initialized.") + if self._host_buffer_view is None: + return + if self._device_buffer is None: + raise RuntimeError("Device buffer not initialized.") + + num_tokens = len(indices) + data = self._device_buffer[:num_tokens, :, :].cpu().numpy() + + with _file_lock(self._lock_file): + self._host_buffer_view[indices, :, :] = data + + def cleanup(self) -> None: + """Explicitly clean up shared memory resources.""" + if self._shm is not None: + try: + self._shm.close() + self._shm.unlink() + except Exception: + logger.debug("Exception during cleanup for capturer", exc_info=True) + finally: + self._shm = None + + def __del__(self) -> None: + """Clean up shared memory on destruction.""" + self.cleanup() + + +class RoutedExpertsReader: + """ + Reader for routed experts from shared memory. + + This class attaches to shared memory created by RoutedExpertsCapturer + and reads expert routing decisions. + """ + + _instance: RoutedExpertsReader | None = None + + def __init__(self) -> None: + self._shm: shared_memory.SharedMemory | None = None + self._host_buffer_view: np.ndarray | None = None + self._lock_file: str | None = None + + @classmethod + def create(cls) -> RoutedExpertsReader: + """Create a global singleton instance.""" + global _global_experts_reader + if _global_experts_reader is not None: + raise RuntimeError("Experts reader already created.") + + _global_experts_reader = cls() + return _global_experts_reader + + @staticmethod + def get_instance() -> RoutedExpertsReader | None: + """Get the global singleton instance.""" + if _global_experts_reader is None: + logger.info("Experts reader not initialized.") + return _global_experts_reader + + def attach_buffer( + self, + max_num_kv_tokens: int, + model_config: ModelConfig, + instance_id: str, + ) -> None: + """ + Attach to an existing shared memory buffer. + + Args: + max_num_kv_tokens: Maximum number of KV tokens. + model_config: Model configuration. + instance_id: Unique identifier for the shared memory buffer. + """ + if self._shm is not None: + logger.warning("Already attached to shared memory buffer.") + return # Already attached + + hf_config = model_config.hf_text_config + shape = ( + max_num_kv_tokens, + hf_config.num_hidden_layers, + hf_config.num_experts_per_tok, + ) + + self._lock_file = f"{_LOCK_FILE_PREFIX}_{instance_id}.lock" + shm_name = f"{_BUFFER_PREFIX}_{instance_id}" + + with _file_lock(self._lock_file, mode="rb+"): + # Avoid resource_tracker registering the shared memory + with patch( + "multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None, + ): + self._shm = shared_memory.SharedMemory(name=shm_name) + + self._host_buffer_view = np.ndarray( + shape, dtype=np.int32, buffer=self._shm.buf + ) + + def get_routed_experts(self, indices: np.ndarray) -> np.ndarray: + """ + Read routed expert data from shared memory. + + Args: + indices: Array of indices to read. + + Returns: + Copy of the expert routing data for the given indices. + """ + if self._host_buffer_view is None: + raise RuntimeError("Buffer not attached. Call attach_buffer() first.") + if self._lock_file is None: + raise RuntimeError("Lock file not initialized.") + + with _file_lock(self._lock_file, mode="rb+"): + return self._host_buffer_view[indices, :, :].copy() + + def cleanup(self) -> None: + """Explicitly clean up resources (close without unlink).""" + if self._shm is not None: + try: + self._shm.close() + except Exception: + logger.debug("Exception during cleanup for reader", exc_info=True) + finally: + self._shm = None + + def __del__(self) -> None: + """Close shared memory on destruction (do not unlink).""" + self.cleanup() diff --git a/vllm/outputs.py b/vllm/outputs.py index 74e534ef0..cf23745c4 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -6,6 +6,7 @@ from collections.abc import Sequence as GenericSequence from dataclasses import dataclass from typing import Any, Generic +import numpy as np import torch from typing_extensions import TypeVar @@ -42,6 +43,7 @@ class CompletionOutput: token_ids: GenericSequence[int] cumulative_logprob: float | None logprobs: SampleLogprobs | None + routed_experts: np.ndarray | None = None # [seq_len,layer_num,topk] finish_reason: str | None = None stop_reason: int | str | None = None lora_request: LoRARequest | None = None @@ -54,6 +56,7 @@ class CompletionOutput: f"CompletionOutput(index={self.index}, " f"text={self.text!r}, " f"token_ids={self.token_ids}, " + f"routed_experts={self.routed_experts}, " f"cumulative_logprob={self.cumulative_logprob}, " f"logprobs={self.logprobs}, " f"finish_reason={self.finish_reason}, " diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 9a2a589cf..bdd6d2a3c 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -6,6 +6,8 @@ from collections import defaultdict from collections.abc import Iterable from typing import Any +import numpy as np + from vllm import envs from vllm.compilation.cuda_graph import CUDAGraphStat from vllm.config import VllmConfig @@ -24,6 +26,9 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( + RoutedExpertsReader, +) from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import ( EncoderCacheManager, @@ -219,11 +224,31 @@ class Scheduler(SchedulerInterface): ) self.use_pp = self.parallel_config.pipeline_parallel_size > 1 self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER - self.perf_metrics: ModelMetrics | None = None if self.log_stats and vllm_config.observability_config.enable_mfu_metrics: self.perf_metrics = ModelMetrics(vllm_config) + if self.vllm_config.model_config.enable_return_routed_experts: + assert self.dcp_world_size == 1 and self.pcp_world_size == 1, ( + "enable_return_routed_experts does not support context parallelism " + "(dcp_world_size > 1 or pcp_world_size > 1)" + ) + + self.routed_experts_reader = RoutedExpertsReader.create() + + assert len(kv_cache_config.kv_cache_groups) > 0, ( + "enable_return_routed_experts requires at least one kv cache group" + ) + self.max_num_kv_tokens = ( + kv_cache_config.num_blocks // len(kv_cache_config.kv_cache_groups) + 1 + ) * self.block_size + + self.routed_experts_reader.attach_buffer( + max_num_kv_tokens=self.max_num_kv_tokens, + model_config=self.vllm_config.model_config, + instance_id=self.vllm_config.instance_id, + ) + def schedule(self) -> SchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. @@ -1162,7 +1187,30 @@ class Scheduler(SchedulerInterface): request.status = RequestStatus.FINISHED_STOPPED stopped = True + routed_experts = None if stopped: + if self.vllm_config.model_config.enable_return_routed_experts: + kv_blocks = self.kv_cache_manager.get_blocks(request.request_id) + block_ids = kv_blocks.get_block_ids()[0] + num_tokens = request.num_tokens - 1 + + # compute slot mapping + block_ids_array = np.array(block_ids, dtype=np.int32) + num_blocks = len(block_ids) + block_size = self.block_size + + # generate block offsets + block_offsets = np.arange(0, block_size) + + # compute slot mapping: slot = block_id * block_size + offset + slot_mapping = ( + block_offsets.reshape((1, block_size)) + + block_ids_array.reshape((num_blocks, 1)) * block_size + ).flatten()[:num_tokens] + + routed_experts = self.routed_experts_reader.get_routed_experts( + indices=slot_mapping + ) kv_transfer_params = self._free_request(request) if status_before_stop == RequestStatus.RUNNING: stopped_running_reqs.add(request) @@ -1209,6 +1257,7 @@ class Scheduler(SchedulerInterface): kv_transfer_params=kv_transfer_params, trace_headers=request.trace_headers, num_cached_tokens=request.num_cached_tokens, + routed_experts=routed_experts, num_nans_in_logits=request.num_nans_in_logits, ) ) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 27d34f1c6..0ffb97206 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -7,6 +7,7 @@ from collections.abc import Mapping from typing import Any import msgspec +import numpy as np import torch from vllm.lora.request import LoRARequest @@ -139,7 +140,7 @@ class EngineCoreOutput( trace_headers: Mapping[str, str] | None = None # The number of tokens with prefix cache hits. num_cached_tokens: int = 0 - + routed_experts: np.ndarray | None = None # The number of NaNs in logits. # A value greater than 0 indicates that the output is corrupted. num_nans_in_logits: int = 0 diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 7f762bcbb..f461e56ff 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -7,6 +7,7 @@ from collections.abc import Iterable from dataclasses import dataclass from typing import Any, cast +import numpy as np import torch from vllm.lora.request import LoRARequest @@ -213,6 +214,7 @@ class RequestState: finish_reason: FinishReason | None, stop_reason: int | str | None, kv_transfer_params: dict[str, Any] | None = None, + routed_experts: np.ndarray | None = None, ) -> RequestOutput | PoolingRequestOutput | None: finished = finish_reason is not None final_only = self.output_kind == RequestOutputKind.FINAL_ONLY @@ -253,7 +255,9 @@ class RequestState: finished, ) - output = self._new_completion_output(new_token_ids, finish_reason, stop_reason) + output = self._new_completion_output( + new_token_ids, finish_reason, stop_reason, routed_experts + ) if self.parent_req is None: outputs = [output] @@ -316,6 +320,7 @@ class RequestState: token_ids: list[int], finish_reason: FinishReason | None, stop_reason: int | str | None, + routed_experts: np.ndarray | None = None, ) -> CompletionOutput: assert self.detokenizer is not None assert self.logprobs_processor is not None @@ -336,6 +341,7 @@ class RequestState: index=self.request_index, text=text, token_ids=token_ids, + routed_experts=routed_experts, logprobs=logprobs, cumulative_logprob=self.logprobs_processor.cumulative_logprob, finish_reason=str(finish_reason) if finished else None, @@ -527,6 +533,7 @@ class OutputProcessor: finish_reason = engine_core_output.finish_reason stop_reason = engine_core_output.stop_reason kv_transfer_params = engine_core_output.kv_transfer_params + routed_experts = engine_core_output.routed_experts req_state.num_cached_tokens = engine_core_output.num_cached_tokens req_state.is_prefilling = False @@ -552,6 +559,7 @@ class OutputProcessor: finish_reason, stop_reason, kv_transfer_params, + routed_experts, ): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 476a89bb7..e08463c40 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -50,6 +50,9 @@ from vllm.forward_context import ( from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping, LoRAMappingType from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( + RoutedExpertsCapturer, +) from vllm.model_executor.layers.rotary_embedding import ( MRotaryEmbedding, XDRotaryEmbedding, @@ -1633,6 +1636,8 @@ class GPUModelRunner( return blk_table_tensor, slot_mapping block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0) + if self.model_config.enable_return_routed_experts: + self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy() cm_base = CommonAttentionMetadata( query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1], query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1], @@ -3112,6 +3117,13 @@ class GPUModelRunner( "after execute_model() returns None." ) + if self.vllm_config.model_config.enable_return_routed_experts: + capturer = RoutedExpertsCapturer.get_instance() + if capturer is not None: + capturer.clear_buffer() # noqa + else: + logger.error("RoutedExpertsCapturer not initialized.") + if scheduler_output.preempted_req_ids and has_kv_transfer_group(): get_kv_transfer_group().handle_preemptions( scheduler_output.preempted_req_ids @@ -3485,6 +3497,13 @@ class GPUModelRunner( self.eplb_step() with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): + if self.model_config.enable_return_routed_experts: + capturer = RoutedExpertsCapturer.get_instance() + if capturer is not None: + capturer.save_captured_experts(indices=self.slot_mapping) # noqa + else: + logger.error("RoutedExpertsCapturer not initialized.") + output = ModelRunnerOutput( req_ids=req_ids_output_copy, req_id_to_index=req_id_to_index_output_copy, @@ -5646,6 +5665,28 @@ class GPUModelRunner( kv_transfer_group.register_kv_caches(kv_caches) kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks) + if self.model_config.enable_return_routed_experts: + self.init_routed_experts_capturer() + + def init_routed_experts_capturer(self): + logger.info( + "Initializing routed experts capturer, enable_return_routed_experts: %s", + self.model_config.enable_return_routed_experts, + ) + routed_experts_capturer = RoutedExpertsCapturer.create() + block_size = self.cache_config.block_size + self.max_num_kv_tokens = ( + self.kv_cache_config.num_blocks // len(self.kv_cache_config.kv_cache_groups) + + 1 + ) * block_size + + routed_experts_capturer.init_buffer( + max_num_batched_tokens=self.scheduler_config.max_num_batched_tokens, + max_num_kv_tokens=self.max_num_kv_tokens, + model_config=self.model_config, + instance_id=self.vllm_config.instance_id, + ) + def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: """ Add encoder-only layers to the KV cache config.