[Kernels] Modular kernel refactor (#24812)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -366,10 +366,11 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||
# topk weights and reduction are fused in moe_unpermute cuda kernel
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
|
||||
return self.out_dtype if self.out_dtype is not None else act_dtype
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
@@ -377,16 +378,11 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
workspace1 = (M * topk, max(N, K))
|
||||
workspace2 = (M * topk, max(N // 2, K))
|
||||
output = (M, K)
|
||||
return (
|
||||
workspace1,
|
||||
workspace2,
|
||||
output,
|
||||
self.out_dtype if self.out_dtype is not None else a.dtype,
|
||||
)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
|
||||
class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
||||
@@ -428,11 +424,11 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
# TODO(bnell): maybe remove need for passing aq to workspace_shapes
|
||||
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
|
||||
return self.out_dtype if self.out_dtype is not None else act_dtype
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
@@ -440,19 +436,13 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
padded_M = aq.size(1)
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
num_dp = self.num_dispatchers
|
||||
assert num_dp is not None
|
||||
workspace1 = (self.max_experts_per_worker, padded_M * num_dp, max(N, K))
|
||||
workspace2 = (self.max_experts_per_worker, padded_M * num_dp, max(N // 2, K))
|
||||
output = (self.max_experts_per_worker, padded_M, K)
|
||||
return (
|
||||
workspace1,
|
||||
workspace2,
|
||||
output,
|
||||
self.out_dtype if self.out_dtype is not None else a.dtype,
|
||||
)
|
||||
workspace1 = (self.max_experts_per_worker, M * num_dp, max(N, K))
|
||||
workspace2 = (self.max_experts_per_worker, M * num_dp, max(N // 2, K))
|
||||
output = (self.max_experts_per_worker, M, K)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
|
||||
def cutlass_moe_fp8(
|
||||
@@ -767,10 +757,11 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
|
||||
return self.out_dtype if self.out_dtype is not None else act_dtype
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
@@ -778,25 +769,19 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
workspace1: tuple[int, ...] = ()
|
||||
workspace2: tuple[int, ...] = ()
|
||||
output: tuple[int, ...] = ()
|
||||
if self.use_batched_format:
|
||||
padded_M = aq.size(1)
|
||||
workspace1 = (self.max_experts_per_worker, padded_M, max(N, K))
|
||||
workspace2 = (self.max_experts_per_worker, padded_M, (N // 2))
|
||||
output = (self.max_experts_per_worker, padded_M, K)
|
||||
workspace1 = (self.max_experts_per_worker, M, max(N, K))
|
||||
workspace2 = (self.max_experts_per_worker, M, (N // 2))
|
||||
output = (self.max_experts_per_worker, M, K)
|
||||
else:
|
||||
workspace1 = (M * topk, max(2 * N, K))
|
||||
workspace2 = (M * topk, N)
|
||||
output = (M, K)
|
||||
return (
|
||||
workspace1,
|
||||
workspace2,
|
||||
output,
|
||||
self.out_dtype if self.out_dtype is not None else a.dtype,
|
||||
)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user