[Misc] Update w2 scale loading for GPTQMarlinMoE (#12757)

This commit is contained in:
Dipika Sikka
2025-02-06 04:02:14 -05:00
committed by GitHub
parent 0408efc6d0
commit 7ca9934fe7
3 changed files with 21 additions and 8 deletions

View File

@@ -302,8 +302,8 @@ class FusedMoE(torch.nn.Module):
"weight_loader": self.weight_loader,
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__ ==
"CompressedTensorsWNA16MoEMethod"):
if (self.quant_method.__class__.__name__
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params)