dynamic distpatch of fp8 kernels (#14245)
Signed-off-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
@@ -103,8 +103,7 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
|
||||
assert attn._v_scale == 1.0
|
||||
|
||||
if current_platform.is_cuda():
|
||||
if current_platform.has_device_capability(
|
||||
89) and not force_marlin:
|
||||
if current_platform.supports_fp8() and not force_marlin:
|
||||
# For GPUs with hardware support, we keep weights in fp8
|
||||
assert fc1.weight.dtype == torch.float8_e4m3fn
|
||||
else:
|
||||
@@ -112,11 +111,9 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
|
||||
# for weight-only quantization using Marlin kernels
|
||||
assert fc1.weight.dtype == torch.int32
|
||||
elif current_platform.is_rocm():
|
||||
# Only MI300 and above support quantization='fp8'
|
||||
if current_platform.has_device_capability(
|
||||
94) and not force_marlin:
|
||||
if current_platform.supports_fp8() and not force_marlin:
|
||||
# For GPUs with hardware support, we keep weights in fp8
|
||||
assert fc1.weight.dtype == torch.float8_e4m3fnuz
|
||||
assert fc1.weight.dtype == current_platform.fp8_dtype()
|
||||
else: # unsupported ROCm platform
|
||||
pytest.skip(
|
||||
"Skip `test_load_fp16_model`. "
|
||||
|
||||
Reference in New Issue
Block a user