[Model] Correct Mixtral FP8 checkpoint loading (#5231)

This commit is contained in:
Cody Yu
2024-06-05 10:58:50 -07:00
committed by GitHub
parent ccd4f129e8
commit 5563a4dea8
2 changed files with 80 additions and 35 deletions

View File

@@ -300,14 +300,15 @@ def all_close_1d(x: torch.Tensor) -> bool:
def per_tensor_quantize(tensor: torch.Tensor,
inv_scale: float) -> torch.Tensor:
inv_scale: Union[float, torch.Tensor]) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
return qweight.to(torch.float8_e4m3fn)
def per_tensor_dequantize(tensor: torch.Tensor,
inv_scale: float) -> torch.Tensor:
def per_tensor_dequantize(
tensor: torch.Tensor, inv_scale: Union[float,
torch.Tensor]) -> torch.Tensor:
fake_qweight = tensor.to(torch.float16)
dq_weight = fake_qweight * inv_scale
return dq_weight