[Kernels] Add activation chunking logic to FusedMoEModularKernel (#19168)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -1542,6 +1542,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
use_int4_w4a16=use_int4_w4a16)
|
||||
self.per_channel_quant = per_channel_quant
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return True
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
@@ -1551,14 +1554,15 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
) -> tuple[int, int, torch.dtype]:
|
||||
factor = num_experts if a.dim() == 3 else 1
|
||||
workspace1 = M * topk * max(N * 2, K) * factor
|
||||
workspace2 = M * topk * N * factor
|
||||
return (workspace1, workspace2, a.dtype)
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
workspace1 = (M, topk, max(N * 2, K))
|
||||
workspace2 = (M, topk, N)
|
||||
output = (M, topk, K)
|
||||
return (workspace1, workspace2, output, a.dtype)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
@@ -1575,7 +1579,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
):
|
||||
# Check constraints.
|
||||
if self.use_int4_w4a16:
|
||||
assert hidden_states.size(-1) // 2 == w1.size(2), (
|
||||
@@ -1632,8 +1636,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
(num_tokens, top_k_num, N))
|
||||
intermediate_cache2 = _resize_cache(workspace2,
|
||||
(num_tokens * top_k_num, N // 2))
|
||||
intermediate_cache3 = _resize_cache(workspace13,
|
||||
(num_tokens, top_k_num, K))
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
||||
moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'],
|
||||
@@ -1671,7 +1673,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
invoke_fused_moe_kernel(qintermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
output,
|
||||
a2q_scale,
|
||||
w2_scale,
|
||||
w2_zp,
|
||||
@@ -1690,8 +1692,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
per_channel_quant=self.per_channel_quant,
|
||||
block_shape=self.block_shape)
|
||||
|
||||
return intermediate_cache3
|
||||
|
||||
|
||||
def modular_triton_fused_moe(
|
||||
use_fp8_w8a8: bool,
|
||||
|
||||
Reference in New Issue
Block a user