[Kernels] Modular kernel refactor (#24812)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-10-08 17:51:52 -04:00
committed by GitHub
parent f08919b7d1
commit da364615fc
22 changed files with 665 additions and 573 deletions

View File

@@ -1954,8 +1954,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
@@ -1963,11 +1961,11 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
workspace1 = (M, topk, max(N // 2, K))
workspace2 = (M, topk, max(N, K))
output = (M, K)
return (workspace1, workspace2, output, a.dtype)
return (workspace1, workspace2, output)
def apply(
self,