diff --git a/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py b/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py index 0fd788ea5..7608e06aa 100644 --- a/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py +++ b/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py @@ -17,8 +17,9 @@ from unittest.mock import patch import numpy as np import torch -from vllm.config import ModelConfig +from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_rank +from vllm.forward_context import get_forward_context logger = logging.getLogger(__name__) @@ -89,7 +90,6 @@ class RoutedExpertsCapturer: self._shm: shared_memory.SharedMemory | None = None self._host_buffer_view: np.ndarray | None = None self._lock_file: str | None = None - self._shm_name: str | None = None @classmethod def create(cls) -> RoutedExpertsCapturer: @@ -110,8 +110,7 @@ class RoutedExpertsCapturer: self, max_num_batched_tokens: int, max_num_kv_tokens: int, - model_config: ModelConfig, - instance_id: str, + vllm_config: VllmConfig, ) -> None: """ Initialize the device buffer and optionally shared memory buffer. @@ -119,14 +118,13 @@ class RoutedExpertsCapturer: Args: max_num_batched_tokens: Maximum number of tokens in a batch. max_num_kv_tokens: Maximum number of KV tokens for shared memory. - model_config: Model configuration containing layer and expert info. - instance_id: Unique identifier for the shared memory buffer. + vllm_config: vllm configuration containing layer and expert info. """ if self._device_buffer is not None: raise RuntimeError("Device buffer has already been initialized") - hf_config = model_config.hf_text_config + hf_config = vllm_config.model_config.hf_text_config num_layers = hf_config.num_hidden_layers num_experts_per_tok = hf_config.num_experts_per_tok @@ -136,6 +134,7 @@ class RoutedExpertsCapturer: dtype=torch.int32, device="cuda", ) + self.dp_rank = vllm_config.parallel_config.data_parallel_rank if get_tensor_model_parallel_rank() != 0: return @@ -143,19 +142,19 @@ class RoutedExpertsCapturer: # Initialize shared memory shape = (max_num_kv_tokens, num_layers, num_experts_per_tok) buffer_size = int(np.prod(shape)) * np.dtype(np.int32).itemsize - - self._lock_file = f"{_LOCK_FILE_PREFIX}_{instance_id}.lock" - self._shm_name = f"{_BUFFER_PREFIX}_{instance_id}" + instance_id = vllm_config.instance_id + self._lock_file = f"{_LOCK_FILE_PREFIX}_{instance_id}_{self.dp_rank}.lock" + shm_name = f"{_BUFFER_PREFIX}_{instance_id}_{self.dp_rank}" self._shm = _create_or_attach_shared_memory( - self._shm_name, buffer_size, self._lock_file + shm_name, buffer_size, self._lock_file ) self._host_buffer_view = np.ndarray(shape, dtype=np.int32, buffer=self._shm.buf) self._host_buffer_view.fill(0) logger.debug( "Created shared memory buffer '%s' with shape %s", - self._shm.name, + shm_name, shape, ) @@ -170,11 +169,24 @@ class RoutedExpertsCapturer: if self._device_buffer is None: raise RuntimeError("Buffer not initialized. Call init_buffer() first.") + ctx = get_forward_context() + if ctx.dp_metadata is None: # single dp + start_loc = 0 + end_loc = topk_ids.shape[0] + token_num_per_dp = topk_ids.shape[0] + else: # multi dp + token_num_per_dp = ctx.dp_metadata.num_tokens_across_dp_cpu[self.dp_rank] + cumsum = torch.cumsum(ctx.dp_metadata.num_tokens_across_dp_cpu, dim=0) + assert cumsum[-1] == topk_ids.shape[0] + end_loc = cumsum[self.dp_rank] + start_loc = end_loc - token_num_per_dp + if layer_id >= self._device_buffer.shape[1]: return - batch_size = topk_ids.shape[0] - self._device_buffer[:batch_size, layer_id, :] = topk_ids + self._device_buffer[:token_num_per_dp, layer_id, :] = topk_ids[ + start_loc:end_loc, : + ] def clear_buffer(self) -> None: """Clear the device buffer.""" @@ -254,30 +266,30 @@ class RoutedExpertsReader: def attach_buffer( self, max_num_kv_tokens: int, - model_config: ModelConfig, - instance_id: str, + vllm_config: VllmConfig, ) -> None: """ Attach to an existing shared memory buffer. Args: max_num_kv_tokens: Maximum number of KV tokens. - model_config: Model configuration. - instance_id: Unique identifier for the shared memory buffer. + vllm_config: vllm configuration. """ if self._shm is not None: logger.warning("Already attached to shared memory buffer.") return # Already attached - hf_config = model_config.hf_text_config + hf_config = vllm_config.model_config.hf_text_config shape = ( max_num_kv_tokens, hf_config.num_hidden_layers, hf_config.num_experts_per_tok, ) - self._lock_file = f"{_LOCK_FILE_PREFIX}_{instance_id}.lock" - shm_name = f"{_BUFFER_PREFIX}_{instance_id}" + self.dp_rank = vllm_config.parallel_config.data_parallel_rank + instance_id = vllm_config.instance_id + self._lock_file = f"{_LOCK_FILE_PREFIX}_{instance_id}_{self.dp_rank}.lock" + shm_name = f"{_BUFFER_PREFIX}_{instance_id}_{self.dp_rank}" with _file_lock(self._lock_file, mode="rb+"): # Avoid resource_tracker registering the shared memory diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 1660d5189..a6d6ae93e 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -245,8 +245,7 @@ class Scheduler(SchedulerInterface): self.routed_experts_reader.attach_buffer( max_num_kv_tokens=self.max_num_kv_tokens, - model_config=self.vllm_config.model_config, - instance_id=self.vllm_config.instance_id, + vllm_config=self.vllm_config, ) def schedule(self) -> SchedulerOutput: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 00e401f41..5691a7698 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5675,12 +5675,10 @@ class GPUModelRunner( self.kv_cache_config.num_blocks // len(self.kv_cache_config.kv_cache_groups) + 1 ) * block_size - routed_experts_capturer.init_buffer( max_num_batched_tokens=self.scheduler_config.max_num_batched_tokens, max_num_kv_tokens=self.max_num_kv_tokens, - model_config=self.model_config, - instance_id=self.vllm_config.instance_id, + vllm_config=self.vllm_config, ) def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: