[BugFix] Make sure to allocate worst case MoE workspace during profile run in the DP + EP case (#27426)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2025-11-21 14:41:52 -05:00
committed by GitHub
parent 1bed891f72
commit 1840c5cb18
2 changed files with 43 additions and 2 deletions

View File

@@ -10,6 +10,9 @@ from typing import final
import torch
import vllm.envs as envs
from vllm.config import get_current_vllm_config
from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache,
@@ -26,6 +29,8 @@ from vllm.v1.worker.ubatching import (
dbo_yield,
)
logger = init_logger(__name__)
#
# This file defines a set of base classes used to make MoE kernels more modular.
# The goal is to be able to utilize different communication mechanisms with
@@ -798,6 +803,42 @@ class FusedMoEModularKernel(torch.nn.Module):
buffers = self.shared_buffers[ubatch_idx]
workspace_dtype = self.fused_experts.workspace_dtype(out_dtype)
# Force worst-case allocation in profiling run for
# "mk.FusedMoEModularKernel.Standard" formats where this is only bounded
# by `VLLM_FUSED_MOE_CHUNK_SIZE` and may not be seen during profiling with
# DP+EP due to the random token routing.
is_profile_run = (
is_forward_context_available()
and get_forward_context().attn_metadata is None
)
if is_profile_run and self.fused_experts.supports_chunking():
parallel_config = get_current_vllm_config().parallel_config
is_dp_ep = (
parallel_config.data_parallel_size > 1
and parallel_config.enable_expert_parallel
)
if is_dp_ep:
max_workspace_13, max_workspace_2, max_fused_out_shape = (
self.fused_experts.workspace_shapes(
envs.VLLM_FUSED_MOE_CHUNK_SIZE,
N,
K,
top_k,
global_num_experts,
local_num_experts,
expert_tokens_meta,
)
)
buffers.workspace13.get(
max_workspace_13, device=device, dtype=workspace_dtype
)
buffers.workspace2.get(
max_workspace_2, device=device, dtype=workspace_dtype
)
buffers.fused_out.get(
max_fused_out_shape, device=device, dtype=workspace_dtype
)
# Get intermediate workspace shapes based off the chunked M size.
workspace13_shape, workspace2_shape, _ = self.fused_experts.workspace_shapes(
M_chunk,