enable skipping of SW attention layers when using FP8 KV cache (#33695)
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
This commit is contained in:
@@ -466,3 +466,26 @@ def test_fp8_reloading(
|
||||
weight_loader(param, torch.zeros(shape)) # cannot use empty
|
||||
|
||||
method.process_weights_after_loading(layer)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("fp8"),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_kv_cache_dtype_skip_layers(vllm_runner, monkeypatch):
|
||||
"""Test that kv_cache_dtype_skip_layers skips quantization for specified layers."""
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
with vllm_runner(
|
||||
"facebook/opt-125m",
|
||||
kv_cache_dtype="fp8",
|
||||
kv_cache_dtype_skip_layers=["0", "2"],
|
||||
enforce_eager=True,
|
||||
) as llm:
|
||||
|
||||
def check_layers(model):
|
||||
for i, layer in enumerate(model.model.decoder.layers):
|
||||
expected = "auto" if str(i) in ["0", "2"] else "fp8"
|
||||
assert layer.self_attn.attn.kv_cache_dtype == expected
|
||||
|
||||
llm.apply_model(check_layers)
|
||||
|
||||
@@ -87,6 +87,9 @@ class CacheConfig:
|
||||
It enables dynamic calculation of `k_scale` and `v_scale` when
|
||||
kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model
|
||||
checkpoint if available. Otherwise, the scales will default to 1.0."""
|
||||
kv_cache_dtype_skip_layers: list[str] = field(default_factory=list)
|
||||
"""Layer patterns to skip KV cache quantization. Accepts layer indices
|
||||
(e.g., '0', '2', '4') or attention type names (e.g., 'sliding_window')."""
|
||||
cpu_kvcache_space_bytes: int | None = None
|
||||
"""(CPU backend only) CPU key-value cache space."""
|
||||
mamba_page_size_padded: int | None = None
|
||||
|
||||
@@ -597,6 +597,9 @@ class EngineArgs:
|
||||
attention_backend: AttentionBackendEnum | None = AttentionConfig.backend
|
||||
|
||||
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
|
||||
kv_cache_dtype_skip_layers: list[str] = get_field(
|
||||
CacheConfig, "kv_cache_dtype_skip_layers"
|
||||
)
|
||||
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
|
||||
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
|
||||
mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size")
|
||||
@@ -1003,6 +1006,9 @@ class EngineArgs:
|
||||
cache_group.add_argument(
|
||||
"--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]
|
||||
)
|
||||
cache_group.add_argument(
|
||||
"--kv-cache-dtype-skip-layers", **cache_kwargs["kv_cache_dtype_skip_layers"]
|
||||
)
|
||||
cache_group.add_argument(
|
||||
"--kv-sharing-fast-prefill", **cache_kwargs["kv_sharing_fast_prefill"]
|
||||
)
|
||||
@@ -1578,6 +1584,7 @@ class EngineArgs:
|
||||
enable_prefix_caching=self.enable_prefix_caching,
|
||||
prefix_caching_hash_algo=self.prefix_caching_hash_algo,
|
||||
calculate_kv_scales=self.calculate_kv_scales,
|
||||
kv_cache_dtype_skip_layers=self.kv_cache_dtype_skip_layers,
|
||||
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
|
||||
mamba_cache_dtype=self.mamba_cache_dtype,
|
||||
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
|
||||
|
||||
@@ -240,6 +240,31 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
and kv_cache_scheme.get("strategy") == "attn_head"
|
||||
)
|
||||
|
||||
# Skip quantization for specified layers
|
||||
if cache_config is not None and cache_config.kv_cache_dtype_skip_layers:
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
|
||||
skip = False
|
||||
# Check attention type
|
||||
if (
|
||||
sliding_window is not None
|
||||
and "sliding_window" in cache_config.kv_cache_dtype_skip_layers
|
||||
):
|
||||
skip = True
|
||||
# Check layer index
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
if str(layer_idx) in cache_config.kv_cache_dtype_skip_layers:
|
||||
skip = True
|
||||
if skip:
|
||||
kv_cache_dtype = "auto"
|
||||
calculate_kv_scales = False
|
||||
logger.info(
|
||||
"Layer %s: kv_cache_dtype=%s, sliding_window=%s",
|
||||
prefix,
|
||||
kv_cache_dtype,
|
||||
sliding_window,
|
||||
)
|
||||
|
||||
self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
|
||||
kv_cache_dtype, vllm_config.model_config
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user