[Quantization] Fp8 Channelwise Dynamic Per Token GroupedGEMM (#15587)
Signed-off-by: ElizaWszola <eliza@neuralmagic.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com> Co-authored-by: ElizaWszola <eliza@neuralmagic.com> Co-authored-by: Lucas Wilkinson <wilkinson.lucas@gmail.com> Co-authored-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
@@ -885,32 +885,6 @@ class FusedMoE(torch.nn.Module):
|
||||
]
|
||||
]
|
||||
|
||||
def _load_fp8_scale(self, param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor, weight_name: str,
|
||||
shard_id: str, expert_id: int) -> None:
|
||||
param_data = param.data
|
||||
|
||||
# Input scales can be loaded directly and should be equal.
|
||||
if "input_scale" in weight_name:
|
||||
if param_data[expert_id] != 1 and (param_data[expert_id] -
|
||||
loaded_weight).abs() > 1e-5:
|
||||
raise ValueError(
|
||||
"input_scales of w1 and w3 of a layer "
|
||||
f"must be equal. But got {param_data[expert_id]} "
|
||||
f"vs. {loaded_weight}")
|
||||
param_data[expert_id] = loaded_weight
|
||||
# Weight scales
|
||||
elif "weight_scale" in weight_name:
|
||||
# If we are in merged column case (gate_up_proj)
|
||||
if shard_id in ("w1", "w3"):
|
||||
# We have to keep the weight scales of w1 and w3 because
|
||||
# we need to re-quantize w1/w3 weights after weight loading.
|
||||
idx = 0 if shard_id == "w1" else 1
|
||||
param_data[expert_id][idx] = loaded_weight
|
||||
# If we are in the row parallel case (down_proj)
|
||||
else:
|
||||
param_data[expert_id] = loaded_weight
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
|
||||
s = (
|
||||
|
||||
Reference in New Issue
Block a user