[Kernels] Add activation chunking logic to FusedMoEModularKernel (#19168)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-06-11 12:53:10 -04:00
committed by GitHub
parent b2d9be6f7d
commit 29fa5cac1c
15 changed files with 458 additions and 396 deletions

View File

@@ -1542,6 +1542,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int4_w4a16=use_int4_w4a16)
self.per_channel_quant = per_channel_quant
def supports_chunking(self) -> bool:
return True
def workspace_shapes(
self,
a: torch.Tensor,
@@ -1551,14 +1554,15 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
factor = num_experts if a.dim() == 3 else 1
workspace1 = M * topk * max(N * 2, K) * factor
workspace2 = M * topk * N * factor
return (workspace1, workspace2, a.dtype)
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1 = (M, topk, max(N * 2, K))
workspace2 = (M, topk, N)
output = (M, topk, K)
return (workspace1, workspace2, output, a.dtype)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
@@ -1575,7 +1579,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
):
# Check constraints.
if self.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), (
@@ -1632,8 +1636,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
(num_tokens, top_k_num, N))
intermediate_cache2 = _resize_cache(workspace2,
(num_tokens * top_k_num, N // 2))
intermediate_cache3 = _resize_cache(workspace13,
(num_tokens, top_k_num, K))
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'],
@@ -1671,7 +1673,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
invoke_fused_moe_kernel(qintermediate_cache2,
w2,
intermediate_cache3,
output,
a2q_scale,
w2_scale,
w2_zp,
@@ -1690,8 +1692,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
per_channel_quant=self.per_channel_quant,
block_shape=self.block_shape)
return intermediate_cache3
def modular_triton_fused_moe(
use_fp8_w8a8: bool,