[ROCm] [Feature] [Doc] [Dockerfile] [BugFix] Support Per-Token-Activation Per-Channel-Weight FP8 Quantization Inferencing (#12501)
This commit is contained in:
@@ -55,10 +55,21 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str):
|
||||
|
||||
assert isinstance(attn.quant_method, Fp8KVCacheMethod)
|
||||
|
||||
# NOTE: it is valid for scales to be 1.0 (default value), but
|
||||
# we know these checkpoints have scales < 1.0
|
||||
assert 0.0 < attn._k_scale < 1.0
|
||||
assert 0.0 < attn._v_scale < 1.0
|
||||
if not current_platform.is_rocm():
|
||||
# NOTE: This code path requires validation on Non-CUDA platform
|
||||
# NOTE: it is valid for scales to be 1.0 (default value), but
|
||||
# we know these checkpoints have scales < 1.0
|
||||
assert 0.0 < attn._k_scale < 1.0
|
||||
assert 0.0 < attn._v_scale < 1.0
|
||||
else:
|
||||
# NOTE: This code path is for ROCm platform
|
||||
# NOTE: it is valid for scales to be 1.0 (default value), but
|
||||
# we know these checkpoints have scales < 1.0
|
||||
# However on ROCm platform, the _k_scale and _v_scale will be
|
||||
# scaled by a factor of 2 as described in
|
||||
# vllm/model_executor/layers/quantization/kv_cache.py
|
||||
assert 0.0 < attn._k_scale < (1.0 * 2.0)
|
||||
assert 0.0 < attn._v_scale < (1.0 * 2.0)
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
@@ -91,13 +102,29 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
|
||||
assert attn._k_scale == 1.0
|
||||
assert attn._v_scale == 1.0
|
||||
|
||||
if current_platform.has_device_capability(89) and not force_marlin:
|
||||
# For GPUs with hardware support, we keep weights in fp8
|
||||
assert fc1.weight.dtype == torch.float8_e4m3fn
|
||||
else:
|
||||
# For GPUs without hardware support, we pack the fp8 weights
|
||||
# for weight-only quantization using Marlin kernels
|
||||
assert fc1.weight.dtype == torch.int32
|
||||
if current_platform.is_cuda():
|
||||
if current_platform.has_device_capability(
|
||||
89) and not force_marlin:
|
||||
# For GPUs with hardware support, we keep weights in fp8
|
||||
assert fc1.weight.dtype == torch.float8_e4m3fn
|
||||
else:
|
||||
# For GPUs without hardware support, we pack the fp8 weights
|
||||
# for weight-only quantization using Marlin kernels
|
||||
assert fc1.weight.dtype == torch.int32
|
||||
elif current_platform.is_rocm():
|
||||
# Only MI300 and above support quantization='fp8'
|
||||
if current_platform.has_device_capability(
|
||||
94) and not force_marlin:
|
||||
# For GPUs with hardware support, we keep weights in fp8
|
||||
assert fc1.weight.dtype == torch.float8_e4m3fnuz
|
||||
else: # unsupported ROCm platform
|
||||
pytest.skip(
|
||||
"Skip `test_load_fp16_model`. "
|
||||
"It only runs on ROCm platform with FP8 compute."
|
||||
" e.g. MI300X and above.")
|
||||
else: # unsupported platform
|
||||
pytest.skip("Skip `test_load_fp16_model`. "
|
||||
"It only runs on CUDA and ROCm platform.")
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user