[ Kernel ] FP8 Dynamic Per Token Quant - Add scale_ub (#6593)

Co-authored-by: Varun Sundar Rabindranth <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath
2024-07-19 21:15:26 -04:00
committed by GitHub
parent e81522e879
commit 2e26564259
6 changed files with 86 additions and 39 deletions

View File

@@ -1,4 +1,4 @@
from typing import Tuple, Union
from typing import Optional, Tuple, Union
import torch
@@ -7,13 +7,19 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
return torch.as_tensor(x, dtype=torch.float32, device='cuda')
def ref_dynamic_per_token_quant(x: torch.tensor,
quant_dtype: torch.dtype) \
quant_dtype: torch.dtype,
scale_ub: Optional[torch.tensor] = None) \
-> Tuple[torch.tensor, torch.tensor]:
assert quant_dtype in [torch.int8, torch.float8_e4m3fn]
if scale_ub is not None:
assert quant_dtype == torch.float8_e4m3fn
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
else torch.finfo(quant_dtype)
qtype_max = as_float32_tensor(qtype_traits.max)
s_1 = as_float32_tensor(1.0)
s_512 = as_float32_tensor(512.0)
# For fp8, in order to match the cuda kernel output, we have to do exactly
# the same operations as in the corresponding fp8 kernel to prevent
@@ -22,14 +28,24 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
# Compute scales
x_token_max, _ = x.abs().max(dim=-1)
x_token_max = as_float32_tensor(x_token_max)
if scale_ub is not None:
x_token_max = x_token_max.clamp(max=scale_ub)
scales = (x_token_max / qtype_max)[:, None]
# Quant
iscales = (qtype_max / x_token_max)[:, None]
torch_out = as_float32_tensor(x) * iscales
torch_out = torch_out.round() if quant_dtype == torch.int8 else torch_out
torch_out = torch_out.clamp(qtype_traits.min,
qtype_traits.max).to(quant_dtype)
if quant_dtype == torch.int8:
iscales = as_float32_tensor(s_1 / scales)
torch_out = as_float32_tensor(x) * iscales
torch_out = torch_out.round()
torch_out = torch_out.clamp(qtype_traits.min,
qtype_traits.max).to(quant_dtype)
else:
assert quant_dtype == torch.float8_e4m3fn
min_scaling_factor = s_1 / (qtype_max * s_512)
scales = scales.clamp(min=min_scaling_factor)
torch_out = as_float32_tensor(x) / scales
torch_out = torch_out.clamp(qtype_traits.min,
qtype_traits.max).to(quant_dtype)
return torch_out, scales