[Model] Correct Mixtral FP8 checkpoint loading (#5231)
This commit is contained in:
@@ -300,14 +300,15 @@ def all_close_1d(x: torch.Tensor) -> bool:
|
||||
|
||||
|
||||
def per_tensor_quantize(tensor: torch.Tensor,
|
||||
inv_scale: float) -> torch.Tensor:
|
||||
inv_scale: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
|
||||
return qweight.to(torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def per_tensor_dequantize(tensor: torch.Tensor,
|
||||
inv_scale: float) -> torch.Tensor:
|
||||
def per_tensor_dequantize(
|
||||
tensor: torch.Tensor, inv_scale: Union[float,
|
||||
torch.Tensor]) -> torch.Tensor:
|
||||
fake_qweight = tensor.to(torch.float16)
|
||||
dq_weight = fake_qweight * inv_scale
|
||||
return dq_weight
|
||||
|
||||
Reference in New Issue
Block a user