[Kernel] Enable fp8 support for pplx and BatchedTritonExperts. (#18864)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -1094,6 +1094,8 @@ def torch_experts(
|
||||
if expert_map is not None:
|
||||
topk_ids = expert_map[topk_ids]
|
||||
|
||||
f32 = torch.float32
|
||||
|
||||
for i in range(num_experts):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
@@ -1109,7 +1111,8 @@ def torch_experts(
|
||||
out.dtype)
|
||||
tmp2 = SiluAndMul()(tmp1)
|
||||
tmp2, b_scale = moe_kernel_quantize_input(
|
||||
tmp2, None, quant_dtype, per_act_token_quant, block_shape)
|
||||
tmp2, a2_scale, quant_dtype, per_act_token_quant,
|
||||
block_shape)
|
||||
|
||||
out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale,
|
||||
w2_scale[i], block_shape,
|
||||
@@ -1117,7 +1120,6 @@ def torch_experts(
|
||||
else:
|
||||
assert (a_scale is not None and w1_scale is not None
|
||||
and w2_scale is not None)
|
||||
f32 = torch.float32
|
||||
scales = a_scale if a_scale.numel() == 1 else a_scale[mask]
|
||||
tmp1 = a[mask].to(f32) * scales
|
||||
w1_dq = (w1[i].to(f32) * w1_scale[i]).transpose(0, 1)
|
||||
@@ -1126,8 +1128,8 @@ def torch_experts(
|
||||
w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1)
|
||||
out[mask] = (tmp2 @ w2_dq).to(out.dtype)
|
||||
|
||||
return (out.view(M, -1, w2.shape[1]) *
|
||||
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
return (out.view(M, -1, w2.shape[1]).to(f32) *
|
||||
topk_weight.view(M, -1, 1)).sum(dim=1).to(out.dtype)
|
||||
|
||||
|
||||
def torch_moe(a: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user