[Feature] Support recording expert indices for rollout router replay (#28284)
Signed-off-by: xhx1022 <1737006628@qq.com> Signed-off-by: Hongxin Xu <70438206+xhx1022@users.noreply.github.com> Signed-off-by: arlenxu <arlenxu@tencent.com> Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: arlenxu <arlenxu@tencent.com>
This commit is contained in:
@@ -35,6 +35,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
init_aiter_topK_meta_data,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
|
||||
RoutedExpertsCapturer,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
@@ -701,6 +704,13 @@ class FusedMoE(CustomOp):
|
||||
def shared_experts(self) -> torch.nn.Module | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def layer_id(self):
|
||||
# Delayed import to avoid circular dependency
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
|
||||
return extract_layer_index(self.layer_name)
|
||||
|
||||
@property
|
||||
def gate(self) -> torch.nn.Module | None:
|
||||
return None
|
||||
@@ -1650,6 +1660,18 @@ class FusedMoE(CustomOp):
|
||||
|
||||
assert topk_ids.dtype == indices_type or indices_type is None
|
||||
|
||||
if (
|
||||
self.vllm_config.model_config is not None
|
||||
and self.vllm_config.model_config.enable_return_routed_experts
|
||||
):
|
||||
# In dummy runs, the capturer is not initialized.
|
||||
capturer = RoutedExpertsCapturer.get_instance()
|
||||
if capturer is not None: # in dummmy_run may be None
|
||||
capturer.capture( # noqa
|
||||
layer_id=self.layer_id,
|
||||
topk_ids=topk_ids,
|
||||
)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
def must_reduce_shared_expert_outputs(self) -> bool:
|
||||
|
||||
324
vllm/model_executor/layers/fused_moe/routed_experts_capturer.py
Normal file
324
vllm/model_executor/layers/fused_moe/routed_experts_capturer.py
Normal file
@@ -0,0 +1,324 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Adapted from
|
||||
# https://github.com/sgl-project/sglang/blob/bed301a5acaa9577c9aa706468bdf242f6a43051/python/sglang/srt/layers/moe/routed_experts_capturer.py
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import fcntl
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from multiprocessing import shared_memory
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
_TMP_DIR = tempfile.gettempdir()
|
||||
_LOCK_FILE_PREFIX = os.path.join(_TMP_DIR, "vllm_routed_experts")
|
||||
_BUFFER_PREFIX = "vllm_routed_experts_buffer"
|
||||
|
||||
# Global singleton instances
|
||||
_global_experts_capturer: RoutedExpertsCapturer | None = None
|
||||
_global_experts_reader: RoutedExpertsReader | None = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _file_lock(lock_file: str, mode: str = "wb+") -> Generator[None, None, None]:
|
||||
"""Context manager for file-based locking."""
|
||||
with open(lock_file, mode) as fp:
|
||||
fcntl.flock(fp, fcntl.LOCK_EX)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
fcntl.flock(fp, fcntl.LOCK_UN)
|
||||
|
||||
|
||||
def _create_or_attach_shared_memory(
|
||||
name: str, size: int, lock_file: str
|
||||
) -> shared_memory.SharedMemory:
|
||||
"""Create or attach to shared memory with proper locking."""
|
||||
# Ensure lock file exists before acquiring lock
|
||||
with open(lock_file, "wb"):
|
||||
pass
|
||||
|
||||
with _file_lock(lock_file):
|
||||
try:
|
||||
shm = shared_memory.SharedMemory(name=name, create=True, size=size)
|
||||
except FileExistsError:
|
||||
shm = shared_memory.SharedMemory(name=name, create=False, size=size)
|
||||
|
||||
if shm.size != size:
|
||||
logger.warning(
|
||||
"Shared memory %s size mismatch; recreating",
|
||||
name,
|
||||
)
|
||||
shm.close()
|
||||
shm.unlink()
|
||||
try:
|
||||
shm = shared_memory.SharedMemory(name=name, create=True, size=size)
|
||||
logger.info("Created shared memory %s", name)
|
||||
except FileExistsError:
|
||||
shm = shared_memory.SharedMemory(name=name, create=False, size=size)
|
||||
logger.info("Linked to existing shared memory %s", name)
|
||||
|
||||
return shm
|
||||
|
||||
|
||||
class RoutedExpertsCapturer:
|
||||
"""
|
||||
Capturer for routed experts with device and optional shared memory buffer.
|
||||
|
||||
This class captures expert routing decisions during model forward passes
|
||||
and optionally stores them in shared memory for cross-process access.
|
||||
"""
|
||||
|
||||
_instance: RoutedExpertsCapturer | None = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._device_buffer: torch.Tensor | None = None
|
||||
self._shm: shared_memory.SharedMemory | None = None
|
||||
self._host_buffer_view: np.ndarray | None = None
|
||||
self._lock_file: str | None = None
|
||||
self._shm_name: str | None = None
|
||||
|
||||
@classmethod
|
||||
def create(cls) -> RoutedExpertsCapturer:
|
||||
"""Create a global singleton instance."""
|
||||
global _global_experts_capturer
|
||||
if _global_experts_capturer is not None:
|
||||
raise RuntimeError("Experts capturer already created.")
|
||||
|
||||
_global_experts_capturer = cls()
|
||||
return _global_experts_capturer
|
||||
|
||||
@staticmethod
|
||||
def get_instance() -> RoutedExpertsCapturer | None:
|
||||
"""Get the global singleton instance."""
|
||||
return _global_experts_capturer
|
||||
|
||||
def init_buffer(
|
||||
self,
|
||||
max_num_batched_tokens: int,
|
||||
max_num_kv_tokens: int,
|
||||
model_config: ModelConfig,
|
||||
instance_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the device buffer and optionally shared memory buffer.
|
||||
|
||||
Args:
|
||||
max_num_batched_tokens: Maximum number of tokens in a batch.
|
||||
max_num_kv_tokens: Maximum number of KV tokens for shared memory.
|
||||
model_config: Model configuration containing layer and expert info.
|
||||
instance_id: Unique identifier for the shared memory buffer.
|
||||
"""
|
||||
|
||||
if self._device_buffer is not None:
|
||||
raise RuntimeError("Device buffer has already been initialized")
|
||||
|
||||
hf_config = model_config.hf_text_config
|
||||
num_layers = hf_config.num_hidden_layers
|
||||
num_experts_per_tok = hf_config.num_experts_per_tok
|
||||
|
||||
# Initialize device buffer
|
||||
self._device_buffer = torch.zeros(
|
||||
(max_num_batched_tokens, num_layers, num_experts_per_tok),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
if get_tensor_model_parallel_rank() != 0:
|
||||
return
|
||||
|
||||
# Initialize shared memory
|
||||
shape = (max_num_kv_tokens, num_layers, num_experts_per_tok)
|
||||
buffer_size = int(np.prod(shape)) * np.dtype(np.int32).itemsize
|
||||
|
||||
self._lock_file = f"{_LOCK_FILE_PREFIX}_{instance_id}.lock"
|
||||
self._shm_name = f"{_BUFFER_PREFIX}_{instance_id}"
|
||||
|
||||
self._shm = _create_or_attach_shared_memory(
|
||||
self._shm_name, buffer_size, self._lock_file
|
||||
)
|
||||
self._host_buffer_view = np.ndarray(shape, dtype=np.int32, buffer=self._shm.buf)
|
||||
self._host_buffer_view.fill(0)
|
||||
|
||||
logger.debug(
|
||||
"Created shared memory buffer '%s' with shape %s",
|
||||
self._shm.name,
|
||||
shape,
|
||||
)
|
||||
|
||||
def capture(self, layer_id: int, topk_ids: torch.Tensor) -> None:
|
||||
"""
|
||||
Capture expert routing decisions for a specific layer.
|
||||
|
||||
Args:
|
||||
layer_id: The layer index.
|
||||
topk_ids: Tensor of shape (batch_size, num_routed_experts).
|
||||
"""
|
||||
if self._device_buffer is None:
|
||||
raise RuntimeError("Buffer not initialized. Call init_buffer() first.")
|
||||
|
||||
if layer_id >= self._device_buffer.shape[1]:
|
||||
return
|
||||
|
||||
batch_size = topk_ids.shape[0]
|
||||
self._device_buffer[:batch_size, layer_id, :] = topk_ids
|
||||
|
||||
def clear_buffer(self) -> None:
|
||||
"""Clear the device buffer."""
|
||||
if self._device_buffer is not None:
|
||||
self._device_buffer.zero_()
|
||||
|
||||
def save_captured_experts(self, indices: np.ndarray) -> None:
|
||||
"""
|
||||
Save captured experts from device buffer to shared memory.
|
||||
|
||||
Args:
|
||||
indices: Array of indices indicating where to store the data.
|
||||
"""
|
||||
if get_tensor_model_parallel_rank() != 0:
|
||||
return
|
||||
if self._lock_file is None:
|
||||
raise RuntimeError("Shared memory not initialized.")
|
||||
if self._host_buffer_view is None:
|
||||
return
|
||||
if self._device_buffer is None:
|
||||
raise RuntimeError("Device buffer not initialized.")
|
||||
|
||||
num_tokens = len(indices)
|
||||
data = self._device_buffer[:num_tokens, :, :].cpu().numpy()
|
||||
|
||||
with _file_lock(self._lock_file):
|
||||
self._host_buffer_view[indices, :, :] = data
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Explicitly clean up shared memory resources."""
|
||||
if self._shm is not None:
|
||||
try:
|
||||
self._shm.close()
|
||||
self._shm.unlink()
|
||||
except Exception:
|
||||
logger.debug("Exception during cleanup for capturer", exc_info=True)
|
||||
finally:
|
||||
self._shm = None
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Clean up shared memory on destruction."""
|
||||
self.cleanup()
|
||||
|
||||
|
||||
class RoutedExpertsReader:
|
||||
"""
|
||||
Reader for routed experts from shared memory.
|
||||
|
||||
This class attaches to shared memory created by RoutedExpertsCapturer
|
||||
and reads expert routing decisions.
|
||||
"""
|
||||
|
||||
_instance: RoutedExpertsReader | None = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._shm: shared_memory.SharedMemory | None = None
|
||||
self._host_buffer_view: np.ndarray | None = None
|
||||
self._lock_file: str | None = None
|
||||
|
||||
@classmethod
|
||||
def create(cls) -> RoutedExpertsReader:
|
||||
"""Create a global singleton instance."""
|
||||
global _global_experts_reader
|
||||
if _global_experts_reader is not None:
|
||||
raise RuntimeError("Experts reader already created.")
|
||||
|
||||
_global_experts_reader = cls()
|
||||
return _global_experts_reader
|
||||
|
||||
@staticmethod
|
||||
def get_instance() -> RoutedExpertsReader | None:
|
||||
"""Get the global singleton instance."""
|
||||
if _global_experts_reader is None:
|
||||
logger.info("Experts reader not initialized.")
|
||||
return _global_experts_reader
|
||||
|
||||
def attach_buffer(
|
||||
self,
|
||||
max_num_kv_tokens: int,
|
||||
model_config: ModelConfig,
|
||||
instance_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Attach to an existing shared memory buffer.
|
||||
|
||||
Args:
|
||||
max_num_kv_tokens: Maximum number of KV tokens.
|
||||
model_config: Model configuration.
|
||||
instance_id: Unique identifier for the shared memory buffer.
|
||||
"""
|
||||
if self._shm is not None:
|
||||
logger.warning("Already attached to shared memory buffer.")
|
||||
return # Already attached
|
||||
|
||||
hf_config = model_config.hf_text_config
|
||||
shape = (
|
||||
max_num_kv_tokens,
|
||||
hf_config.num_hidden_layers,
|
||||
hf_config.num_experts_per_tok,
|
||||
)
|
||||
|
||||
self._lock_file = f"{_LOCK_FILE_PREFIX}_{instance_id}.lock"
|
||||
shm_name = f"{_BUFFER_PREFIX}_{instance_id}"
|
||||
|
||||
with _file_lock(self._lock_file, mode="rb+"):
|
||||
# Avoid resource_tracker registering the shared memory
|
||||
with patch(
|
||||
"multiprocessing.resource_tracker.register",
|
||||
lambda *args, **kwargs: None,
|
||||
):
|
||||
self._shm = shared_memory.SharedMemory(name=shm_name)
|
||||
|
||||
self._host_buffer_view = np.ndarray(
|
||||
shape, dtype=np.int32, buffer=self._shm.buf
|
||||
)
|
||||
|
||||
def get_routed_experts(self, indices: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Read routed expert data from shared memory.
|
||||
|
||||
Args:
|
||||
indices: Array of indices to read.
|
||||
|
||||
Returns:
|
||||
Copy of the expert routing data for the given indices.
|
||||
"""
|
||||
if self._host_buffer_view is None:
|
||||
raise RuntimeError("Buffer not attached. Call attach_buffer() first.")
|
||||
if self._lock_file is None:
|
||||
raise RuntimeError("Lock file not initialized.")
|
||||
|
||||
with _file_lock(self._lock_file, mode="rb+"):
|
||||
return self._host_buffer_view[indices, :, :].copy()
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Explicitly clean up resources (close without unlink)."""
|
||||
if self._shm is not None:
|
||||
try:
|
||||
self._shm.close()
|
||||
except Exception:
|
||||
logger.debug("Exception during cleanup for reader", exc_info=True)
|
||||
finally:
|
||||
self._shm = None
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Close shared memory on destruction (do not unlink)."""
|
||||
self.cleanup()
|
||||
Reference in New Issue
Block a user