[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)

Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Michael Goin <mgoin@redhat.com>
This commit is contained in:
Jinzhen Lin
2025-11-29 23:19:33 +08:00
committed by GitHub
parent fa59fe417f
commit 1656ad3704
46 changed files with 4371 additions and 2240 deletions

View File

@@ -846,6 +846,13 @@ def torch_experts(
or (expert_map is not None and global_num_experts == expert_map.shape[0])
)
if quant_dtype in [torch.float16, torch.bfloat16]:
quant_dtype = None
quant_input_only = quant_dtype is not None and w1_scale is None and w2_scale is None
if quant_input_only:
assert a1_scale is None and a2_scale is None
assert per_act_token_quant
M, K = a.shape
topk = topk_ids.shape[1]
@@ -863,6 +870,9 @@ def torch_experts(
a, a1_scale, quant_dtype, per_act_token_quant, block_shape
)
if quant_input_only:
a = (a.float() * a_scale.view(-1, 1)).to(w1.dtype)
num_experts = w1.shape[0]
topk_ids = topk_ids.view(-1)
@@ -882,6 +892,14 @@ def torch_experts(
out[mask] = tmp2 @ w2[i].transpose(0, 1)
if b_bias2 is not None:
out[mask] = out[mask] + b_bias2[i].view(1, -1).to(tmp1.dtype)
elif quant_input_only:
tmp1 = a[mask] @ w1[i].transpose(0, 1)
tmp2 = SiluAndMul()(tmp1)
tmp2, tmp2_scale = moe_kernel_quantize_input(
tmp2, None, quant_dtype, per_act_token_quant
)
tmp2 = (tmp2.float() * tmp2_scale.view(-1, 1)).to(w2.dtype)
out[mask] = tmp2 @ w2[i].transpose(0, 1)
elif block_shape is not None:
# block quantized
assert (