[Feature][Hardware][Amd] Add fp8 Linear Layer for Rocm (#7210)
This commit is contained in:
@@ -2,6 +2,13 @@ from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.utils import is_hip
|
||||
|
||||
# Using the default value (240.0) from pytorch will cause accuracy
|
||||
# issue on dynamic quantization models. Here use 224.0 for rocm.
|
||||
ROCM_FP8_MAX = 224.0
|
||||
FP8_DTYPE = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
|
||||
|
||||
|
||||
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
|
||||
return torch.as_tensor(x, dtype=torch.float32, device='cuda')
|
||||
@@ -11,13 +18,15 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
|
||||
scale_ub: Optional[torch.tensor] = None) \
|
||||
-> Tuple[torch.tensor, torch.tensor]:
|
||||
|
||||
assert quant_dtype in [torch.int8, torch.float8_e4m3fn]
|
||||
assert quant_dtype in [torch.int8, FP8_DTYPE]
|
||||
if scale_ub is not None:
|
||||
assert quant_dtype == torch.float8_e4m3fn
|
||||
assert quant_dtype == FP8_DTYPE
|
||||
|
||||
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
|
||||
else torch.finfo(quant_dtype)
|
||||
qtype_max = as_float32_tensor(qtype_traits.max)
|
||||
qtype_traits_max = ROCM_FP8_MAX if is_hip() else qtype_traits.max
|
||||
qtype_traits_min = -ROCM_FP8_MAX if is_hip() else qtype_traits.min
|
||||
qtype_max = as_float32_tensor(qtype_traits_max)
|
||||
s_1 = as_float32_tensor(1.0)
|
||||
s_512 = as_float32_tensor(512.0)
|
||||
|
||||
@@ -37,15 +46,15 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
|
||||
iscales = as_float32_tensor(s_1 / scales)
|
||||
torch_out = as_float32_tensor(x) * iscales
|
||||
torch_out = torch_out.round()
|
||||
torch_out = torch_out.clamp(qtype_traits.min,
|
||||
qtype_traits.max).to(quant_dtype)
|
||||
torch_out = torch_out.clamp(qtype_traits_min,
|
||||
qtype_traits_max).to(quant_dtype)
|
||||
else:
|
||||
assert quant_dtype == torch.float8_e4m3fn
|
||||
assert quant_dtype == FP8_DTYPE
|
||||
min_scaling_factor = s_1 / (qtype_max * s_512)
|
||||
scales = scales.clamp(min=min_scaling_factor)
|
||||
torch_out = as_float32_tensor(x) / scales
|
||||
torch_out = torch_out.clamp(qtype_traits.min,
|
||||
qtype_traits.max).to(quant_dtype)
|
||||
torch_out = torch_out.clamp(qtype_traits_min,
|
||||
qtype_traits_max).to(quant_dtype)
|
||||
|
||||
return torch_out, scales
|
||||
|
||||
@@ -56,8 +65,10 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
|
||||
def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
|
||||
-> Tuple[torch.tensor, torch.tensor]:
|
||||
|
||||
fp8_traits = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max = as_float32_tensor(fp8_traits.max)
|
||||
fp8_traits = torch.finfo(FP8_DTYPE)
|
||||
fp8_traits_max = ROCM_FP8_MAX if is_hip() else fp8_traits.max
|
||||
fp8_traits_min = -ROCM_FP8_MAX if is_hip() else fp8_traits.min
|
||||
fp8_max = as_float32_tensor(fp8_traits_max)
|
||||
one = as_float32_tensor(1.0)
|
||||
|
||||
# For fp8, in order to match the cuda kernel output, we have to do exactly
|
||||
@@ -68,5 +79,5 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
|
||||
ref_scale = x_max / fp8_max
|
||||
ref_iscale = one / ref_scale
|
||||
ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
|
||||
fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn)
|
||||
fp8_traits_min, fp8_traits_max).to(FP8_DTYPE)
|
||||
return ref_out, ref_scale.view((1, ))
|
||||
|
||||
Reference in New Issue
Block a user