[Attn,KV-cache] Use per-head scales in the attention selector (#34281)

Signed-off-by: Your Name <you@example.com>
Signed-off-by: Eldar Kurtic <research@neuralmagic.com>
Co-authored-by: Eldar Kurtic <research@neuralmagic.com>
Co-authored-by: Your Name <you@example.com>
This commit is contained in:
Eldar Kurtić
2026-02-24 15:02:43 +01:00
committed by GitHub
parent 761e63e541
commit a87cc50859
6 changed files with 80 additions and 15 deletions

View File

@@ -291,3 +291,57 @@ def test_invalid_backend():
):
# Invalid backend name should raise ValueError when creating enum
AttentionConfig(backend=AttentionBackendEnum["INVALID"])
@pytest.mark.parametrize(
"backend_name,flash_attn_version,should_succeed",
[
("FLASH_ATTN", 3, True), # FA3 supports per-head quant scales
("FLASH_ATTN", 2, False), # FA2 does not support per-head quant scales
("FLASHINFER", None, False), # FlashInfer does not support
("FLEX_ATTENTION", None, False), # Flex does not support
],
)
def test_per_head_quant_scales_backend_selection(
backend_name: str, flash_attn_version: int | None, should_succeed: bool
):
"""Test backend selection when use_per_head_quant_scales=True."""
# Clear cache to ensure fresh backend selection
_cached_get_attn_backend.cache_clear()
attention_config = AttentionConfig(
backend=AttentionBackendEnum[backend_name],
flash_attn_version=flash_attn_version,
)
vllm_config = VllmConfig(attention_config=attention_config)
with (
set_current_vllm_config(vllm_config),
patch("vllm.platforms.current_platform", CudaPlatform()),
):
if backend_name == "FLASH_ATTN" and flash_attn_version == 3:
if not torch.cuda.is_available():
pytest.skip("FA3 requires CUDA")
capability = torch.cuda.get_device_capability()
if capability[0] != 9:
pytest.skip("FA3 is only supported on Hopper (SM 9.x) GPUs")
if should_succeed:
backend = get_attn_backend(
head_size=128,
dtype=torch.float16,
kv_cache_dtype="fp8",
block_size=64,
use_per_head_quant_scales=True,
)
assert backend.get_name() == backend_name
else:
with pytest.raises(ValueError) as exc_info:
get_attn_backend(
head_size=128,
dtype=torch.float16,
kv_cache_dtype="fp8",
block_size=64,
use_per_head_quant_scales=True,
)
assert backend_name in str(exc_info.value)

View File

@@ -229,13 +229,20 @@ class Attention(nn.Module, AttentionLayerBase):
calculate_kv_scales = False
# llm-compressor mdls need to set cache_dtype to "fp8" manually.
if getattr(quant_config, "kv_cache_scheme", None) is not None:
kv_cache_scheme = getattr(quant_config, "kv_cache_scheme", None)
if kv_cache_scheme is not None:
kv_cache_dtype = "fp8"
calculate_kv_scales = False
if cache_config is not None:
cache_config.cache_dtype = "fp8"
cache_config.calculate_kv_scales = False
# Check if per-head quant scales are required based on kv_cache_scheme
use_per_head_quant_scales = (
kv_cache_scheme is not None
and kv_cache_scheme.get("strategy") == "attn_head"
)
self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
kv_cache_dtype, vllm_config.model_config
)
@@ -272,6 +279,7 @@ class Attention(nn.Module, AttentionLayerBase):
use_mla=False,
has_sink=self.has_sink,
use_mm_prefix=self.use_mm_prefix,
use_per_head_quant_scales=use_per_head_quant_scales,
attn_type=attn_type,
)
else:

View File

@@ -985,14 +985,7 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
self.quant_config.kv_cache_scheme["strategy"]
)
if strategy == QuantizationStrategy.ATTN_HEAD:
assert layer.impl.supports_per_head_quant_scales, (
f"Layer {layer.__class__.__name__} with implementation "
f"{layer.impl.__class__.__name__} does not support per-head scales."
)
n_scales = int(layer.num_kv_heads)
else:
n_scales = 1
n_scales = int(layer.num_kv_heads) if strategy == "attn_head" else 1
layer.k_scale = torch.nn.Parameter(
torch.ones(n_scales, requires_grad=False, dtype=torch.float32)

View File

@@ -187,6 +187,10 @@ class AttentionBackend(ABC):
def is_sparse(cls) -> bool:
return False
@classmethod
def supports_per_head_quant_scales(cls) -> bool:
return False
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""Check if backend supports a given attention type.
@@ -225,6 +229,7 @@ class AttentionBackend(ABC):
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
use_per_head_quant_scales: bool,
device_capability: "DeviceCapability",
attn_type: str,
) -> list[str]:
@@ -253,6 +258,8 @@ class AttentionBackend(ABC):
invalid_reasons.append("sparse not supported")
else:
invalid_reasons.append("non-sparse not supported")
if use_per_head_quant_scales and not cls.supports_per_head_quant_scales():
invalid_reasons.append("per-head quant scales not supported")
if not cls.supports_compute_capability(device_capability):
invalid_reasons.append("compute capability not supported")
if not cls.supports_attn_type(attn_type):
@@ -635,7 +642,6 @@ class AttentionImplBase(ABC, Generic[T]):
# TODO add support to more backends:
# https://github.com/vllm-project/vllm/issues/25584
supports_quant_query_input: bool = False
supports_per_head_quant_scales: bool = False
dcp_world_size: int
dcp_rank: int

View File

@@ -95,6 +95,11 @@ class FlashAttentionBackend(AttentionBackend):
AttentionType.ENCODER_DECODER,
)
@classmethod
def supports_per_head_quant_scales(cls) -> bool:
fa_version = get_flash_attn_version()
return fa_version is not None and fa_version >= 3
@staticmethod
def get_impl_cls() -> type["FlashAttentionImpl"]:
return FlashAttentionImpl
@@ -595,11 +600,6 @@ class FlashAttentionImpl(AttentionImpl):
)
self.supports_quant_query_input = True
self.supports_per_head_quant_scales = (
self.vllm_flash_attn_version >= 3
if self.vllm_flash_attn_version is not None
else False
)
def forward(
self,

View File

@@ -27,6 +27,7 @@ class AttentionSelectorConfig(NamedTuple):
has_sink: bool = False
use_sparse: bool = False
use_mm_prefix: bool = False
use_per_head_quant_scales: bool = False
attn_type: str = AttentionType.DECODER
def __repr__(self):
@@ -39,6 +40,7 @@ class AttentionSelectorConfig(NamedTuple):
f"has_sink={self.has_sink}, "
f"use_sparse={self.use_sparse}, "
f"use_mm_prefix={self.use_mm_prefix}, "
f"use_per_head_quant_scales={self.use_per_head_quant_scales}, "
f"attn_type={self.attn_type})"
)
@@ -52,6 +54,7 @@ def get_attn_backend(
has_sink: bool = False,
use_sparse: bool = False,
use_mm_prefix: bool = False,
use_per_head_quant_scales: bool = False,
attn_type: str | None = None,
num_heads: int | None = None,
) -> type[AttentionBackend]:
@@ -77,6 +80,7 @@ def get_attn_backend(
has_sink=has_sink,
use_sparse=use_sparse,
use_mm_prefix=use_mm_prefix,
use_per_head_quant_scales=use_per_head_quant_scales,
attn_type=attn_type or AttentionType.DECODER,
)