[Kernel] Zero point support in fused MarlinMoE kernel + AWQ Fused MoE (#8973)
Co-authored-by: Dipika <dipikasikka1@gmail.com> Co-authored-by: Dipika Sikka <ds3822@columbia.edu>
This commit is contained in:
@@ -208,6 +208,7 @@ def marlin_moe_permute_scales(
|
||||
device=s.device,
|
||||
dtype=s.dtype,
|
||||
)
|
||||
|
||||
for e in range(num_experts):
|
||||
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
|
||||
return output
|
||||
@@ -258,6 +259,20 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
|
||||
return marlin_zp
|
||||
|
||||
|
||||
def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
|
||||
size_n: int, num_bits: int):
|
||||
num_experts = q_zp_packed.shape[0]
|
||||
output = torch.empty(
|
||||
(num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
|
||||
device=q_zp_packed.device,
|
||||
dtype=q_zp_packed.dtype,
|
||||
)
|
||||
for e in range(num_experts):
|
||||
output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n,
|
||||
num_bits)
|
||||
return output
|
||||
|
||||
|
||||
def apply_gptq_marlin_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user