enable skipping of SW attention layers when using FP8 KV cache (#33695)

Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
This commit is contained in:
Jonas M. Kübler
2026-03-27 14:25:02 +01:00
committed by GitHub
parent b111f8a61f
commit 98e7f223b9
4 changed files with 58 additions and 0 deletions

View File

@@ -466,3 +466,26 @@ def test_fp8_reloading(
weight_loader(param, torch.zeros(shape)) # cannot use empty
method.process_weights_after_loading(layer)
@pytest.mark.skipif(
not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.",
)
def test_kv_cache_dtype_skip_layers(vllm_runner, monkeypatch):
"""Test that kv_cache_dtype_skip_layers skips quantization for specified layers."""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
with vllm_runner(
"facebook/opt-125m",
kv_cache_dtype="fp8",
kv_cache_dtype_skip_layers=["0", "2"],
enforce_eager=True,
) as llm:
def check_layers(model):
for i, layer in enumerate(model.model.decoder.layers):
expected = "auto" if str(i) in ["0", "2"] else "fp8"
assert layer.self_attn.attn.kv_cache_dtype == expected
llm.apply_model(check_layers)