[Quantization] Fp8 Channelwise Dynamic Per Token GroupedGEMM (#15587)

Signed-off-by: ElizaWszola <eliza@neuralmagic.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
Co-authored-by: ElizaWszola <eliza@neuralmagic.com>
Co-authored-by: Lucas Wilkinson <wilkinson.lucas@gmail.com>
Co-authored-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
Robert Shaw
2025-03-27 02:47:25 -04:00
committed by GitHub
parent f4c98b4d4c
commit 43ed4143c4
2 changed files with 66 additions and 65 deletions

View File

@@ -885,32 +885,6 @@ class FusedMoE(torch.nn.Module):
]
]
def _load_fp8_scale(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str,
shard_id: str, expert_id: int) -> None:
param_data = param.data
# Input scales can be loaded directly and should be equal.
if "input_scale" in weight_name:
if param_data[expert_id] != 1 and (param_data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param_data[expert_id]} "
f"vs. {loaded_weight}")
param_data[expert_id] = loaded_weight
# Weight scales
elif "weight_scale" in weight_name:
# If we are in merged column case (gate_up_proj)
if shard_id in ("w1", "w3"):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == "w1" else 1
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
else:
param_data[expert_id] = loaded_weight
def extra_repr(self) -> str:
s = (