[Misc] ModularKernel : Perform WeightAndReduce inside TritonExperts & DeepGemmExperts (#20725)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-07-15 01:17:16 +05:30
committed by GitHub
parent 8bb43b9c9e
commit c0569dbc82
9 changed files with 203 additions and 157 deletions

View File

@@ -26,7 +26,7 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, moe_kernel_quantize_input)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
@@ -1606,8 +1606,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
return TopKWeightAndReduceNoOP()
def workspace_shapes(
self,
@@ -1620,9 +1619,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1 = (M, topk, max(N * 2, K))
workspace2 = (M, topk, N)
output = (M, topk, K)
workspace1 = (M, topk, max(N // 2, K))
workspace2 = (M, topk, max(N, K))
output = (M, K)
return (workspace1, workspace2, output, a.dtype)
def apply(
@@ -1631,6 +1630,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
@@ -1644,6 +1644,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
# Check constraints.
if self.use_int4_w4a16:
@@ -1696,37 +1697,39 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
raise ValueError(
f"Unsupported compute_type: {hidden_states.dtype}")
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
intermediate_cache1 = _resize_cache(workspace13,
# Note that the output tensor might be in workspace1
intermediate_cache1 = _resize_cache(workspace2,
(num_tokens, top_k_num, N))
intermediate_cache2 = _resize_cache(workspace2,
intermediate_cache2 = _resize_cache(workspace13,
(num_tokens * top_k_num, N // 2))
intermediate_cache3 = _resize_cache(workspace2,
(num_tokens, top_k_num, K))
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'],
global_num_experts, expert_map))
invoke_fused_moe_kernel(hidden_states,
w1,
intermediate_cache1,
a1q_scale,
w1_scale,
w1_zp,
None,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
top_k_num,
config,
compute_type=compute_type,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape)
invoke_fused_moe_kernel(
hidden_states,
w1,
intermediate_cache1,
a1q_scale,
w1_scale,
w1_zp,
None, # topk_weights
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False, # mul_routed_weights
top_k_num,
config,
compute_type=compute_type,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape)
self.activation(activation, intermediate_cache2,
intermediate_cache1.view(-1, N))
@@ -1739,15 +1742,15 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
invoke_fused_moe_kernel(qintermediate_cache2,
w2,
output,
intermediate_cache3,
a2q_scale,
w2_scale,
w2_zp,
None,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
not apply_router_weight_on_input,
1,
config,
compute_type=compute_type,
@@ -1758,6 +1761,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape)
ops.moe_sum(intermediate_cache3, output)
def modular_triton_fused_moe(
use_fp8_w8a8: bool,