[Kernel] Enable fp8 support for pplx and BatchedTritonExperts. (#18864)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -41,10 +41,7 @@ def run_cutlass_moe_fp8(
|
||||
assert w2_scale is not None
|
||||
assert w1.dtype == torch.float8_e4m3fn
|
||||
assert w2.dtype == torch.float8_e4m3fn
|
||||
if expert_num_tokens is None:
|
||||
assert a1q.size(1) == w1.size(2), "Hidden size mismatch w1"
|
||||
else:
|
||||
assert a1q.size(2) == w1.size(2), "Hidden size mismatch w1"
|
||||
assert a1q.size(-1) == w1.size(2), "Hidden size mismatch w1"
|
||||
assert w1.size(1) == w2.size(2) * 2, "Hidden size mismatch w2"
|
||||
assert w1_scale.dim() == 1 or w1_scale.size(
|
||||
1) == 1 or w1_scale.shape[1] == w1.size(1), "W1 scale shape mismatch"
|
||||
@@ -178,6 +175,8 @@ def run_cutlass_moe_fp8(
|
||||
c2 = _resize_cache(workspace2, (M * topk, N))
|
||||
c3 = _resize_cache(workspace13, (M * topk, K))
|
||||
|
||||
c1.fill_(0)
|
||||
|
||||
ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets,
|
||||
problem_sizes1, ab_strides1, ab_strides1, c_strides1,
|
||||
per_act_token, per_out_ch)
|
||||
@@ -213,6 +212,7 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
num_dispatchers: Optional[int] = None,
|
||||
use_batched_format: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
@@ -223,7 +223,9 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
block_shape=block_shape,
|
||||
))
|
||||
assert max_experts_per_worker > 0
|
||||
assert not use_batched_format or num_dispatchers is not None
|
||||
self.max_experts_per_worker = max_experts_per_worker
|
||||
self.num_dispatchers = num_dispatchers
|
||||
self.out_dtype = out_dtype
|
||||
self.use_batched_format = use_batched_format
|
||||
|
||||
@@ -260,8 +262,12 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
output: tuple[int, ...] = ()
|
||||
if self.use_batched_format:
|
||||
padded_M = aq.size(1)
|
||||
workspace1 = (self.max_experts_per_worker, padded_M, max(N, K))
|
||||
workspace2 = (self.max_experts_per_worker, padded_M, (N // 2))
|
||||
num_dp = self.num_dispatchers
|
||||
assert num_dp is not None
|
||||
workspace1 = (self.max_experts_per_worker, padded_M * num_dp,
|
||||
max(N, K))
|
||||
workspace2 = (self.max_experts_per_worker, padded_M * num_dp,
|
||||
(N // 2))
|
||||
output = (self.max_experts_per_worker, padded_M, K)
|
||||
else:
|
||||
workspace1 = (M * topk, max(2 * N, K))
|
||||
|
||||
Reference in New Issue
Block a user