[Kernel] DeepGemm MoE : Integrate triton permute / unpermute kernels (#20903)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-07-17 13:40:37 +05:30
committed by GitHub
parent fdc5b43d20
commit 11dfdf21bf
10 changed files with 490 additions and 58 deletions

View File

@@ -317,6 +317,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
"""
Compute the shapes for the temporary and final outputs of the two gemms
@@ -479,7 +480,8 @@ class FusedMoEModularKernel(torch.nn.Module):
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts)
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
expert_tokens_meta)
# We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1.
@@ -572,10 +574,9 @@ class FusedMoEModularKernel(torch.nn.Module):
assert num_chunks > 1
# Construct the entire output that can then be processed in chunks.
(_, _, fused_out_shape,
_) = self.fused_experts.workspace_shapes(a1, a1q, M, N, K, top_k,
global_num_experts,
local_num_experts)
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
expert_tokens_meta)
fused_out = torch.empty(fused_out_shape,
device=a1q.device,
dtype=a1.dtype)
@@ -613,8 +614,11 @@ class FusedMoEModularKernel(torch.nn.Module):
need_expert_num_tokens_cpu = (
full_expert_tokens_meta.expert_num_tokens_cpu is not None)
if need_expert_num_tokens_cpu:
# This is blocking as some implementations need the count
# on the CPU to determine appropriate input/out fused-moe
# buffers
c_expert_num_tokens_cpu = c_expert_num_tokens.to(
"cpu", non_blocking=True)
"cpu", non_blocking=False)
return ExpertTokensMetadata(
expert_num_tokens=c_expert_num_tokens,