[Kernels] Modular kernel refactor (#24812)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-10-08 17:51:52 -04:00
committed by GitHub
parent f08919b7d1
commit da364615fc
22 changed files with 665 additions and 573 deletions

View File

@@ -366,10 +366,11 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
# topk weights and reduction are fused in moe_unpermute cuda kernel
return TopKWeightAndReduceNoOP()
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
return self.out_dtype if self.out_dtype is not None else act_dtype
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
@@ -377,16 +378,11 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
workspace1 = (M * topk, max(N, K))
workspace2 = (M * topk, max(N // 2, K))
output = (M, K)
return (
workspace1,
workspace2,
output,
self.out_dtype if self.out_dtype is not None else a.dtype,
)
return (workspace1, workspace2, output)
class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
@@ -428,11 +424,11 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
def supports_expert_map(self) -> bool:
return False
# TODO(bnell): maybe remove need for passing aq to workspace_shapes
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
return self.out_dtype if self.out_dtype is not None else act_dtype
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
@@ -440,19 +436,13 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
padded_M = aq.size(1)
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
num_dp = self.num_dispatchers
assert num_dp is not None
workspace1 = (self.max_experts_per_worker, padded_M * num_dp, max(N, K))
workspace2 = (self.max_experts_per_worker, padded_M * num_dp, max(N // 2, K))
output = (self.max_experts_per_worker, padded_M, K)
return (
workspace1,
workspace2,
output,
self.out_dtype if self.out_dtype is not None else a.dtype,
)
workspace1 = (self.max_experts_per_worker, M * num_dp, max(N, K))
workspace2 = (self.max_experts_per_worker, M * num_dp, max(N // 2, K))
output = (self.max_experts_per_worker, M, K)
return (workspace1, workspace2, output)
def cutlass_moe_fp8(
@@ -767,10 +757,11 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
return self.out_dtype if self.out_dtype is not None else act_dtype
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
@@ -778,25 +769,19 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
workspace1: tuple[int, ...] = ()
workspace2: tuple[int, ...] = ()
output: tuple[int, ...] = ()
if self.use_batched_format:
padded_M = aq.size(1)
workspace1 = (self.max_experts_per_worker, padded_M, max(N, K))
workspace2 = (self.max_experts_per_worker, padded_M, (N // 2))
output = (self.max_experts_per_worker, padded_M, K)
workspace1 = (self.max_experts_per_worker, M, max(N, K))
workspace2 = (self.max_experts_per_worker, M, (N // 2))
output = (self.max_experts_per_worker, M, K)
else:
workspace1 = (M * topk, max(2 * N, K))
workspace2 = (M * topk, N)
output = (M, K)
return (
workspace1,
workspace2,
output,
self.out_dtype if self.out_dtype is not None else a.dtype,
)
return (workspace1, workspace2, output)
def apply(
self,