[MoE Refactor] Create MK for TRTLLM Kernels (#32564)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Robert Shaw <rshaw@neuralmagic.com> Signed-off-by: Robert Shaw <robertgshaw2@gmail.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
@@ -20,6 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache,
|
||||
@@ -56,25 +57,25 @@ logger = init_logger(__name__)
|
||||
# MoE kernel implementations.
|
||||
#
|
||||
# The following main classes are defined:
|
||||
# * FusedMoEPrepareAndFinalize - an abstract base class for preparation of MoE
|
||||
# * FusedMoEPrepareAndFinalizeModular - an abstract base class for preparation of MoE
|
||||
# inputs (e.g. quantization, distribution) and finalization of Moe outputs.
|
||||
# The prepare method must take care of any needed quantization and the
|
||||
# finalize method, informed by the FusedMoEPermuteExpertsUnpermute method,
|
||||
# finalize method, informed by the FusedMoEExpertsModular method,
|
||||
# may apply weights and/or do the final reduction of the output.
|
||||
# * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused
|
||||
# * FusedMoEExpertsModular - an abstract base class for the main fused
|
||||
# MoE operation, i.e matmul + act_mul + optionally quant + matmul.
|
||||
# Some FusedMoEPermuteExpertsUnpermute implementations may choose to do
|
||||
# Some FusedMoEExpertsModular implementations may choose to do
|
||||
# the weight application and/or reduction. The class communicates this
|
||||
# to [Finalize] via a TopKWeightAndReduce object.
|
||||
# * FusedMoEModularKernel - an interface class that combines a
|
||||
# FusedMoEPrepareAndFinalize and a FusedMoEPermuteExpertsUnpermute to
|
||||
# FusedMoEPrepareAndFinalizeModular and a FusedMoEExpertsModular to
|
||||
# provide the standard fused MoE kernel interface.
|
||||
# * TopKWeightAndReduce - A TopKWeightAndReduce implementation chosen
|
||||
# by the FusedMoEPermuteExpertsUnpermute implementation that is passed
|
||||
# by the FusedMoEExpertsModular implementation that is passed
|
||||
# on to [Finalize].
|
||||
#
|
||||
# [Quantize-Prepare] and [Finalize] functionality are bundled into a single
|
||||
# class `FusedMoEPrepareAndFinalize` since they could use collective
|
||||
# class `FusedMoEPrepareAndFinalizeModular` since they could use collective
|
||||
# communication mechanisms that need to be consistent.
|
||||
#
|
||||
|
||||
@@ -155,25 +156,96 @@ PrepareResultType = tuple[
|
||||
torch.Tensor | None,
|
||||
]
|
||||
|
||||
#
|
||||
# PrepareResultType is a tuple of:
|
||||
# - quantized + dispatched a.
|
||||
# - quantized + dispatched a1_scales.
|
||||
# - dispatched router logits.
|
||||
#
|
||||
# See `prepare_monolithic` method below.
|
||||
#
|
||||
PrepareMonolithicResultType = tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor,
|
||||
]
|
||||
|
||||
ReceiverType = Callable[[], PrepareResultType]
|
||||
|
||||
################################################################################
|
||||
# Prepare/Finalize
|
||||
################################################################################
|
||||
|
||||
|
||||
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
|
||||
class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
|
||||
described above.
|
||||
|
||||
There are two variants of this class:
|
||||
* FusedMoEPrepareAndFinalizeModular - this operates on topk ids and weights
|
||||
* FusedMoEPrepareAndFinalizeMonolithic - the operates on router_logits
|
||||
"""
|
||||
|
||||
def post_init_setup(self, fused_experts: "FusedMoEPermuteExpertsUnpermute"):
|
||||
def post_init_setup(self, fused_experts: "FusedMoEExperts"):
|
||||
"""
|
||||
Initialize FusedMoEPrepareAndFinalize settings that depend on
|
||||
FusedMoEPermuteExpertsUnpermute experts object.
|
||||
The FusedMoEPrepareAndFinalize implementations that have such
|
||||
Initialize FusedMoEPrepareAndFinalizeModular settings that depend on
|
||||
FusedMoEExpertsModular experts object.
|
||||
The FusedMoEPrepareAndFinalizeModular implementations that have such
|
||||
dependencies may choose to override this function.
|
||||
"""
|
||||
return
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def activation_format(self) -> FusedMoEActivationFormat:
|
||||
"""
|
||||
A property indicating the output format of the activations for the
|
||||
'prepare' method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
"""
|
||||
The PrepareFinalize All2All implementations generally constrain the
|
||||
dtype of the topk_ids they support. This function returns the
|
||||
required topk indices dtype so it can be respected.
|
||||
Return None if there are no such restrictions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
"""
|
||||
Some PrepareFinalize All2All implementations are batched. Meaning,
|
||||
they can process only as set of tokens at a time. This
|
||||
function returns the batch size i.e the maximum number of tokens
|
||||
the implementation can process at a time.
|
||||
Return None if there are no such restrictions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def num_dispatchers(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def output_is_reduced(self) -> bool:
|
||||
"""
|
||||
Indicates whether or not the output of finalize is reduced across all
|
||||
ranks.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
|
||||
class FusedMoEPrepareAndFinalizeModular(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
|
||||
described above for the Modular case.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def prepare(
|
||||
self,
|
||||
@@ -198,7 +270,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
activations, before quantization + dispatching.
|
||||
- quant_config: Quantization info provided by the fused experts.
|
||||
- defer_input_quant: Runtime parameter indicating whether or not to
|
||||
defer input quantization to the FusedMoEPermuteExpertsUnpermute
|
||||
defer input quantization to the FusedMoEExpertsModular
|
||||
in cases where the compute kernel expects unquantized inputs
|
||||
|
||||
Returns a tuple of:
|
||||
@@ -245,7 +317,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
- apply_router_weight_on_input: When True, apply the weights to the
|
||||
activations, before quantization + dispatching.
|
||||
- defer_input_quant: Runtime parameter indicating whether or not to
|
||||
defer input quantization to the FusedMoEPermuteExpertsUnpermute
|
||||
defer input quantization to the FusedMoEExpertsModular
|
||||
in cases where the compute kernel expects unquantized inputs
|
||||
|
||||
Returns a callback or a hook callback pair that when invoked waits for
|
||||
@@ -338,56 +410,58 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
|
||||
class FusedMoEPrepareAndFinalizeMonolithic(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
|
||||
described above for the monolithic case.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def activation_format(self) -> FusedMoEActivationFormat:
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> PrepareMonolithicResultType:
|
||||
"""
|
||||
A property indicating the output format of the activations for the
|
||||
'prepare' method.
|
||||
Optional method for subclasses compatible with monolithic
|
||||
FusedMoEExpertsModular kernels.
|
||||
|
||||
Perform any quantization (and/or) dispatching needed for this kernel.
|
||||
- a1: The (unquantized) input to the MoE layer.
|
||||
- quant_config: Quantization info provided by the fused experts.
|
||||
- defer_input_quant: Runtime parameter indicating whether or not to
|
||||
defer input quantization to the FusedMoEExpertsModular
|
||||
|
||||
Returns a tuple of:
|
||||
- quantized + dispatched a.
|
||||
- Optional quantized + dispatched a1_scales.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
def finalize(self, fused_expert_output: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
The PrepareFinalize All2All implementations generally constrain the
|
||||
dtype of the topk_ids they support. This function returns the
|
||||
required topk indices dtype so it can be respected.
|
||||
Return None if there are no such restrictions.
|
||||
Optional method for subclasses compatible with monolithic
|
||||
FusedMoEExpertsModular kernels.
|
||||
|
||||
Perform any combine plus apply weights and perform a reduction on the
|
||||
fused experts output.
|
||||
- fused_expert_output: The unweighted, unreduced output of the fused
|
||||
experts, it will have (M, topk, K) shape.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
"""
|
||||
Some PrepareFinalize All2All implementations are batched. Meaning,
|
||||
they can process only as set of tokens at a time. This
|
||||
function returns the batch size i.e the maximum number of tokens
|
||||
the implementation can process at a time.
|
||||
Return None if there are no such restrictions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def num_dispatchers(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def output_is_reduced(self) -> bool:
|
||||
"""
|
||||
Indicates whether or not the output of finalize is reduced across all
|
||||
ranks.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
################################################################################
|
||||
# Experts
|
||||
################################################################################
|
||||
|
||||
|
||||
# TODO: add supported activations method (return string)
|
||||
class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
An abstract base class for the [Permute-Experts-Unpermute] step described
|
||||
above.
|
||||
"""
|
||||
|
||||
class FusedMoEExperts(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
@@ -419,6 +493,10 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.num_dispatchers = num_dispatchers
|
||||
|
||||
@staticmethod
|
||||
def is_monolithic() -> bool:
|
||||
raise NotImplementedError("Implemented by subclasses.")
|
||||
|
||||
@property
|
||||
def expects_unquantized_inputs(self) -> bool:
|
||||
"""
|
||||
@@ -439,49 +517,6 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def moe_problem_size(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> tuple[int, int, int, int, int]:
|
||||
"""
|
||||
Extract the MoE problem size from the given tensor arguments:
|
||||
- a: The hidden states, input to the MoE layer.
|
||||
- w1: The first set of expert weights.
|
||||
- w2: The second set of expert weights.
|
||||
- topk_ids: The topk ids.
|
||||
|
||||
Note: extracting the problem shape from the weight and activation
|
||||
tensors is not obvious. It needs to be done this way specifically
|
||||
due to subtle issues with particular kernels, e.g. the int4 kernels
|
||||
divide the trailing dimension by two, so it's not "correct" to
|
||||
extract N or K from the trailing dimension of w1 or w2. Similarly,
|
||||
some kernels transpose the weights, so this needs to be kept in mind.
|
||||
|
||||
Note: This implementation covers most cases. However, if experts
|
||||
require a specialized implementation, like MarlinExperts, they are free
|
||||
to override this function.
|
||||
"""
|
||||
assert w1.dim() == 3 and w2.dim() == 3
|
||||
E, N, _ = w1.size()
|
||||
K = a1.size(-1)
|
||||
|
||||
if a1.dim() == 2:
|
||||
# Make sure we are using the correct a1 (pre-permute).
|
||||
assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}"
|
||||
M = a1.size(0)
|
||||
else:
|
||||
assert a1.dim() == 3
|
||||
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
|
||||
M = a1.size(1) # This is max_num_tokens
|
||||
|
||||
assert topk_ids.dim() == 2
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
return E, M, N, K, topk
|
||||
|
||||
#
|
||||
# Various helpers for registering support for various features.
|
||||
# Used by the oracle to select a particular kernel for a deployment.
|
||||
@@ -489,7 +524,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
|
||||
@staticmethod
|
||||
def is_supported_config(
|
||||
cls: type["FusedMoEPermuteExpertsUnpermute"],
|
||||
cls: type["FusedMoEExperts"],
|
||||
moe_config: FusedMoEConfig,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
@@ -512,6 +547,21 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
return False, _make_reason(
|
||||
f"parallel config {moe_config.moe_parallel_config}"
|
||||
)
|
||||
elif not cls._supports_routing_method(
|
||||
moe_config.routing_method, weight_key, activation_key
|
||||
):
|
||||
return False, _make_reason(f"routing method {moe_config.routing_method}")
|
||||
elif not cls._supports_router_logits_dtype(
|
||||
moe_config.router_logits_dtype,
|
||||
moe_config.routing_method,
|
||||
):
|
||||
return False, _make_reason(
|
||||
f"router logits dtype {moe_config.router_logits_dtype}"
|
||||
)
|
||||
elif not cls._supports_shape(moe_config.hidden_dim):
|
||||
return False, _make_reason(
|
||||
f"{moe_config.hidden_dim} hidden dim is not supported"
|
||||
)
|
||||
elif activation_format != cls.activation_format():
|
||||
return False, _make_reason(f"{activation_format.value} activation format")
|
||||
return True, None
|
||||
@@ -554,10 +604,48 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
@abstractmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
"""
|
||||
Whether the kernel supports deployment in expert parallel.
|
||||
Whether the kernel supports deployment in particular parallel config.
|
||||
|
||||
Can be overriden if a kernel does not support EP, SP or some other
|
||||
configuration.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def _supports_routing_method(
|
||||
routing_method: RoutingMethodType,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""
|
||||
Whether the kernel supports a routing method (e.g. GroupedTopK).
|
||||
|
||||
Can be overriden by monolithic kernels that execute the router
|
||||
in addition to the experts if certain routers are not supported.
|
||||
"""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_router_logits_dtype(
|
||||
router_logits_dtype: torch.dtype | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""
|
||||
Whether a kernel supports a particular dtype for router logits input.
|
||||
|
||||
Can be overriden by monolithic kernels that execute the router
|
||||
in addition to the experts if certain dtypes are not supported.
|
||||
"""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_shape(hidden_dim: int) -> bool:
|
||||
"""
|
||||
Whether a kernel supports a particular shape. Can be overridden if a kernel
|
||||
has specific shape requirements.
|
||||
"""
|
||||
return True
|
||||
|
||||
#
|
||||
# Various helpers for accessing quantization parameters from the
|
||||
# quant_config.
|
||||
@@ -654,6 +742,65 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
return False
|
||||
|
||||
def enable_chunking(self):
|
||||
return (
|
||||
envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and self.supports_chunking()
|
||||
)
|
||||
|
||||
|
||||
class FusedMoEExpertsModular(FusedMoEExperts):
|
||||
"""
|
||||
An abstract base class for the [Permute-Experts-Unpermute] step described
|
||||
above.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def is_monolithic() -> bool:
|
||||
return False
|
||||
|
||||
def moe_problem_size(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> tuple[int, int, int, int, int]:
|
||||
"""
|
||||
Extract the MoE problem size from the given tensor arguments:
|
||||
- a: The hidden states, input to the MoE layer.
|
||||
- w1: The first set of expert weights.
|
||||
- w2: The second set of expert weights.
|
||||
- topk_ids: The topk ids.
|
||||
|
||||
Note: extracting the problem shape from the weight and activation
|
||||
tensors is not obvious. It needs to be done this way specifically
|
||||
due to subtle issues with particular kernels, e.g. the int4 kernels
|
||||
divide the trailing dimension by two, so it's not "correct" to
|
||||
extract N or K from the trailing dimension of w1 or w2. Similarly,
|
||||
some kernels transpose the weights, so this needs to be kept in mind.
|
||||
|
||||
Note: This implementation covers most cases. However, if experts
|
||||
require a specialized implementation, like MarlinExperts, they are free
|
||||
to override this function.
|
||||
"""
|
||||
assert w1.dim() == 3 and w2.dim() == 3
|
||||
E, N, _ = w1.size()
|
||||
K = a1.size(-1)
|
||||
|
||||
if a1.dim() == 2:
|
||||
# Make sure we are using the correct a1 (pre-permute).
|
||||
assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}"
|
||||
M = a1.size(0)
|
||||
else:
|
||||
assert a1.dim() == 3
|
||||
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
|
||||
M = a1.size(1) # This is max_num_tokens
|
||||
|
||||
assert topk_ids.dim() == 2
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
return E, M, N, K, topk
|
||||
|
||||
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
|
||||
"""
|
||||
Workspace type: The dtype to use for the workspace tensors.
|
||||
@@ -726,11 +873,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
) -> None:
|
||||
apply_moe_activation(activation, output, input)
|
||||
|
||||
def enable_chunking(self):
|
||||
return (
|
||||
envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and self.supports_chunking()
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def finalize_weight_and_reduce_impl(self) -> TopKWeightAndReduce:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -791,6 +934,67 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FusedMoEExpertsMonolithic(FusedMoEExperts):
|
||||
"""
|
||||
An abstract base class for the [Permute-Experts-Unpermute] step described
|
||||
above, but with the monolithic interface (accepts router logits
|
||||
rather than topk ids and weights).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _supports_routing_method(
|
||||
routing_method: RoutingMethodType,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""
|
||||
Whether the kernel supports a routing method (e.g. GroupedTopK).
|
||||
|
||||
Monolithic kernels should explicitly opt-in to support.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def _supports_router_logits_dtype(
|
||||
router_logits_dtype: torch.dtype | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""
|
||||
Whether the kernel supports a dtype for router logits.
|
||||
|
||||
Modular kernels should opt-in to support.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def is_monolithic() -> bool:
|
||||
return True
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Same as apply(), except uses router_logits as opposed
|
||||
to the topk_ids and topk_weights. This is useful for kernels
|
||||
with fused router and fused_experts (e.g. FLASHINFER_TRTLLM).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _slice_scales(
|
||||
scales: torch.Tensor | None, start: int, end: int
|
||||
) -> torch.Tensor | None:
|
||||
@@ -802,75 +1006,32 @@ def _slice_scales(
|
||||
return None
|
||||
|
||||
|
||||
################################################################################
|
||||
# Kernel
|
||||
################################################################################
|
||||
|
||||
|
||||
@final
|
||||
class FusedMoEModularKernel(torch.nn.Module):
|
||||
"""
|
||||
This class combines a FusedMoEPrepareAndFinalize instance and
|
||||
a FusedMoEPermuteExpertsUnpermute to provide an interface that
|
||||
is compatible with the `fused_experts` function in fused_moe.py.
|
||||
|
||||
It takes care of managing any required scratch space.
|
||||
|
||||
Note: Instances of this class should only be used for a single model
|
||||
layer due to any layer specific state that may be used by the component
|
||||
objects.
|
||||
"""
|
||||
|
||||
class FusedMoEKernelModularImpl:
|
||||
def __init__(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
fused_experts: FusedMoEPermuteExpertsUnpermute,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
|
||||
fused_experts: FusedMoEExpertsModular,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
moe_parallel_config: FusedMoEParallelConfig | None = None,
|
||||
inplace: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
self.shared_experts = shared_experts
|
||||
self.moe_parallel_config = moe_parallel_config
|
||||
self.inplace = inplace
|
||||
|
||||
# prefer an explicit FusedMoEParallelConfig when available (from
|
||||
# FusedMoE layers / tests).
|
||||
# if not provided, assume this kernel is
|
||||
# running in a non-DP+EP context
|
||||
self.moe_parallel_config: FusedMoEParallelConfig | None = moe_parallel_config
|
||||
self.is_dp_ep = (
|
||||
moe_parallel_config is not None
|
||||
and moe_parallel_config.dp_size > 1
|
||||
and moe_parallel_config.use_ep
|
||||
)
|
||||
|
||||
self._post_init_setup()
|
||||
assert (
|
||||
prepare_finalize.activation_format == fused_experts.activation_format()
|
||||
), (
|
||||
f"{prepare_finalize.__class__.__name__}."
|
||||
f"{prepare_finalize.activation_format} == "
|
||||
f"{fused_experts.__class__.__name__}."
|
||||
f"{fused_experts.activation_format()}"
|
||||
)
|
||||
|
||||
def _post_init_setup(self):
|
||||
"""
|
||||
Resolve any leftover setup dependencies between self.prepare_finalize
|
||||
and self.fused_experts here.
|
||||
"""
|
||||
self.prepare_finalize.post_init_setup(self.fused_experts)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
"""
|
||||
A flag indicating whether or not this class supports expert maps.
|
||||
"""
|
||||
return self.fused_experts.supports_expert_map()
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
"""
|
||||
Indicates whether or not the output of fused MoE kernel
|
||||
is reduced across all ranks.
|
||||
"""
|
||||
return self.prepare_finalize.output_is_reduced()
|
||||
|
||||
def _chunk_info(self, M: int) -> tuple[int, int]:
|
||||
"""
|
||||
Compute number of chunks and chunk size for given M.
|
||||
@@ -919,7 +1080,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
workspace_dtype = self.fused_experts.workspace_dtype(out_dtype)
|
||||
|
||||
# Force worst-case allocation in profiling run for
|
||||
# "mk.FusedMoEModularKernel.Standard" formats where this is only bounded
|
||||
# "mk.FusedMoEKernel.Standard" formats where this is only bounded
|
||||
# by `VLLM_FUSED_MOE_CHUNK_SIZE` and may not be seen during profiling with
|
||||
# DP+EP due to the random token routing.
|
||||
is_profile_run = (
|
||||
@@ -1313,13 +1474,13 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
assert shared_output is not None
|
||||
return shared_output, output
|
||||
|
||||
def forward(
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
activation: MoEActivation = MoEActivation.SILU,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
@@ -1334,8 +1495,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
- hidden_states: (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- topk_weights (torch.Tensor): The topk weights applied at the end of
|
||||
the layer.
|
||||
- topk_weights (torch.Tensor): The topk weights applied at the end of the layer.
|
||||
- topk_ids (torch.Tensor): A map of row to expert id.
|
||||
- activation (MoEActivation): The activation function to apply after the first
|
||||
MoE layer.
|
||||
@@ -1354,7 +1514,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
|
||||
if self.inplace:
|
||||
assert self.shared_experts is None
|
||||
assert not disable_inplace()
|
||||
@@ -1400,3 +1559,206 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
apply_router_weight_on_input,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
class FusedMoEKernelMonolithicImpl:
|
||||
def __init__(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalizeMonolithic,
|
||||
fused_experts: FusedMoEExpertsMonolithic,
|
||||
):
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Same as forward(), except uses router_logits as opposed
|
||||
to the topk_ids and topk_weights. This is used for kernels
|
||||
that have fused router + experts (e.g. FLASHINFER_TRTLLM).
|
||||
"""
|
||||
|
||||
# TODO(rob): add inplace support.
|
||||
a1q, a1q_scale, router_logits = self.prepare_finalize.prepare(
|
||||
hidden_states,
|
||||
router_logits=router_logits,
|
||||
quant_config=self.fused_experts.quant_config,
|
||||
defer_input_quant=self.fused_experts.expects_unquantized_inputs,
|
||||
)
|
||||
|
||||
fused_out = self.fused_experts.apply(
|
||||
hidden_states=a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
router_logits=router_logits,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
a1q_scale=a1q_scale,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group=num_expert_group,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
topk_group=topk_group,
|
||||
)
|
||||
|
||||
output = self.prepare_finalize.finalize(fused_out)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@final
|
||||
class FusedMoEKernel:
|
||||
def __init__(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
fused_experts: FusedMoEExperts,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
moe_parallel_config: FusedMoEParallelConfig | None = None,
|
||||
inplace: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.shared_experts = shared_experts # NOTE: check if we can remove
|
||||
|
||||
# Initialize the implementation (monolithic or modular).
|
||||
self.impl: FusedMoEKernelModularImpl | FusedMoEKernelMonolithicImpl
|
||||
if isinstance(
|
||||
prepare_finalize, FusedMoEPrepareAndFinalizeModular
|
||||
) and isinstance(fused_experts, FusedMoEExpertsModular):
|
||||
self.impl = FusedMoEKernelModularImpl(
|
||||
prepare_finalize,
|
||||
fused_experts,
|
||||
shared_experts,
|
||||
moe_parallel_config,
|
||||
inplace,
|
||||
)
|
||||
|
||||
elif isinstance(
|
||||
prepare_finalize, FusedMoEPrepareAndFinalizeMonolithic
|
||||
) and isinstance(fused_experts, FusedMoEExpertsMonolithic):
|
||||
assert shared_experts is None
|
||||
assert not inplace
|
||||
self.impl = FusedMoEKernelMonolithicImpl(
|
||||
prepare_finalize,
|
||||
fused_experts,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"prepare_finalize and fused_experts must both be either monolithic "
|
||||
f"or non-monolithic but got {prepare_finalize.__class__.__name__} "
|
||||
f"and {fused_experts.__class__.__name__}"
|
||||
)
|
||||
|
||||
self._post_init_setup()
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return isinstance(self.impl, FusedMoEKernelMonolithicImpl)
|
||||
|
||||
@property
|
||||
def prepare_finalize(self) -> FusedMoEPrepareAndFinalize:
|
||||
return self.impl.prepare_finalize
|
||||
|
||||
@property
|
||||
def fused_experts(self) -> FusedMoEExperts:
|
||||
return self.impl.fused_experts
|
||||
|
||||
def _post_init_setup(self):
|
||||
"""
|
||||
Resolve any leftover setup dependencies between self.prepare_finalize
|
||||
and self.fused_experts here.
|
||||
"""
|
||||
self.prepare_finalize.post_init_setup(self.impl.fused_experts)
|
||||
assert (
|
||||
self.prepare_finalize.activation_format
|
||||
== self.fused_experts.activation_format()
|
||||
)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
"""
|
||||
A flag indicating whether or not this class supports expert maps.
|
||||
"""
|
||||
return self.fused_experts.supports_expert_map()
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
"""
|
||||
Indicates whether or not the output of fused MoE kernel
|
||||
is reduced across all ranks.
|
||||
"""
|
||||
return self.prepare_finalize.output_is_reduced()
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert isinstance(self.impl, FusedMoEKernelMonolithicImpl)
|
||||
return self.impl.apply(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
router_logits=router_logits,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
num_expert_group=num_expert_group,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
topk_group=topk_group,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
shared_experts_input: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert isinstance(self.impl, FusedMoEKernelModularImpl)
|
||||
return self.impl.apply(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user