[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin@redhat.com>
This commit is contained in:
@@ -846,6 +846,13 @@ def torch_experts(
|
||||
or (expert_map is not None and global_num_experts == expert_map.shape[0])
|
||||
)
|
||||
|
||||
if quant_dtype in [torch.float16, torch.bfloat16]:
|
||||
quant_dtype = None
|
||||
quant_input_only = quant_dtype is not None and w1_scale is None and w2_scale is None
|
||||
if quant_input_only:
|
||||
assert a1_scale is None and a2_scale is None
|
||||
assert per_act_token_quant
|
||||
|
||||
M, K = a.shape
|
||||
topk = topk_ids.shape[1]
|
||||
|
||||
@@ -863,6 +870,9 @@ def torch_experts(
|
||||
a, a1_scale, quant_dtype, per_act_token_quant, block_shape
|
||||
)
|
||||
|
||||
if quant_input_only:
|
||||
a = (a.float() * a_scale.view(-1, 1)).to(w1.dtype)
|
||||
|
||||
num_experts = w1.shape[0]
|
||||
|
||||
topk_ids = topk_ids.view(-1)
|
||||
@@ -882,6 +892,14 @@ def torch_experts(
|
||||
out[mask] = tmp2 @ w2[i].transpose(0, 1)
|
||||
if b_bias2 is not None:
|
||||
out[mask] = out[mask] + b_bias2[i].view(1, -1).to(tmp1.dtype)
|
||||
elif quant_input_only:
|
||||
tmp1 = a[mask] @ w1[i].transpose(0, 1)
|
||||
tmp2 = SiluAndMul()(tmp1)
|
||||
tmp2, tmp2_scale = moe_kernel_quantize_input(
|
||||
tmp2, None, quant_dtype, per_act_token_quant
|
||||
)
|
||||
tmp2 = (tmp2.float() * tmp2_scale.view(-1, 1)).to(w2.dtype)
|
||||
out[mask] = tmp2 @ w2[i].transpose(0, 1)
|
||||
elif block_shape is not None:
|
||||
# block quantized
|
||||
assert (
|
||||
|
||||
Reference in New Issue
Block a user