[ROCm] [Feature] [Doc] [Dockerfile] [BugFix] Support Per-Token-Activation Per-Channel-Weight FP8 Quantization Inferencing (#12501)

This commit is contained in:
TJian
2025-02-08 00:13:43 +08:00
committed by GitHub
parent 0630d4537a
commit eaa92d4437
8 changed files with 295 additions and 32 deletions

View File

@@ -55,10 +55,21 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str):
assert isinstance(attn.quant_method, Fp8KVCacheMethod)
# NOTE: it is valid for scales to be 1.0 (default value), but
# we know these checkpoints have scales < 1.0
assert 0.0 < attn._k_scale < 1.0
assert 0.0 < attn._v_scale < 1.0
if not current_platform.is_rocm():
# NOTE: This code path requires validation on Non-CUDA platform
# NOTE: it is valid for scales to be 1.0 (default value), but
# we know these checkpoints have scales < 1.0
assert 0.0 < attn._k_scale < 1.0
assert 0.0 < attn._v_scale < 1.0
else:
# NOTE: This code path is for ROCm platform
# NOTE: it is valid for scales to be 1.0 (default value), but
# we know these checkpoints have scales < 1.0
# However on ROCm platform, the _k_scale and _v_scale will be
# scaled by a factor of 2 as described in
# vllm/model_executor/layers/quantization/kv_cache.py
assert 0.0 < attn._k_scale < (1.0 * 2.0)
assert 0.0 < attn._v_scale < (1.0 * 2.0)
llm.apply_model(check_model)
@@ -91,13 +102,29 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
assert attn._k_scale == 1.0
assert attn._v_scale == 1.0
if current_platform.has_device_capability(89) and not force_marlin:
# For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fn
else:
# For GPUs without hardware support, we pack the fp8 weights
# for weight-only quantization using Marlin kernels
assert fc1.weight.dtype == torch.int32
if current_platform.is_cuda():
if current_platform.has_device_capability(
89) and not force_marlin:
# For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fn
else:
# For GPUs without hardware support, we pack the fp8 weights
# for weight-only quantization using Marlin kernels
assert fc1.weight.dtype == torch.int32
elif current_platform.is_rocm():
# Only MI300 and above support quantization='fp8'
if current_platform.has_device_capability(
94) and not force_marlin:
# For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fnuz
else: # unsupported ROCm platform
pytest.skip(
"Skip `test_load_fp16_model`. "
"It only runs on ROCm platform with FP8 compute."
" e.g. MI300X and above.")
else: # unsupported platform
pytest.skip("Skip `test_load_fp16_model`. "
"It only runs on CUDA and ROCm platform.")
llm.apply_model(check_model)