[Kernels] Add activation chunking logic to FusedMoEModularKernel (#19168)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-06-11 12:53:10 -04:00
committed by GitHub
parent b2d9be6f7d
commit 29fa5cac1c
15 changed files with 458 additions and 396 deletions

View File

@@ -1,10 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from math import prod
from typing import Optional
import torch
import vllm.envs as envs
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.utils import cdiv
#
# This file defines a set of base classes used to make MoE kernels more modular.
# The goal is to be able to utilize different communication mechanisms with
@@ -115,9 +120,9 @@ class FusedMoEPrepareAndFinalize(ABC):
- quantized + dispatched a.
- quantized + dispatched a1_scales.
- Optional tensor as big as number of local experts that contains the
number of tokens assigned to each local expert.
number of tokens assigned to each local expert.
- Optional dispatched expert topk IDs
- Optional dispatched expert topk weight
- Optional dispatched expert topk weight
"""
raise NotImplementedError
@@ -159,7 +164,7 @@ class FusedMoEPrepareAndFinalize(ABC):
Some PrepareFinalize All2All implementations are batched. Meaning,
they can processes only as set of tokens at a time. This
function returns the batch size i.e the maximum number of tokens
the implementation can process at a time.
the implementation can process at a time.
Return None if there are no such restrictions.
"""
raise NotImplementedError
@@ -171,6 +176,15 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
above.
"""
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
@abstractmethod
def supports_chunking(self) -> bool:
"""
A flag indicating whether or not this class supports activation
chunking.
"""
raise NotImplementedError
@abstractmethod
def workspace_shapes(
self,
@@ -181,19 +195,22 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
"""
Compute the number of elements for the temporary outputs of the two
gemms and activation in the fused expert function. Since the
gemms are independent, the workspace for the first gemm can be shared
with the workspace for the last gemm.
Compute the shapes for the temporary and final outputs of the two gemms
and activation in the fused expert function. Since the gemms are
independent, the workspace for the first gemm can be shared with the
workspace for the last gemm.
Returns a tuple of:
- Number of workspace13 elements: must be large enough to hold the
- workspace13 shape tuple: must be large enough to hold the
result of either expert gemm.
- Number of workspace2 elements: must be large enough to hold the
- workspace2 shape tuple: must be large enough to hold the
result of the activation function.
- output shape tuple: must be exact size of the final gemm output.
- Workspace type: The dtype to use for the workspace tensors.
- Note: in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens.
"""
raise NotImplementedError
@@ -210,6 +227,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
@abstractmethod
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
@@ -226,12 +244,13 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
):
"""
This function computes the intermediate result of a Mixture of Experts
(MoE) layer using two sets of weights, w1 and w2.
Parameters:
- output: (torch.Tensor): The unweighted, unreduced output tensor.
- hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE
layer.
- w1 (torch.Tensor): The first set of expert weights.
@@ -259,13 +278,20 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
function.
- expert_num_tokens: An optional tensor containing the number of tokens
assigned to each expert when using batched experts format input.
Returns:
- torch.Tensor: The unweighted, unreduced output tensor
"""
raise NotImplementedError
def _chunk_scales(scales: Optional[torch.Tensor], start: int,
end: int) -> Optional[torch.Tensor]:
if scales is not None:
if scales.numel() == 1:
return scales
else:
return scales[start:end]
return None
class FusedMoEModularKernel(torch.nn.Module):
"""
This class combines a FusedMoEPrepareAndFinalize instance and
@@ -288,61 +314,6 @@ class FusedMoEModularKernel(torch.nn.Module):
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
def _do_fused_experts(
self,
a1: torch.Tensor, # input to forward fn
a1q: torch.Tensor, # output of prepare fn
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
expert_num_tokens: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor]) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
# Use a1 here to decipher the correct workspace datatype
workspace13_shape, workspace2_shape, workspace_dtype = (
self.fused_experts.workspace_shapes(a1, a1q, M, N, K, top_k,
global_num_experts))
# We can reuse the memory between cache1 and cache3 because by the time
# we need cache3, we're done with cache1
workspace13 = torch.zeros(workspace13_shape,
device=a1.device,
dtype=workspace_dtype)
workspace2 = torch.zeros(workspace2_shape,
device=a1.device,
dtype=workspace_dtype)
fused_out = self.fused_experts.apply(
a1q,
w1,
w2,
topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
)
return fused_out
def forward(
self,
hidden_states: torch.Tensor,
@@ -408,12 +379,14 @@ class FusedMoEModularKernel(torch.nn.Module):
_expert_topk_weights) = self.prepare_finalize.prepare(
a1, a1_scale, a2_scale, topk_weights, topk_ids,
global_num_experts, expert_map, apply_router_weight_on_input)
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
topk_weights = (topk_weights if _expert_topk_weights is None else
_expert_topk_weights)
fused_out = None
if a1q.numel() == 0:
# This happens when none of the tokens from the all2all reach this
# EP rank. Also, note that this is only relevant for CUDAGraph
@@ -423,22 +396,107 @@ class FusedMoEModularKernel(torch.nn.Module):
# and can never run into the tensor.numel() == 0 case.
fused_out = torch.empty_like(a1q).to(dtype=a1.dtype)
else:
fused_out = self._do_fused_experts(
a1=a1,
a1q=a1q,
w1=w1,
w2=w2,
topk_ids=topk_ids,
expert_num_tokens=expert_num_tokens,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale)
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
if self.fused_experts.supports_chunking():
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_chunks = cdiv(M, CHUNK_SIZE)
else:
CHUNK_SIZE = M
num_chunks = 1
if num_chunks == 1:
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts)
else:
# Use the full M to get the final output shape.
_, _, fused_out_shape, _ = (
self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts))
# Use the CHUNK_SIZE to get the workspace shapes.
workspace13_shape, workspace2_shape, _, workspace_dtype = (
self.fused_experts.workspace_shapes(
a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts))
# We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1.
workspace13 = torch.zeros(prod(workspace13_shape),
device=a1.device,
dtype=workspace_dtype)
workspace2 = torch.zeros(prod(workspace2_shape),
device=a1.device,
dtype=workspace_dtype)
if num_chunks == 1:
fused_out = _resize_cache(workspace13, fused_out_shape)
self.fused_experts.apply(
fused_out,
a1q,
w1,
w2,
topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
)
else:
# The leading output dimension may not be equal to M, so
# we compute output indices separately.
M_out = fused_out_shape[0]
assert M_out >= M
factor = M_out // M
assert factor > 0
OUT_CHUNK_SIZE = CHUNK_SIZE * factor
fused_out = torch.empty(fused_out_shape,
device=a1q.device,
dtype=workspace_dtype)
assert cdiv(M_out, OUT_CHUNK_SIZE) == num_chunks, (
f"{cdiv(M_out, OUT_CHUNK_SIZE)} == {num_chunks}")
for chunk in range(num_chunks):
begin_chunk_idx = chunk * CHUNK_SIZE
end_chunk_idx = min((chunk + 1) * CHUNK_SIZE, M)
begin_out_idx = chunk * OUT_CHUNK_SIZE
end_out_idx = min((chunk + 1) * OUT_CHUNK_SIZE, M_out)
curr_a1q = a1q[begin_chunk_idx:end_chunk_idx]
curr_a1q_scale = _chunk_scales(a1q_scale, begin_chunk_idx,
end_chunk_idx)
curr_a2_scale = _chunk_scales(a2_scale, begin_chunk_idx,
end_chunk_idx)
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
self.fused_experts.apply(
fused_out[begin_out_idx:end_out_idx],
curr_a1q,
w1,
w2,
curr_topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=curr_a1q_scale,
a2_scale=curr_a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
)
self.prepare_finalize.finalize(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input)