[Misc] Support FP8 MoE for compressed-tensors (#8588)

This commit is contained in:
Michael Goin
2024-09-25 12:43:36 -04:00
committed by GitHub
parent 64840dfae4
commit 873edda6cf
5 changed files with 226 additions and 8 deletions

View File

@@ -323,10 +323,12 @@ class FusedMoE(torch.nn.Module):
loaded_weight: torch.Tensor, weight_name: str,
shard_id: str, expert_id: int) -> None:
# compressed-tensors represents weights on disk which are flipped
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
loaded_weight = loaded_weight.t().contiguous() if (
self.quant_method.__class__.__name__
== "CompressedTensorsMoEMethod") else loaded_weight
== "CompressedTensorsWNA16MoEMethod") else loaded_weight
if shard_id not in ("w1", "w2", "w3"):
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
@@ -353,6 +355,9 @@ class FusedMoE(torch.nn.Module):
# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# this is needed for compressed-tensors only
loaded_weight = loaded_weight.to(param.data.device)
if param.data[expert_id] != 1 and (param.data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(