[Feature] Support recording expert indices for rollout router replay (#28284)

Signed-off-by: xhx1022 <1737006628@qq.com>
Signed-off-by: Hongxin Xu <70438206+xhx1022@users.noreply.github.com>
Signed-off-by: arlenxu <arlenxu@tencent.com>
Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Co-authored-by: arlenxu <arlenxu@tencent.com>
This commit is contained in:
Hongxin Xu
2026-01-12 22:23:04 +08:00
committed by GitHub
parent 0565f1fdec
commit 49e6b86c91
11 changed files with 463 additions and 3 deletions

View File

@@ -35,6 +35,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
init_aiter_topK_meta_data,
)
from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
RoutedExpertsCapturer,
)
from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
@@ -701,6 +704,13 @@ class FusedMoE(CustomOp):
def shared_experts(self) -> torch.nn.Module | None:
return None
@property
def layer_id(self):
# Delayed import to avoid circular dependency
from vllm.model_executor.models.utils import extract_layer_index
return extract_layer_index(self.layer_name)
@property
def gate(self) -> torch.nn.Module | None:
return None
@@ -1650,6 +1660,18 @@ class FusedMoE(CustomOp):
assert topk_ids.dtype == indices_type or indices_type is None
if (
self.vllm_config.model_config is not None
and self.vllm_config.model_config.enable_return_routed_experts
):
# In dummy runs, the capturer is not initialized.
capturer = RoutedExpertsCapturer.get_instance()
if capturer is not None: # in dummmy_run may be None
capturer.capture( # noqa
layer_id=self.layer_id,
topk_ids=topk_ids,
)
return topk_weights, topk_ids
def must_reduce_shared_expert_outputs(self) -> bool:

View File

@@ -0,0 +1,324 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/sgl-project/sglang/blob/bed301a5acaa9577c9aa706468bdf242f6a43051/python/sglang/srt/layers/moe/routed_experts_capturer.py
from __future__ import annotations
import fcntl
import logging
import os
import tempfile
from collections.abc import Generator
from contextlib import contextmanager
from multiprocessing import shared_memory
from unittest.mock import patch
import numpy as np
import torch
from vllm.config import ModelConfig
from vllm.distributed import get_tensor_model_parallel_rank
logger = logging.getLogger(__name__)
# Constants
_TMP_DIR = tempfile.gettempdir()
_LOCK_FILE_PREFIX = os.path.join(_TMP_DIR, "vllm_routed_experts")
_BUFFER_PREFIX = "vllm_routed_experts_buffer"
# Global singleton instances
_global_experts_capturer: RoutedExpertsCapturer | None = None
_global_experts_reader: RoutedExpertsReader | None = None
@contextmanager
def _file_lock(lock_file: str, mode: str = "wb+") -> Generator[None, None, None]:
"""Context manager for file-based locking."""
with open(lock_file, mode) as fp:
fcntl.flock(fp, fcntl.LOCK_EX)
try:
yield
finally:
fcntl.flock(fp, fcntl.LOCK_UN)
def _create_or_attach_shared_memory(
name: str, size: int, lock_file: str
) -> shared_memory.SharedMemory:
"""Create or attach to shared memory with proper locking."""
# Ensure lock file exists before acquiring lock
with open(lock_file, "wb"):
pass
with _file_lock(lock_file):
try:
shm = shared_memory.SharedMemory(name=name, create=True, size=size)
except FileExistsError:
shm = shared_memory.SharedMemory(name=name, create=False, size=size)
if shm.size != size:
logger.warning(
"Shared memory %s size mismatch; recreating",
name,
)
shm.close()
shm.unlink()
try:
shm = shared_memory.SharedMemory(name=name, create=True, size=size)
logger.info("Created shared memory %s", name)
except FileExistsError:
shm = shared_memory.SharedMemory(name=name, create=False, size=size)
logger.info("Linked to existing shared memory %s", name)
return shm
class RoutedExpertsCapturer:
"""
Capturer for routed experts with device and optional shared memory buffer.
This class captures expert routing decisions during model forward passes
and optionally stores them in shared memory for cross-process access.
"""
_instance: RoutedExpertsCapturer | None = None
def __init__(self) -> None:
self._device_buffer: torch.Tensor | None = None
self._shm: shared_memory.SharedMemory | None = None
self._host_buffer_view: np.ndarray | None = None
self._lock_file: str | None = None
self._shm_name: str | None = None
@classmethod
def create(cls) -> RoutedExpertsCapturer:
"""Create a global singleton instance."""
global _global_experts_capturer
if _global_experts_capturer is not None:
raise RuntimeError("Experts capturer already created.")
_global_experts_capturer = cls()
return _global_experts_capturer
@staticmethod
def get_instance() -> RoutedExpertsCapturer | None:
"""Get the global singleton instance."""
return _global_experts_capturer
def init_buffer(
self,
max_num_batched_tokens: int,
max_num_kv_tokens: int,
model_config: ModelConfig,
instance_id: str,
) -> None:
"""
Initialize the device buffer and optionally shared memory buffer.
Args:
max_num_batched_tokens: Maximum number of tokens in a batch.
max_num_kv_tokens: Maximum number of KV tokens for shared memory.
model_config: Model configuration containing layer and expert info.
instance_id: Unique identifier for the shared memory buffer.
"""
if self._device_buffer is not None:
raise RuntimeError("Device buffer has already been initialized")
hf_config = model_config.hf_text_config
num_layers = hf_config.num_hidden_layers
num_experts_per_tok = hf_config.num_experts_per_tok
# Initialize device buffer
self._device_buffer = torch.zeros(
(max_num_batched_tokens, num_layers, num_experts_per_tok),
dtype=torch.int32,
device="cuda",
)
if get_tensor_model_parallel_rank() != 0:
return
# Initialize shared memory
shape = (max_num_kv_tokens, num_layers, num_experts_per_tok)
buffer_size = int(np.prod(shape)) * np.dtype(np.int32).itemsize
self._lock_file = f"{_LOCK_FILE_PREFIX}_{instance_id}.lock"
self._shm_name = f"{_BUFFER_PREFIX}_{instance_id}"
self._shm = _create_or_attach_shared_memory(
self._shm_name, buffer_size, self._lock_file
)
self._host_buffer_view = np.ndarray(shape, dtype=np.int32, buffer=self._shm.buf)
self._host_buffer_view.fill(0)
logger.debug(
"Created shared memory buffer '%s' with shape %s",
self._shm.name,
shape,
)
def capture(self, layer_id: int, topk_ids: torch.Tensor) -> None:
"""
Capture expert routing decisions for a specific layer.
Args:
layer_id: The layer index.
topk_ids: Tensor of shape (batch_size, num_routed_experts).
"""
if self._device_buffer is None:
raise RuntimeError("Buffer not initialized. Call init_buffer() first.")
if layer_id >= self._device_buffer.shape[1]:
return
batch_size = topk_ids.shape[0]
self._device_buffer[:batch_size, layer_id, :] = topk_ids
def clear_buffer(self) -> None:
"""Clear the device buffer."""
if self._device_buffer is not None:
self._device_buffer.zero_()
def save_captured_experts(self, indices: np.ndarray) -> None:
"""
Save captured experts from device buffer to shared memory.
Args:
indices: Array of indices indicating where to store the data.
"""
if get_tensor_model_parallel_rank() != 0:
return
if self._lock_file is None:
raise RuntimeError("Shared memory not initialized.")
if self._host_buffer_view is None:
return
if self._device_buffer is None:
raise RuntimeError("Device buffer not initialized.")
num_tokens = len(indices)
data = self._device_buffer[:num_tokens, :, :].cpu().numpy()
with _file_lock(self._lock_file):
self._host_buffer_view[indices, :, :] = data
def cleanup(self) -> None:
"""Explicitly clean up shared memory resources."""
if self._shm is not None:
try:
self._shm.close()
self._shm.unlink()
except Exception:
logger.debug("Exception during cleanup for capturer", exc_info=True)
finally:
self._shm = None
def __del__(self) -> None:
"""Clean up shared memory on destruction."""
self.cleanup()
class RoutedExpertsReader:
"""
Reader for routed experts from shared memory.
This class attaches to shared memory created by RoutedExpertsCapturer
and reads expert routing decisions.
"""
_instance: RoutedExpertsReader | None = None
def __init__(self) -> None:
self._shm: shared_memory.SharedMemory | None = None
self._host_buffer_view: np.ndarray | None = None
self._lock_file: str | None = None
@classmethod
def create(cls) -> RoutedExpertsReader:
"""Create a global singleton instance."""
global _global_experts_reader
if _global_experts_reader is not None:
raise RuntimeError("Experts reader already created.")
_global_experts_reader = cls()
return _global_experts_reader
@staticmethod
def get_instance() -> RoutedExpertsReader | None:
"""Get the global singleton instance."""
if _global_experts_reader is None:
logger.info("Experts reader not initialized.")
return _global_experts_reader
def attach_buffer(
self,
max_num_kv_tokens: int,
model_config: ModelConfig,
instance_id: str,
) -> None:
"""
Attach to an existing shared memory buffer.
Args:
max_num_kv_tokens: Maximum number of KV tokens.
model_config: Model configuration.
instance_id: Unique identifier for the shared memory buffer.
"""
if self._shm is not None:
logger.warning("Already attached to shared memory buffer.")
return # Already attached
hf_config = model_config.hf_text_config
shape = (
max_num_kv_tokens,
hf_config.num_hidden_layers,
hf_config.num_experts_per_tok,
)
self._lock_file = f"{_LOCK_FILE_PREFIX}_{instance_id}.lock"
shm_name = f"{_BUFFER_PREFIX}_{instance_id}"
with _file_lock(self._lock_file, mode="rb+"):
# Avoid resource_tracker registering the shared memory
with patch(
"multiprocessing.resource_tracker.register",
lambda *args, **kwargs: None,
):
self._shm = shared_memory.SharedMemory(name=shm_name)
self._host_buffer_view = np.ndarray(
shape, dtype=np.int32, buffer=self._shm.buf
)
def get_routed_experts(self, indices: np.ndarray) -> np.ndarray:
"""
Read routed expert data from shared memory.
Args:
indices: Array of indices to read.
Returns:
Copy of the expert routing data for the given indices.
"""
if self._host_buffer_view is None:
raise RuntimeError("Buffer not attached. Call attach_buffer() first.")
if self._lock_file is None:
raise RuntimeError("Lock file not initialized.")
with _file_lock(self._lock_file, mode="rb+"):
return self._host_buffer_view[indices, :, :].copy()
def cleanup(self) -> None:
"""Explicitly clean up resources (close without unlink)."""
if self._shm is not None:
try:
self._shm.close()
except Exception:
logger.debug("Exception during cleanup for reader", exc_info=True)
finally:
self._shm = None
def __del__(self) -> None:
"""Close shared memory on destruction (do not unlink)."""
self.cleanup()