[Kernel] Use moe_wna16 kernel for compressed tensors wna16 moe models (#16038)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -512,7 +512,9 @@ class FusedMoE(torch.nn.Module):
|
||||
}
|
||||
# need full intermediate size pre-sharding for WNA16 act order
|
||||
if (self.quant_method.__class__.__name__
|
||||
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
|
||||
in ("GPTQMarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MoEMethod")):
|
||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
@@ -648,9 +650,10 @@ class FusedMoE(torch.nn.Module):
|
||||
# compressed-tensors checkpoints with packed weights are stored flipped
|
||||
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
||||
# against known CompressionFormat enum values that have this quality
|
||||
loaded_weight = loaded_weight.t().contiguous() if (
|
||||
self.quant_method.__class__.__name__
|
||||
== "CompressedTensorsWNA16MoEMethod") else loaded_weight
|
||||
if self.quant_method.__class__.__name__ in (
|
||||
"CompressedTensorsWNA16MarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MoEMethod"):
|
||||
loaded_weight = loaded_weight.t().contiguous()
|
||||
|
||||
if shard_id not in ("w1", "w2", "w3"):
|
||||
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
|
||||
|
||||
Reference in New Issue
Block a user