[Kernel] Enable 8-bit weights in Fused Marlin MoE (#8032)
Co-authored-by: Dipika <dipikasikka1@gmail.com>
This commit is contained in:
@@ -445,7 +445,7 @@ def grouped_topk(hidden_states: torch.Tensor,
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights, topk_ids.to(torch.int32)
|
||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
|
||||
|
||||
def get_config_dtype_str(dtype: torch.dtype,
|
||||
|
||||
Reference in New Issue
Block a user