[Kernel] Enable fp8 support for pplx and BatchedTritonExperts. (#18864)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-07-03 17:55:40 -04:00
committed by GitHub
parent 2f2fcb31b8
commit 78fe77534b
25 changed files with 1277 additions and 663 deletions

View File

@@ -1094,6 +1094,8 @@ def torch_experts(
if expert_map is not None:
topk_ids = expert_map[topk_ids]
f32 = torch.float32
for i in range(num_experts):
mask = topk_ids == i
if mask.sum():
@@ -1109,7 +1111,8 @@ def torch_experts(
out.dtype)
tmp2 = SiluAndMul()(tmp1)
tmp2, b_scale = moe_kernel_quantize_input(
tmp2, None, quant_dtype, per_act_token_quant, block_shape)
tmp2, a2_scale, quant_dtype, per_act_token_quant,
block_shape)
out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale,
w2_scale[i], block_shape,
@@ -1117,7 +1120,6 @@ def torch_experts(
else:
assert (a_scale is not None and w1_scale is not None
and w2_scale is not None)
f32 = torch.float32
scales = a_scale if a_scale.numel() == 1 else a_scale[mask]
tmp1 = a[mask].to(f32) * scales
w1_dq = (w1[i].to(f32) * w1_scale[i]).transpose(0, 1)
@@ -1126,8 +1128,8 @@ def torch_experts(
w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1)
out[mask] = (tmp2 @ w2_dq).to(out.dtype)
return (out.view(M, -1, w2.shape[1]) *
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
return (out.view(M, -1, w2.shape[1]).to(f32) *
topk_weight.view(M, -1, 1)).sum(dim=1).to(out.dtype)
def torch_moe(a: torch.Tensor,