[Feature]: Remove Chunking From FusedMoE (#34086)
Signed-off-by: SouthWest7 <am1ao@qq.com> Signed-off-by: Southwest <1403572259@qq.com> Signed-off-by: southwest <am1ao@qq.com> Signed-off-by: Xinan Miao <1403572259@qq.com> Co-authored-by: SouthWest7 <am1ao@qq.com>
This commit is contained in:
@@ -9,8 +9,6 @@ from typing import final
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.forward_context import get_forward_context, is_forward_context_available
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import (
|
||||
MoEActivation,
|
||||
@@ -24,14 +22,12 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache,
|
||||
count_expert_num_tokens,
|
||||
disable_inplace,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.worker.ubatching import (
|
||||
dbo_enabled,
|
||||
dbo_maybe_run_recv_hook,
|
||||
@@ -719,15 +715,6 @@ class FusedMoEExperts(ABC):
|
||||
def g2_alphas(self) -> torch.Tensor | None:
|
||||
return self.quant_config.g2_alphas
|
||||
|
||||
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
|
||||
@abstractmethod
|
||||
def supports_chunking(self) -> bool:
|
||||
"""
|
||||
A flag indicating whether or not this class supports activation
|
||||
chunking.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def supports_expert_map(self) -> bool:
|
||||
"""
|
||||
@@ -742,11 +729,6 @@ class FusedMoEExperts(ABC):
|
||||
"""
|
||||
return False
|
||||
|
||||
def enable_chunking(self):
|
||||
return (
|
||||
envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and self.supports_chunking()
|
||||
)
|
||||
|
||||
|
||||
class FusedMoEExpertsModular(FusedMoEExperts):
|
||||
"""
|
||||
@@ -995,17 +977,6 @@ class FusedMoEExpertsMonolithic(FusedMoEExperts):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _slice_scales(
|
||||
scales: torch.Tensor | None, start: int, end: int
|
||||
) -> torch.Tensor | None:
|
||||
if scales is not None:
|
||||
if scales.numel() == 1:
|
||||
return scales
|
||||
else:
|
||||
return scales[start:end]
|
||||
return None
|
||||
|
||||
|
||||
################################################################################
|
||||
# Kernel
|
||||
################################################################################
|
||||
@@ -1032,26 +1003,6 @@ class FusedMoEKernelModularImpl:
|
||||
and moe_parallel_config.use_ep
|
||||
)
|
||||
|
||||
def _chunk_info(self, M: int) -> tuple[int, int]:
|
||||
"""
|
||||
Compute number of chunks and chunk size for given M.
|
||||
If chunking is not supported, set the CHUNK_SIZE to M so we
|
||||
get num_chunks == 1. Take max(M, 1) to avoid divide by zero.
|
||||
If there are no tokens to process, the number of chunks will be zero.
|
||||
"""
|
||||
CHUNK_SIZE = max(
|
||||
1,
|
||||
(
|
||||
M
|
||||
if not self.fused_experts.enable_chunking()
|
||||
else min(M, envs.VLLM_FUSED_MOE_CHUNK_SIZE)
|
||||
),
|
||||
)
|
||||
num_chunks = cdiv(M, CHUNK_SIZE)
|
||||
# If there are no tokens, then there should be no loop iterations.
|
||||
assert M > 0 or num_chunks == 0
|
||||
return num_chunks, CHUNK_SIZE
|
||||
|
||||
def _allocate_buffers(
|
||||
self,
|
||||
out_dtype: torch.dtype,
|
||||
@@ -1076,40 +1027,8 @@ class FusedMoEKernelModularImpl:
|
||||
"""
|
||||
assert M_full > 0 and M_chunk > 0
|
||||
|
||||
num_chunks, _ = self._chunk_info(M_full)
|
||||
workspace_dtype = self.fused_experts.workspace_dtype(out_dtype)
|
||||
|
||||
# Force worst-case allocation in profiling run for
|
||||
# "mk.FusedMoEKernel.Standard" formats where this is only bounded
|
||||
# by `VLLM_FUSED_MOE_CHUNK_SIZE` and may not be seen during profiling with
|
||||
# DP+EP due to the random token routing.
|
||||
is_profile_run = (
|
||||
is_forward_context_available()
|
||||
and get_forward_context().attn_metadata is None
|
||||
)
|
||||
if is_profile_run and self.fused_experts.enable_chunking() and self.is_dp_ep:
|
||||
max_workspace_13, max_workspace_2, max_fused_out_shape = (
|
||||
self.fused_experts.workspace_shapes(
|
||||
envs.VLLM_FUSED_MOE_CHUNK_SIZE,
|
||||
N,
|
||||
K,
|
||||
top_k,
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
# expert_tokens_meta help in allocating optimal/minimal
|
||||
# amount of workspace. Mark it None, so we allocate for
|
||||
# the worst-case scenario.
|
||||
expert_tokens_meta=None,
|
||||
activation=activation,
|
||||
)
|
||||
)
|
||||
|
||||
current_workspace_manager().get_simultaneous(
|
||||
(max_workspace_13, workspace_dtype),
|
||||
(max_workspace_2, workspace_dtype),
|
||||
(max_fused_out_shape, out_dtype),
|
||||
)
|
||||
|
||||
# Get intermediate workspace shapes based off the chunked M size.
|
||||
workspace13_shape, workspace2_shape, _ = self.fused_experts.workspace_shapes(
|
||||
M_chunk,
|
||||
@@ -1136,80 +1055,17 @@ class FusedMoEKernelModularImpl:
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the
|
||||
# time we need cache3, we're done with cache1.
|
||||
# Construct the entire output that can then be processed in chunks.
|
||||
# Reuse workspace13 for the output in the non-chunked case.
|
||||
# This will not always be the case for standard
|
||||
# format experts and with experts that have empty workspaces.
|
||||
if num_chunks == 1:
|
||||
max_shape_size = max(prod(workspace13_shape), prod(fused_out_shape))
|
||||
common_workspace, workspace2 = current_workspace_manager().get_simultaneous(
|
||||
((max_shape_size,), workspace_dtype),
|
||||
(workspace2_shape, workspace_dtype),
|
||||
)
|
||||
workspace13 = _resize_cache(common_workspace, workspace13_shape)
|
||||
fused_out = _resize_cache(common_workspace, fused_out_shape)
|
||||
else:
|
||||
workspace13, workspace2, fused_out = (
|
||||
current_workspace_manager().get_simultaneous(
|
||||
(workspace13_shape, workspace_dtype),
|
||||
(workspace2_shape, workspace_dtype),
|
||||
(fused_out_shape, out_dtype),
|
||||
)
|
||||
)
|
||||
# Reuse workspace13 for the output since there is only one chunk.
|
||||
max_shape_size = max(prod(workspace13_shape), prod(fused_out_shape))
|
||||
common_workspace, workspace2 = current_workspace_manager().get_simultaneous(
|
||||
((max_shape_size,), workspace_dtype),
|
||||
(workspace2_shape, workspace_dtype),
|
||||
)
|
||||
workspace13 = _resize_cache(common_workspace, workspace13_shape)
|
||||
fused_out = _resize_cache(common_workspace, fused_out_shape)
|
||||
|
||||
return workspace13, workspace2, fused_out
|
||||
|
||||
@staticmethod
|
||||
def _slice_output_tensor(
|
||||
fused_out: torch.Tensor,
|
||||
chunk_idx: int,
|
||||
num_chunks: int,
|
||||
CHUNK_SIZE: int,
|
||||
M: int,
|
||||
) -> torch.Tensor:
|
||||
if num_chunks == 1:
|
||||
return fused_out
|
||||
|
||||
assert fused_out.size(0) % M == 0, f"fused_out shape {fused_out.shape} vs M {M}"
|
||||
factor = fused_out.size(0) // M
|
||||
out_chunk_size = CHUNK_SIZE * factor
|
||||
s = chunk_idx * out_chunk_size
|
||||
e = min(s + out_chunk_size, fused_out.size(0))
|
||||
return fused_out[s:e]
|
||||
|
||||
@staticmethod
|
||||
def _slice_expert_tokens_metadata(
|
||||
num_chunks: int,
|
||||
full_expert_tokens_meta: ExpertTokensMetadata | None,
|
||||
chunk_topk_ids: torch.Tensor,
|
||||
local_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
) -> ExpertTokensMetadata | None:
|
||||
if num_chunks == 1 or full_expert_tokens_meta is None:
|
||||
return full_expert_tokens_meta
|
||||
|
||||
# The existing expert_num_tokens is for the entire a1q
|
||||
# input. Chunking forces recomputation of the number
|
||||
# of tokens assigned to each expert.
|
||||
c_expert_num_tokens = count_expert_num_tokens(
|
||||
chunk_topk_ids, local_num_experts, expert_map
|
||||
)
|
||||
|
||||
c_expert_num_tokens_cpu = None
|
||||
need_expert_num_tokens_cpu = (
|
||||
full_expert_tokens_meta.expert_num_tokens_cpu is not None
|
||||
)
|
||||
if need_expert_num_tokens_cpu:
|
||||
# This is blocking as some implementations need the count
|
||||
# on the CPU to determine appropriate input/out fused-moe
|
||||
# buffers
|
||||
c_expert_num_tokens_cpu = c_expert_num_tokens.to("cpu", non_blocking=False)
|
||||
|
||||
return ExpertTokensMetadata(
|
||||
expert_num_tokens=c_expert_num_tokens,
|
||||
expert_num_tokens_cpu=c_expert_num_tokens_cpu,
|
||||
)
|
||||
|
||||
def _prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -1318,18 +1174,6 @@ class FusedMoEKernelModularImpl:
|
||||
a1q, w1, w2, topk_ids
|
||||
)
|
||||
|
||||
num_chunks, CHUNK_SIZE = self._chunk_info(M_full)
|
||||
|
||||
def input_chunk_range(chunk_idx: int) -> tuple[int, int]:
|
||||
if num_chunks == 1:
|
||||
# Use a1q.size(0) here since batched format does not
|
||||
# keep M in the first dimension.
|
||||
return 0, a1q.size(0)
|
||||
else:
|
||||
s = chunk_idx * CHUNK_SIZE
|
||||
e = min(s + CHUNK_SIZE, M_full)
|
||||
return s, e
|
||||
|
||||
# This happens when none of the tokens from the all2all reach this
|
||||
# EP rank. Also, note that this is only relevant for CUDAGraph
|
||||
# incompatible all2all kernels like the DeepEP high-throughput
|
||||
@@ -1337,58 +1181,39 @@ class FusedMoEKernelModularImpl:
|
||||
# low-latency kernels are always batched and can never run into
|
||||
# the tensor.numel() == 0 case.
|
||||
if M_full == 0:
|
||||
assert num_chunks == 0
|
||||
workspace13 = None
|
||||
workspace2 = None
|
||||
fused_out = torch.empty_like(a1q, dtype=in_dtype)
|
||||
else:
|
||||
assert num_chunks > 0
|
||||
workspace13, workspace2, fused_out = self._allocate_buffers(
|
||||
in_dtype,
|
||||
a1q.device,
|
||||
CHUNK_SIZE,
|
||||
M_full,
|
||||
N,
|
||||
K,
|
||||
top_k,
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
activation,
|
||||
)
|
||||
return torch.empty_like(a1q, dtype=in_dtype)
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
s, e = input_chunk_range(chunk_idx)
|
||||
workspace13, workspace2, fused_out = self._allocate_buffers(
|
||||
in_dtype,
|
||||
a1q.device,
|
||||
M_full,
|
||||
M_full,
|
||||
N,
|
||||
K,
|
||||
top_k,
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
activation,
|
||||
)
|
||||
|
||||
c_expert_tokens_meta = self._slice_expert_tokens_metadata(
|
||||
num_chunks,
|
||||
expert_tokens_meta,
|
||||
topk_ids[s:e],
|
||||
local_num_experts,
|
||||
expert_map,
|
||||
)
|
||||
|
||||
c_fused_out = self._slice_output_tensor(
|
||||
fused_out, chunk_idx, num_chunks, CHUNK_SIZE, M_full
|
||||
)
|
||||
|
||||
self.fused_experts.apply(
|
||||
output=c_fused_out,
|
||||
hidden_states=a1q[s:e],
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights[s:e],
|
||||
topk_ids=topk_ids[s:e],
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
a1q_scale=_slice_scales(a1q_scale, s, e),
|
||||
a2_scale=_slice_scales(self.fused_experts.a2_scale, s, e),
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_tokens_meta=c_expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
self.fused_experts.apply(
|
||||
output=fused_out,
|
||||
hidden_states=a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=self.fused_experts.a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
return fused_out
|
||||
|
||||
|
||||
Reference in New Issue
Block a user