[Kernel] Zero point support in fused MarlinMoE kernel + AWQ Fused MoE (#8973)

Co-authored-by: Dipika <dipikasikka1@gmail.com>
Co-authored-by: Dipika Sikka <ds3822@columbia.edu>
This commit is contained in:
ElizaWszola
2024-10-04 20:34:44 +02:00
committed by GitHub
parent 0dcc8cbe5a
commit 05d686432f
23 changed files with 969 additions and 223 deletions

View File

@@ -208,6 +208,7 @@ def marlin_moe_permute_scales(
device=s.device,
dtype=s.dtype,
)
for e in range(num_experts):
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
return output
@@ -258,6 +259,20 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
return marlin_zp
def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
size_n: int, num_bits: int):
num_experts = q_zp_packed.shape[0]
output = torch.empty(
(num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
device=q_zp_packed.device,
dtype=q_zp_packed.dtype,
)
for e in range(num_experts):
output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n,
num_bits)
return output
def apply_gptq_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,