dynamic distpatch of fp8 kernels (#14245)

Signed-off-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
Jeff Daily
2025-03-11 07:54:56 -07:00
committed by GitHub
parent 08a1a1121d
commit a1c8f3796c
25 changed files with 292 additions and 159 deletions

View File

@@ -103,8 +103,7 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
assert attn._v_scale == 1.0
if current_platform.is_cuda():
if current_platform.has_device_capability(
89) and not force_marlin:
if current_platform.supports_fp8() and not force_marlin:
# For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fn
else:
@@ -112,11 +111,9 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
# for weight-only quantization using Marlin kernels
assert fc1.weight.dtype == torch.int32
elif current_platform.is_rocm():
# Only MI300 and above support quantization='fp8'
if current_platform.has_device_capability(
94) and not force_marlin:
if current_platform.supports_fp8() and not force_marlin:
# For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fnuz
assert fc1.weight.dtype == current_platform.fp8_dtype()
else: # unsupported ROCm platform
pytest.skip(
"Skip `test_load_fp16_model`. "