[Kernels] MoE refactor (#19636)

Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Co-authored-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
bnellnm
2025-07-02 09:08:27 -04:00
committed by GitHub
parent b95877509b
commit c1909e7e8c
36 changed files with 2698 additions and 1584 deletions

View File

@@ -1,12 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from enum import Enum
from math import prod
from typing import Optional
from typing import Optional, final
import torch
import vllm.envs as envs
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.utils import cdiv
@@ -82,6 +84,18 @@ def _moe_problem_size(
return E, M, N, K, topk
class FusedMoEActivationFormat(Enum):
"""
The standard activation format (num_tokens, hidden dim).
"""
Standard = "standard",
"""
The batched experts format (num experts, max tokens per expert, hidden dim)
"""
BatchedExperts = "batched_experts",
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
class FusedMoEPrepareAndFinalize(ABC):
"""
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
@@ -99,6 +113,7 @@ class FusedMoEPrepareAndFinalize(ABC):
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
@@ -148,6 +163,15 @@ class FusedMoEPrepareAndFinalize(ABC):
"""
raise NotImplementedError
@property
@abstractmethod
def activation_format(self) -> FusedMoEActivationFormat:
"""
A property indicating the output format of the activations for the
'prepare' method.
"""
raise NotImplementedError
@abstractmethod
def topk_indices_dtype(self) -> Optional[torch.dtype]:
"""
@@ -176,6 +200,41 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
above.
"""
def __init__(
self,
quant_config: Optional[FusedMoEQuantConfig],
):
if quant_config is not None:
self.quant_config = quant_config
else:
self.quant_config = FusedMoEQuantConfig()
@property
@abstractmethod
def activation_formats(
self) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]:
"""
A property which is a tuple of the input and output activation formats
for the 'apply' method.
"""
raise NotImplementedError
@property
def quant_dtype(self) -> Optional[torch.dtype]:
return self.quant_config.quant_dtype
@property
def block_shape(self) -> Optional[list[int]]:
return self.quant_config.block_shape
@property
def per_act_token_quant(self) -> bool:
return self.quant_config.per_act_token_quant
@property
def per_out_ch_quant(self) -> bool:
return self.quant_config.per_out_ch_quant
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
@abstractmethod
def supports_chunking(self) -> bool:
@@ -185,6 +244,13 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
"""
raise NotImplementedError
@abstractmethod
def supports_expert_map(self) -> bool:
"""
A flag indicating whether or not this class supports expert maps
"""
raise NotImplementedError
@abstractmethod
def workspace_shapes(
self,
@@ -297,6 +363,7 @@ def _chunk_scales(scales: Optional[torch.Tensor], start: int,
return None
@final
class FusedMoEModularKernel(torch.nn.Module):
"""
This class combines a FusedMoEPrepareAndFinalize instance and
@@ -318,6 +385,12 @@ class FusedMoEModularKernel(torch.nn.Module):
super().__init__()
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
assert prepare_finalize.activation_format == \
fused_experts.activation_formats[0], (
f"{prepare_finalize.__class__.__name__}."
f"{prepare_finalize.activation_format} == "
f"{fused_experts.__class__.__name__}."
f"{fused_experts.activation_formats[0]}")
def forward(
self,
@@ -383,8 +456,16 @@ class FusedMoEModularKernel(torch.nn.Module):
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare(
a1, a1_scale, a2_scale, topk_weights, topk_ids,
global_num_experts, expert_map, apply_router_weight_on_input)
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
global_num_experts,
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
)
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids