Add llmcompressor fp8 kv-cache quant (per-tensor and per-attn_head) (#30141)

Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
Signed-off-by: eldarkurtic <8884008+eldarkurtic@users.noreply.github.com>
This commit is contained in:
Eldar Kurtić
2026-01-22 21:29:57 +01:00
committed by GitHub
parent 955b43a5a5
commit 44f08af3a7
18 changed files with 558 additions and 263 deletions

View File

@@ -8,6 +8,7 @@ import torch
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.quant_utils import scaled_dequantize
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
@@ -19,6 +20,7 @@ NUM_HEADS = [8] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 256]
BLOCK_SIZES = [8, 16, 32]
CACHE_LAYOUTS = ["NHD", "HND"]
KV_SCALE_TYPES = ["tensor", "attn_head"]
# Parameters for MLA tests.
KV_LORA_RANKS = [512]
@@ -170,6 +172,7 @@ def test_reshape_and_cache(
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
@pytest.mark.parametrize("kv_scale_type", KV_SCALE_TYPES)
@pytest.mark.parametrize("implementation", RESHAPE_FLASH_IMPLEMENTATIONS)
@torch.inference_mode()
def test_reshape_and_cache_flash(
@@ -184,6 +187,7 @@ def test_reshape_and_cache_flash(
device: str,
kv_cache_dtype: str,
kv_cache_layout: str,
kv_scale_type: str,
implementation: str,
) -> None:
set_random_seed(seed)
@@ -193,6 +197,9 @@ def test_reshape_and_cache_flash(
if implementation == "triton" and kv_cache_layout == "HND":
pytest.skip("Triton implementation only supports NHD layout.")
if kv_scale_type == "attn_head" and implementation != "cuda":
pytest.skip("Only CUDA implementation supports attn_head scaling.")
# fp8 conversion requires continugous memory buffer. Reduce the number of
# blocks and tokens to consume less memory.
num_tokens = num_tokens // 2
@@ -220,8 +227,12 @@ def test_reshape_and_cache_flash(
del key_caches
del value_caches
k_scale = (key.amax() / 64.0).to(torch.float32)
v_scale = (value.amax() / 64.0).to(torch.float32)
if kv_scale_type == "tensor":
k_scale = (key.amax() / 64.0).to(torch.float32)
v_scale = (value.amax() / 64.0).to(torch.float32)
else: # "attn_head"
k_scale = (key.amax(dim=(0, 2)) / 64.0).to(torch.float32)
v_scale = (value.amax(dim=(0, 2)) / 64.0).to(torch.float32)
def permute_and_compact(x):
y = x if kv_cache_layout == "NHD" else x.permute(0, 2, 1, 3)
@@ -230,15 +241,27 @@ def test_reshape_and_cache_flash(
key_cache_compact = permute_and_compact(key_cache)
value_cache_compact = permute_and_compact(value_cache)
def convert_fp8_local(output, input, scale, kv_dtype):
fp8_input = input.view(torch.float8_e4m3fn)
if scale.numel() == 1: # per-tensor
result = scaled_dequantize(
fp8_input.flatten(0, 2), scale, group_shape=None, out_dtype=output.dtype
).reshape(*input.shape)
else: # per-head: broadcast scale along the head dimension
# Original code uses dim 2 for NHD, dim 1 for HND
if kv_cache_layout == "NHD":
result = fp8_input.to(output.dtype) * scale.view(1, 1, -1, 1)
else:
result = fp8_input.to(output.dtype) * scale.view(1, -1, 1, 1)
output.copy_(result)
# Clone the KV caches.
if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16)
ops.convert_fp8(
cloned_key_cache, key_cache_compact, k_scale.item(), kv_cache_dtype
)
convert_fp8_local(cloned_key_cache, key_cache_compact, k_scale, kv_cache_dtype)
cloned_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16)
ops.convert_fp8(
cloned_value_cache, value_cache_compact, v_scale.item(), kv_cache_dtype
convert_fp8_local(
cloned_value_cache, value_cache_compact, v_scale, kv_cache_dtype
)
else:
cloned_key_cache = key_cache_compact.clone()
@@ -289,15 +312,13 @@ def test_reshape_and_cache_flash(
if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16)
ops.convert_fp8(
result_key_cache, key_cache_compact, k_scale.item(), kv_dtype=kv_cache_dtype
)
convert_fp8_local(result_key_cache, key_cache_compact, k_scale, kv_cache_dtype)
result_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16)
ops.convert_fp8(
convert_fp8_local(
result_value_cache,
value_cache_compact,
v_scale.item(),
kv_dtype=kv_cache_dtype,
v_scale,
kv_cache_dtype,
)
# Run the reference implementation.