Add llmcompressor fp8 kv-cache quant (per-tensor and per-attn_head) (#30141)
Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: eldarkurtic <8884008+eldarkurtic@users.noreply.github.com>
This commit is contained in:
@@ -8,6 +8,7 @@ import torch
|
||||
|
||||
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import scaled_dequantize
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
@@ -19,6 +20,7 @@ NUM_HEADS = [8] # Arbitrary values for testing
|
||||
HEAD_SIZES = [64, 80, 256]
|
||||
BLOCK_SIZES = [8, 16, 32]
|
||||
CACHE_LAYOUTS = ["NHD", "HND"]
|
||||
KV_SCALE_TYPES = ["tensor", "attn_head"]
|
||||
|
||||
# Parameters for MLA tests.
|
||||
KV_LORA_RANKS = [512]
|
||||
@@ -170,6 +172,7 @@ def test_reshape_and_cache(
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
|
||||
@pytest.mark.parametrize("kv_scale_type", KV_SCALE_TYPES)
|
||||
@pytest.mark.parametrize("implementation", RESHAPE_FLASH_IMPLEMENTATIONS)
|
||||
@torch.inference_mode()
|
||||
def test_reshape_and_cache_flash(
|
||||
@@ -184,6 +187,7 @@ def test_reshape_and_cache_flash(
|
||||
device: str,
|
||||
kv_cache_dtype: str,
|
||||
kv_cache_layout: str,
|
||||
kv_scale_type: str,
|
||||
implementation: str,
|
||||
) -> None:
|
||||
set_random_seed(seed)
|
||||
@@ -193,6 +197,9 @@ def test_reshape_and_cache_flash(
|
||||
if implementation == "triton" and kv_cache_layout == "HND":
|
||||
pytest.skip("Triton implementation only supports NHD layout.")
|
||||
|
||||
if kv_scale_type == "attn_head" and implementation != "cuda":
|
||||
pytest.skip("Only CUDA implementation supports attn_head scaling.")
|
||||
|
||||
# fp8 conversion requires continugous memory buffer. Reduce the number of
|
||||
# blocks and tokens to consume less memory.
|
||||
num_tokens = num_tokens // 2
|
||||
@@ -220,8 +227,12 @@ def test_reshape_and_cache_flash(
|
||||
del key_caches
|
||||
del value_caches
|
||||
|
||||
k_scale = (key.amax() / 64.0).to(torch.float32)
|
||||
v_scale = (value.amax() / 64.0).to(torch.float32)
|
||||
if kv_scale_type == "tensor":
|
||||
k_scale = (key.amax() / 64.0).to(torch.float32)
|
||||
v_scale = (value.amax() / 64.0).to(torch.float32)
|
||||
else: # "attn_head"
|
||||
k_scale = (key.amax(dim=(0, 2)) / 64.0).to(torch.float32)
|
||||
v_scale = (value.amax(dim=(0, 2)) / 64.0).to(torch.float32)
|
||||
|
||||
def permute_and_compact(x):
|
||||
y = x if kv_cache_layout == "NHD" else x.permute(0, 2, 1, 3)
|
||||
@@ -230,15 +241,27 @@ def test_reshape_and_cache_flash(
|
||||
key_cache_compact = permute_and_compact(key_cache)
|
||||
value_cache_compact = permute_and_compact(value_cache)
|
||||
|
||||
def convert_fp8_local(output, input, scale, kv_dtype):
|
||||
fp8_input = input.view(torch.float8_e4m3fn)
|
||||
if scale.numel() == 1: # per-tensor
|
||||
result = scaled_dequantize(
|
||||
fp8_input.flatten(0, 2), scale, group_shape=None, out_dtype=output.dtype
|
||||
).reshape(*input.shape)
|
||||
else: # per-head: broadcast scale along the head dimension
|
||||
# Original code uses dim 2 for NHD, dim 1 for HND
|
||||
if kv_cache_layout == "NHD":
|
||||
result = fp8_input.to(output.dtype) * scale.view(1, 1, -1, 1)
|
||||
else:
|
||||
result = fp8_input.to(output.dtype) * scale.view(1, -1, 1, 1)
|
||||
output.copy_(result)
|
||||
|
||||
# Clone the KV caches.
|
||||
if kv_cache_dtype == "fp8":
|
||||
cloned_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16)
|
||||
ops.convert_fp8(
|
||||
cloned_key_cache, key_cache_compact, k_scale.item(), kv_cache_dtype
|
||||
)
|
||||
convert_fp8_local(cloned_key_cache, key_cache_compact, k_scale, kv_cache_dtype)
|
||||
cloned_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16)
|
||||
ops.convert_fp8(
|
||||
cloned_value_cache, value_cache_compact, v_scale.item(), kv_cache_dtype
|
||||
convert_fp8_local(
|
||||
cloned_value_cache, value_cache_compact, v_scale, kv_cache_dtype
|
||||
)
|
||||
else:
|
||||
cloned_key_cache = key_cache_compact.clone()
|
||||
@@ -289,15 +312,13 @@ def test_reshape_and_cache_flash(
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
result_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16)
|
||||
ops.convert_fp8(
|
||||
result_key_cache, key_cache_compact, k_scale.item(), kv_dtype=kv_cache_dtype
|
||||
)
|
||||
convert_fp8_local(result_key_cache, key_cache_compact, k_scale, kv_cache_dtype)
|
||||
result_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16)
|
||||
ops.convert_fp8(
|
||||
convert_fp8_local(
|
||||
result_value_cache,
|
||||
value_cache_compact,
|
||||
v_scale.item(),
|
||||
kv_dtype=kv_cache_dtype,
|
||||
v_scale,
|
||||
kv_cache_dtype,
|
||||
)
|
||||
|
||||
# Run the reference implementation.
|
||||
|
||||
Reference in New Issue
Block a user