enable skipping of SW attention layers when using FP8 KV cache (#33695)

Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
This commit is contained in:
Jonas M. Kübler
2026-03-27 14:25:02 +01:00
committed by GitHub
parent b111f8a61f
commit 98e7f223b9
4 changed files with 58 additions and 0 deletions

View File

@@ -466,3 +466,26 @@ def test_fp8_reloading(
weight_loader(param, torch.zeros(shape)) # cannot use empty
method.process_weights_after_loading(layer)
@pytest.mark.skipif(
not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.",
)
def test_kv_cache_dtype_skip_layers(vllm_runner, monkeypatch):
"""Test that kv_cache_dtype_skip_layers skips quantization for specified layers."""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
with vllm_runner(
"facebook/opt-125m",
kv_cache_dtype="fp8",
kv_cache_dtype_skip_layers=["0", "2"],
enforce_eager=True,
) as llm:
def check_layers(model):
for i, layer in enumerate(model.model.decoder.layers):
expected = "auto" if str(i) in ["0", "2"] else "fp8"
assert layer.self_attn.attn.kv_cache_dtype == expected
llm.apply_model(check_layers)

View File

@@ -87,6 +87,9 @@ class CacheConfig:
It enables dynamic calculation of `k_scale` and `v_scale` when
kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model
checkpoint if available. Otherwise, the scales will default to 1.0."""
kv_cache_dtype_skip_layers: list[str] = field(default_factory=list)
"""Layer patterns to skip KV cache quantization. Accepts layer indices
(e.g., '0', '2', '4') or attention type names (e.g., 'sliding_window')."""
cpu_kvcache_space_bytes: int | None = None
"""(CPU backend only) CPU key-value cache space."""
mamba_page_size_padded: int | None = None

View File

@@ -597,6 +597,9 @@ class EngineArgs:
attention_backend: AttentionBackendEnum | None = AttentionConfig.backend
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
kv_cache_dtype_skip_layers: list[str] = get_field(
CacheConfig, "kv_cache_dtype_skip_layers"
)
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size")
@@ -1003,6 +1006,9 @@ class EngineArgs:
cache_group.add_argument(
"--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]
)
cache_group.add_argument(
"--kv-cache-dtype-skip-layers", **cache_kwargs["kv_cache_dtype_skip_layers"]
)
cache_group.add_argument(
"--kv-sharing-fast-prefill", **cache_kwargs["kv_sharing_fast_prefill"]
)
@@ -1578,6 +1584,7 @@ class EngineArgs:
enable_prefix_caching=self.enable_prefix_caching,
prefix_caching_hash_algo=self.prefix_caching_hash_algo,
calculate_kv_scales=self.calculate_kv_scales,
kv_cache_dtype_skip_layers=self.kv_cache_dtype_skip_layers,
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
mamba_cache_dtype=self.mamba_cache_dtype,
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,

View File

@@ -240,6 +240,31 @@ class Attention(nn.Module, AttentionLayerBase):
and kv_cache_scheme.get("strategy") == "attn_head"
)
# Skip quantization for specified layers
if cache_config is not None and cache_config.kv_cache_dtype_skip_layers:
from vllm.model_executor.models.utils import extract_layer_index
skip = False
# Check attention type
if (
sliding_window is not None
and "sliding_window" in cache_config.kv_cache_dtype_skip_layers
):
skip = True
# Check layer index
layer_idx = extract_layer_index(prefix)
if str(layer_idx) in cache_config.kv_cache_dtype_skip_layers:
skip = True
if skip:
kv_cache_dtype = "auto"
calculate_kv_scales = False
logger.info(
"Layer %s: kv_cache_dtype=%s, sliding_window=%s",
prefix,
kv_cache_dtype,
sliding_window,
)
self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
kv_cache_dtype, vllm_config.model_config
)