[Misc] MoE ModularKernel : Introduce TopKWeightAndReduce (#20648)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-07-10 17:40:38 -04:00
committed by GitHub
parent 574ad60db9
commit f0c98cae27
14 changed files with 297 additions and 59 deletions

View File

@@ -23,7 +23,7 @@ from vllm.utils import cdiv
#
# [Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine]
#
# Each component will be independent of the others except for
# Each component will be independent of (but may inform) the others except for
# [Quantize-Dispatch] and `[Combine] (see below). The components can then be
# mixed and matched with so that DP+EP can be supported easily for multiple
# MoE kernel implementations.
@@ -32,13 +32,19 @@ from vllm.utils import cdiv
# * FusedMoEPrepareAndFinalize - an abstract base class for preparation of MoE
# inputs (e.g. quantization, distribution) and finalization of Moe outputs.
# The prepare method must take care of any needed quantization and the
# finalize method must apply weights and do the final reduction of the output.
# finalize method, informed by the FusedMoEPermuteExpertsUnpermute method,
# may apply weights and/or do the final reduction of the output.
# * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused
# MoE operation. One important feature to note is that this class does not
# apply topk weights or reduce the final output.
# MoE operation, i.e matmul + act_mul + optionally quant + matmul.
# Some FusedMoEPermuteExpertsUnpermute implementations may choose to do
# the weight application and/or reduction. The class communicates this
# to [Finalize] via a TopKWeightAndReduce object.
# * FusedMoEModularKernel - an interface class that combines a
# FusedMoEPrepareAndFinalize and a FusedMoEPermuteExpertsUnpermute to
# provide the standard fused MoE kernel interface.
# * TopKWeightAndReduce - A TopKWeightAndReduce implementation chosen
# by the FusedMoEPermuteExpertsUnpermute implementation that is passed
# on to [Finalize].
#
# [Quantize-Prepare] and [Finalize] functionality are bundled into a single
# class `FusedMoEPrepareAndFinalize` since they could use collective
@@ -117,6 +123,24 @@ class ExpertTokensMetadata:
expert_num_tokens_cpu=expert_num_tokens_cpu)
class TopKWeightAndReduce(ABC):
"""
An abstract base class for weight application and reduction implementations.
"""
@abstractmethod
def apply(self, output: Optional[torch.Tensor],
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> torch.Tensor:
"""
Apply topk_weights to the fused_experts_outputs and/or reduce.
If an output tensor is not passed, it will be created in the
function.
"""
raise NotImplementedError
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
class FusedMoEPrepareAndFinalize(ABC):
"""
@@ -173,6 +197,7 @@ class FusedMoEPrepareAndFinalize(ABC):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: TopKWeightAndReduce,
) -> None:
"""
Perform any combine plus apply weights and perform a reduction on the
@@ -184,6 +209,8 @@ class FusedMoEPrepareAndFinalize(ABC):
- topk_ids: The topk_ids.
- apply_router_weight_on_input: When False, apply the weights to
fused_expert_output.
- weight_and_reduce_impl: An optional TopKWeightAndReduce
implementation.
"""
raise NotImplementedError
@@ -323,6 +350,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
return envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and \
self.supports_chunking()
def finalize_weight_and_reduce_impl(self) -> TopKWeightAndReduce:
raise NotImplementedError
@abstractmethod
def apply(
self,
@@ -702,7 +732,9 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta)
self.prepare_finalize.finalize(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input)
self.prepare_finalize.finalize(
output, fused_out, topk_weights, topk_ids,
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl())
return output