[Kernel] Enable 8-bit weights in Fused Marlin MoE (#8032)

Co-authored-by: Dipika <dipikasikka1@gmail.com>
This commit is contained in:
ElizaWszola
2024-09-16 17:47:19 +02:00
committed by GitHub
parent fc990f9795
commit a091e2da3e
12 changed files with 452 additions and 184 deletions

View File

@@ -445,7 +445,7 @@ def grouped_topk(hidden_states: torch.Tensor,
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids.to(torch.int32)
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def get_config_dtype_str(dtype: torch.dtype,