[Attn,KV-cache] Use per-head scales in the attention selector (#34281)

Signed-off-by: Your Name <you@example.com>
Signed-off-by: Eldar Kurtic <research@neuralmagic.com>
Co-authored-by: Eldar Kurtic <research@neuralmagic.com>
Co-authored-by: Your Name <you@example.com>
This commit is contained in:
Eldar Kurtić
2026-02-24 15:02:43 +01:00
committed by GitHub
parent 761e63e541
commit a87cc50859
6 changed files with 80 additions and 15 deletions

View File

@@ -291,3 +291,57 @@ def test_invalid_backend():
):
# Invalid backend name should raise ValueError when creating enum
AttentionConfig(backend=AttentionBackendEnum["INVALID"])
@pytest.mark.parametrize(
"backend_name,flash_attn_version,should_succeed",
[
("FLASH_ATTN", 3, True), # FA3 supports per-head quant scales
("FLASH_ATTN", 2, False), # FA2 does not support per-head quant scales
("FLASHINFER", None, False), # FlashInfer does not support
("FLEX_ATTENTION", None, False), # Flex does not support
],
)
def test_per_head_quant_scales_backend_selection(
backend_name: str, flash_attn_version: int | None, should_succeed: bool
):
"""Test backend selection when use_per_head_quant_scales=True."""
# Clear cache to ensure fresh backend selection
_cached_get_attn_backend.cache_clear()
attention_config = AttentionConfig(
backend=AttentionBackendEnum[backend_name],
flash_attn_version=flash_attn_version,
)
vllm_config = VllmConfig(attention_config=attention_config)
with (
set_current_vllm_config(vllm_config),
patch("vllm.platforms.current_platform", CudaPlatform()),
):
if backend_name == "FLASH_ATTN" and flash_attn_version == 3:
if not torch.cuda.is_available():
pytest.skip("FA3 requires CUDA")
capability = torch.cuda.get_device_capability()
if capability[0] != 9:
pytest.skip("FA3 is only supported on Hopper (SM 9.x) GPUs")
if should_succeed:
backend = get_attn_backend(
head_size=128,
dtype=torch.float16,
kv_cache_dtype="fp8",
block_size=64,
use_per_head_quant_scales=True,
)
assert backend.get_name() == backend_name
else:
with pytest.raises(ValueError) as exc_info:
get_attn_backend(
head_size=128,
dtype=torch.float16,
kv_cache_dtype="fp8",
block_size=64,
use_per_head_quant_scales=True,
)
assert backend_name in str(exc_info.value)