[Kernels] MoE refactor (#19636)
Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from math import prod
|
||||
from typing import Optional
|
||||
from typing import Optional, final
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.utils import cdiv
|
||||
|
||||
@@ -82,6 +84,18 @@ def _moe_problem_size(
|
||||
return E, M, N, K, topk
|
||||
|
||||
|
||||
class FusedMoEActivationFormat(Enum):
|
||||
"""
|
||||
The standard activation format (num_tokens, hidden dim).
|
||||
"""
|
||||
Standard = "standard",
|
||||
"""
|
||||
The batched experts format (num experts, max tokens per expert, hidden dim)
|
||||
"""
|
||||
BatchedExperts = "batched_experts",
|
||||
|
||||
|
||||
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
|
||||
class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
|
||||
@@ -99,6 +113,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
@@ -148,6 +163,15 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def activation_format(self) -> FusedMoEActivationFormat:
|
||||
"""
|
||||
A property indicating the output format of the activations for the
|
||||
'prepare' method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def topk_indices_dtype(self) -> Optional[torch.dtype]:
|
||||
"""
|
||||
@@ -176,6 +200,41 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
above.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: Optional[FusedMoEQuantConfig],
|
||||
):
|
||||
if quant_config is not None:
|
||||
self.quant_config = quant_config
|
||||
else:
|
||||
self.quant_config = FusedMoEQuantConfig()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def activation_formats(
|
||||
self) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]:
|
||||
"""
|
||||
A property which is a tuple of the input and output activation formats
|
||||
for the 'apply' method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def quant_dtype(self) -> Optional[torch.dtype]:
|
||||
return self.quant_config.quant_dtype
|
||||
|
||||
@property
|
||||
def block_shape(self) -> Optional[list[int]]:
|
||||
return self.quant_config.block_shape
|
||||
|
||||
@property
|
||||
def per_act_token_quant(self) -> bool:
|
||||
return self.quant_config.per_act_token_quant
|
||||
|
||||
@property
|
||||
def per_out_ch_quant(self) -> bool:
|
||||
return self.quant_config.per_out_ch_quant
|
||||
|
||||
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
|
||||
@abstractmethod
|
||||
def supports_chunking(self) -> bool:
|
||||
@@ -185,6 +244,13 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def supports_expert_map(self) -> bool:
|
||||
"""
|
||||
A flag indicating whether or not this class supports expert maps
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def workspace_shapes(
|
||||
self,
|
||||
@@ -297,6 +363,7 @@ def _chunk_scales(scales: Optional[torch.Tensor], start: int,
|
||||
return None
|
||||
|
||||
|
||||
@final
|
||||
class FusedMoEModularKernel(torch.nn.Module):
|
||||
"""
|
||||
This class combines a FusedMoEPrepareAndFinalize instance and
|
||||
@@ -318,6 +385,12 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
assert prepare_finalize.activation_format == \
|
||||
fused_experts.activation_formats[0], (
|
||||
f"{prepare_finalize.__class__.__name__}."
|
||||
f"{prepare_finalize.activation_format} == "
|
||||
f"{fused_experts.__class__.__name__}."
|
||||
f"{fused_experts.activation_formats[0]}")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -383,8 +456,16 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
|
||||
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
|
||||
_expert_topk_weights) = self.prepare_finalize.prepare(
|
||||
a1, a1_scale, a2_scale, topk_weights, topk_ids,
|
||||
global_num_experts, expert_map, apply_router_weight_on_input)
|
||||
a1,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
|
||||
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
|
||||
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
|
||||
|
||||
Reference in New Issue
Block a user