[ROCm][Kernel] Add gfx950 support for skinny gemms (#18010)

Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
Charlie Fu
2025-05-31 09:40:05 -05:00
committed by GitHub
parent f2c3f66d59
commit 306d60401d
5 changed files with 91 additions and 54 deletions

View File

@@ -8,7 +8,7 @@ from vllm.platforms import current_platform
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for rocm.
ROCM_FP8_MAX = 224.0
ROCM_FP8FNUZ_MAX = 224.0
FP8_DTYPE = current_platform.fp8_dtype()
@@ -26,9 +26,11 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
else torch.finfo(quant_dtype)
qtype_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
qtype_traits_max = ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \
and current_platform.is_fp8_fnuz() \
else qtype_traits.max
qtype_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
qtype_traits_min = -ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \
and current_platform.is_fp8_fnuz() \
else qtype_traits.min
qtype_max = as_float32_tensor(qtype_traits_max)
s_1 = as_float32_tensor(1.0)
@@ -70,9 +72,11 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
-> tuple[torch.tensor, torch.tensor]:
fp8_traits = torch.finfo(FP8_DTYPE)
fp8_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
fp8_traits_max = ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \
and current_platform.is_fp8_fnuz() \
else fp8_traits.max
fp8_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
fp8_traits_min = -ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \
and current_platform.is_fp8_fnuz() \
else fp8_traits.min
fp8_max = as_float32_tensor(fp8_traits_max)
one = as_float32_tensor(1.0)