[Kernel] Enable fp8 support for pplx and BatchedTritonExperts. (#18864)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-07-03 17:55:40 -04:00
committed by GitHub
parent 2f2fcb31b8
commit 78fe77534b
25 changed files with 1277 additions and 663 deletions

View File

@@ -41,10 +41,7 @@ def run_cutlass_moe_fp8(
assert w2_scale is not None
assert w1.dtype == torch.float8_e4m3fn
assert w2.dtype == torch.float8_e4m3fn
if expert_num_tokens is None:
assert a1q.size(1) == w1.size(2), "Hidden size mismatch w1"
else:
assert a1q.size(2) == w1.size(2), "Hidden size mismatch w1"
assert a1q.size(-1) == w1.size(2), "Hidden size mismatch w1"
assert w1.size(1) == w2.size(2) * 2, "Hidden size mismatch w2"
assert w1_scale.dim() == 1 or w1_scale.size(
1) == 1 or w1_scale.shape[1] == w1.size(1), "W1 scale shape mismatch"
@@ -178,6 +175,8 @@ def run_cutlass_moe_fp8(
c2 = _resize_cache(workspace2, (M * topk, N))
c3 = _resize_cache(workspace13, (M * topk, K))
c1.fill_(0)
ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets,
problem_sizes1, ab_strides1, ab_strides1, c_strides1,
per_act_token, per_out_ch)
@@ -213,6 +212,7 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
per_act_token_quant: bool,
per_out_ch_quant: bool,
block_shape: Optional[list[int]] = None,
num_dispatchers: Optional[int] = None,
use_batched_format: bool = False,
):
super().__init__(
@@ -223,7 +223,9 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
block_shape=block_shape,
))
assert max_experts_per_worker > 0
assert not use_batched_format or num_dispatchers is not None
self.max_experts_per_worker = max_experts_per_worker
self.num_dispatchers = num_dispatchers
self.out_dtype = out_dtype
self.use_batched_format = use_batched_format
@@ -260,8 +262,12 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
output: tuple[int, ...] = ()
if self.use_batched_format:
padded_M = aq.size(1)
workspace1 = (self.max_experts_per_worker, padded_M, max(N, K))
workspace2 = (self.max_experts_per_worker, padded_M, (N // 2))
num_dp = self.num_dispatchers
assert num_dp is not None
workspace1 = (self.max_experts_per_worker, padded_M * num_dp,
max(N, K))
workspace2 = (self.max_experts_per_worker, padded_M * num_dp,
(N // 2))
output = (self.max_experts_per_worker, padded_M, K)
else:
workspace1 = (M * topk, max(2 * N, K))