[Kernel] Use moe_wna16 kernel for compressed tensors wna16 moe models (#16038)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2025-04-10 01:08:47 -06:00
committed by GitHub
parent a5d11a54dc
commit c70cf0fe06
5 changed files with 254 additions and 15 deletions

View File

@@ -512,7 +512,9 @@ class FusedMoE(torch.nn.Module):
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
in ("GPTQMarlinMoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod")):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params)
@@ -648,9 +650,10 @@ class FusedMoE(torch.nn.Module):
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
loaded_weight = loaded_weight.t().contiguous() if (
self.quant_method.__class__.__name__
== "CompressedTensorsWNA16MoEMethod") else loaded_weight
if self.quant_method.__class__.__name__ in (
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod"):
loaded_weight = loaded_weight.t().contiguous()
if shard_id not in ("w1", "w2", "w3"):
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "