dynamic distpatch of fp8 kernels (#14245)
Signed-off-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
@@ -22,10 +22,6 @@ from vllm.utils import direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
current_platform_fp8_dtype = (torch.float8_e4m3fnuz
|
||||
if current_platform.is_rocm() else
|
||||
torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
|
||||
if isinstance(x, torch.Tensor):
|
||||
@@ -165,9 +161,7 @@ def input_to_float8(
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""This function quantizes input values to float8 values "
|
||||
"with tensor-wise quantization."""
|
||||
if dtype is None:
|
||||
dtype = (torch.float8_e4m3fnuz
|
||||
if current_platform.is_rocm() else torch.float8_e4m3fn)
|
||||
dtype = current_platform.fp8_dtype() if dtype is None else dtype
|
||||
finfo = torch.finfo(dtype)
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||
@@ -311,9 +305,7 @@ def per_token_group_quant_fp8(
|
||||
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
scaling factor for quantization.
|
||||
"""
|
||||
if dtype is None:
|
||||
dtype = (torch.float8_e4m3fnuz
|
||||
if current_platform.is_rocm() else torch.float8_e4m3fn)
|
||||
dtype = current_platform.fp8_dtype() if dtype is None else dtype
|
||||
assert (x.shape[-1] % group_size == 0), (
|
||||
f"the last dimension of `x` {x.shape[-1]} must be divisible "
|
||||
f"by `group_size` {group_size}")
|
||||
|
||||
Reference in New Issue
Block a user