[Feature]: Remove Chunking From FusedMoE (#34086)

Signed-off-by: SouthWest7 <am1ao@qq.com>
Signed-off-by: Southwest <1403572259@qq.com>
Signed-off-by: southwest <am1ao@qq.com>
Signed-off-by: Xinan Miao <1403572259@qq.com>
Co-authored-by: SouthWest7 <am1ao@qq.com>
This commit is contained in:
Xinan Miao
2026-03-13 02:24:38 +08:00
committed by GitHub
parent c973ecdead
commit 2cdf92228c
28 changed files with 152 additions and 523 deletions

View File

@@ -9,8 +9,6 @@ from typing import final
import torch
import vllm.envs as envs
from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import (
MoEActivation,
@@ -24,14 +22,12 @@ from vllm.model_executor.layers.fused_moe.config import (
)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache,
count_expert_num_tokens,
disable_inplace,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.v1.worker.ubatching import (
dbo_enabled,
dbo_maybe_run_recv_hook,
@@ -719,15 +715,6 @@ class FusedMoEExperts(ABC):
def g2_alphas(self) -> torch.Tensor | None:
return self.quant_config.g2_alphas
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
@abstractmethod
def supports_chunking(self) -> bool:
"""
A flag indicating whether or not this class supports activation
chunking.
"""
raise NotImplementedError
@abstractmethod
def supports_expert_map(self) -> bool:
"""
@@ -742,11 +729,6 @@ class FusedMoEExperts(ABC):
"""
return False
def enable_chunking(self):
return (
envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and self.supports_chunking()
)
class FusedMoEExpertsModular(FusedMoEExperts):
"""
@@ -995,17 +977,6 @@ class FusedMoEExpertsMonolithic(FusedMoEExperts):
raise NotImplementedError
def _slice_scales(
scales: torch.Tensor | None, start: int, end: int
) -> torch.Tensor | None:
if scales is not None:
if scales.numel() == 1:
return scales
else:
return scales[start:end]
return None
################################################################################
# Kernel
################################################################################
@@ -1032,26 +1003,6 @@ class FusedMoEKernelModularImpl:
and moe_parallel_config.use_ep
)
def _chunk_info(self, M: int) -> tuple[int, int]:
"""
Compute number of chunks and chunk size for given M.
If chunking is not supported, set the CHUNK_SIZE to M so we
get num_chunks == 1. Take max(M, 1) to avoid divide by zero.
If there are no tokens to process, the number of chunks will be zero.
"""
CHUNK_SIZE = max(
1,
(
M
if not self.fused_experts.enable_chunking()
else min(M, envs.VLLM_FUSED_MOE_CHUNK_SIZE)
),
)
num_chunks = cdiv(M, CHUNK_SIZE)
# If there are no tokens, then there should be no loop iterations.
assert M > 0 or num_chunks == 0
return num_chunks, CHUNK_SIZE
def _allocate_buffers(
self,
out_dtype: torch.dtype,
@@ -1076,40 +1027,8 @@ class FusedMoEKernelModularImpl:
"""
assert M_full > 0 and M_chunk > 0
num_chunks, _ = self._chunk_info(M_full)
workspace_dtype = self.fused_experts.workspace_dtype(out_dtype)
# Force worst-case allocation in profiling run for
# "mk.FusedMoEKernel.Standard" formats where this is only bounded
# by `VLLM_FUSED_MOE_CHUNK_SIZE` and may not be seen during profiling with
# DP+EP due to the random token routing.
is_profile_run = (
is_forward_context_available()
and get_forward_context().attn_metadata is None
)
if is_profile_run and self.fused_experts.enable_chunking() and self.is_dp_ep:
max_workspace_13, max_workspace_2, max_fused_out_shape = (
self.fused_experts.workspace_shapes(
envs.VLLM_FUSED_MOE_CHUNK_SIZE,
N,
K,
top_k,
global_num_experts,
local_num_experts,
# expert_tokens_meta help in allocating optimal/minimal
# amount of workspace. Mark it None, so we allocate for
# the worst-case scenario.
expert_tokens_meta=None,
activation=activation,
)
)
current_workspace_manager().get_simultaneous(
(max_workspace_13, workspace_dtype),
(max_workspace_2, workspace_dtype),
(max_fused_out_shape, out_dtype),
)
# Get intermediate workspace shapes based off the chunked M size.
workspace13_shape, workspace2_shape, _ = self.fused_experts.workspace_shapes(
M_chunk,
@@ -1136,80 +1055,17 @@ class FusedMoEKernelModularImpl:
# We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1.
# Construct the entire output that can then be processed in chunks.
# Reuse workspace13 for the output in the non-chunked case.
# This will not always be the case for standard
# format experts and with experts that have empty workspaces.
if num_chunks == 1:
max_shape_size = max(prod(workspace13_shape), prod(fused_out_shape))
common_workspace, workspace2 = current_workspace_manager().get_simultaneous(
((max_shape_size,), workspace_dtype),
(workspace2_shape, workspace_dtype),
)
workspace13 = _resize_cache(common_workspace, workspace13_shape)
fused_out = _resize_cache(common_workspace, fused_out_shape)
else:
workspace13, workspace2, fused_out = (
current_workspace_manager().get_simultaneous(
(workspace13_shape, workspace_dtype),
(workspace2_shape, workspace_dtype),
(fused_out_shape, out_dtype),
)
)
# Reuse workspace13 for the output since there is only one chunk.
max_shape_size = max(prod(workspace13_shape), prod(fused_out_shape))
common_workspace, workspace2 = current_workspace_manager().get_simultaneous(
((max_shape_size,), workspace_dtype),
(workspace2_shape, workspace_dtype),
)
workspace13 = _resize_cache(common_workspace, workspace13_shape)
fused_out = _resize_cache(common_workspace, fused_out_shape)
return workspace13, workspace2, fused_out
@staticmethod
def _slice_output_tensor(
fused_out: torch.Tensor,
chunk_idx: int,
num_chunks: int,
CHUNK_SIZE: int,
M: int,
) -> torch.Tensor:
if num_chunks == 1:
return fused_out
assert fused_out.size(0) % M == 0, f"fused_out shape {fused_out.shape} vs M {M}"
factor = fused_out.size(0) // M
out_chunk_size = CHUNK_SIZE * factor
s = chunk_idx * out_chunk_size
e = min(s + out_chunk_size, fused_out.size(0))
return fused_out[s:e]
@staticmethod
def _slice_expert_tokens_metadata(
num_chunks: int,
full_expert_tokens_meta: ExpertTokensMetadata | None,
chunk_topk_ids: torch.Tensor,
local_num_experts: int,
expert_map: torch.Tensor | None,
) -> ExpertTokensMetadata | None:
if num_chunks == 1 or full_expert_tokens_meta is None:
return full_expert_tokens_meta
# The existing expert_num_tokens is for the entire a1q
# input. Chunking forces recomputation of the number
# of tokens assigned to each expert.
c_expert_num_tokens = count_expert_num_tokens(
chunk_topk_ids, local_num_experts, expert_map
)
c_expert_num_tokens_cpu = None
need_expert_num_tokens_cpu = (
full_expert_tokens_meta.expert_num_tokens_cpu is not None
)
if need_expert_num_tokens_cpu:
# This is blocking as some implementations need the count
# on the CPU to determine appropriate input/out fused-moe
# buffers
c_expert_num_tokens_cpu = c_expert_num_tokens.to("cpu", non_blocking=False)
return ExpertTokensMetadata(
expert_num_tokens=c_expert_num_tokens,
expert_num_tokens_cpu=c_expert_num_tokens_cpu,
)
def _prepare(
self,
hidden_states: torch.Tensor,
@@ -1318,18 +1174,6 @@ class FusedMoEKernelModularImpl:
a1q, w1, w2, topk_ids
)
num_chunks, CHUNK_SIZE = self._chunk_info(M_full)
def input_chunk_range(chunk_idx: int) -> tuple[int, int]:
if num_chunks == 1:
# Use a1q.size(0) here since batched format does not
# keep M in the first dimension.
return 0, a1q.size(0)
else:
s = chunk_idx * CHUNK_SIZE
e = min(s + CHUNK_SIZE, M_full)
return s, e
# This happens when none of the tokens from the all2all reach this
# EP rank. Also, note that this is only relevant for CUDAGraph
# incompatible all2all kernels like the DeepEP high-throughput
@@ -1337,58 +1181,39 @@ class FusedMoEKernelModularImpl:
# low-latency kernels are always batched and can never run into
# the tensor.numel() == 0 case.
if M_full == 0:
assert num_chunks == 0
workspace13 = None
workspace2 = None
fused_out = torch.empty_like(a1q, dtype=in_dtype)
else:
assert num_chunks > 0
workspace13, workspace2, fused_out = self._allocate_buffers(
in_dtype,
a1q.device,
CHUNK_SIZE,
M_full,
N,
K,
top_k,
global_num_experts,
local_num_experts,
expert_tokens_meta,
activation,
)
return torch.empty_like(a1q, dtype=in_dtype)
for chunk_idx in range(num_chunks):
s, e = input_chunk_range(chunk_idx)
workspace13, workspace2, fused_out = self._allocate_buffers(
in_dtype,
a1q.device,
M_full,
M_full,
N,
K,
top_k,
global_num_experts,
local_num_experts,
expert_tokens_meta,
activation,
)
c_expert_tokens_meta = self._slice_expert_tokens_metadata(
num_chunks,
expert_tokens_meta,
topk_ids[s:e],
local_num_experts,
expert_map,
)
c_fused_out = self._slice_output_tensor(
fused_out, chunk_idx, num_chunks, CHUNK_SIZE, M_full
)
self.fused_experts.apply(
output=c_fused_out,
hidden_states=a1q[s:e],
w1=w1,
w2=w2,
topk_weights=topk_weights[s:e],
topk_ids=topk_ids[s:e],
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
a1q_scale=_slice_scales(a1q_scale, s, e),
a2_scale=_slice_scales(self.fused_experts.a2_scale, s, e),
workspace13=workspace13,
workspace2=workspace2,
expert_tokens_meta=c_expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
)
self.fused_experts.apply(
output=fused_out,
hidden_states=a1q,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
a1q_scale=a1q_scale,
a2_scale=self.fused_experts.a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
)
return fused_out