[Attn,KV-cache] Use per-head scales in the attention selector (#34281)
Signed-off-by: Your Name <you@example.com> Signed-off-by: Eldar Kurtic <research@neuralmagic.com> Co-authored-by: Eldar Kurtic <research@neuralmagic.com> Co-authored-by: Your Name <you@example.com>
This commit is contained in:
@@ -291,3 +291,57 @@ def test_invalid_backend():
|
||||
):
|
||||
# Invalid backend name should raise ValueError when creating enum
|
||||
AttentionConfig(backend=AttentionBackendEnum["INVALID"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"backend_name,flash_attn_version,should_succeed",
|
||||
[
|
||||
("FLASH_ATTN", 3, True), # FA3 supports per-head quant scales
|
||||
("FLASH_ATTN", 2, False), # FA2 does not support per-head quant scales
|
||||
("FLASHINFER", None, False), # FlashInfer does not support
|
||||
("FLEX_ATTENTION", None, False), # Flex does not support
|
||||
],
|
||||
)
|
||||
def test_per_head_quant_scales_backend_selection(
|
||||
backend_name: str, flash_attn_version: int | None, should_succeed: bool
|
||||
):
|
||||
"""Test backend selection when use_per_head_quant_scales=True."""
|
||||
# Clear cache to ensure fresh backend selection
|
||||
_cached_get_attn_backend.cache_clear()
|
||||
|
||||
attention_config = AttentionConfig(
|
||||
backend=AttentionBackendEnum[backend_name],
|
||||
flash_attn_version=flash_attn_version,
|
||||
)
|
||||
vllm_config = VllmConfig(attention_config=attention_config)
|
||||
|
||||
with (
|
||||
set_current_vllm_config(vllm_config),
|
||||
patch("vllm.platforms.current_platform", CudaPlatform()),
|
||||
):
|
||||
if backend_name == "FLASH_ATTN" and flash_attn_version == 3:
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("FA3 requires CUDA")
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] != 9:
|
||||
pytest.skip("FA3 is only supported on Hopper (SM 9.x) GPUs")
|
||||
|
||||
if should_succeed:
|
||||
backend = get_attn_backend(
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
kv_cache_dtype="fp8",
|
||||
block_size=64,
|
||||
use_per_head_quant_scales=True,
|
||||
)
|
||||
assert backend.get_name() == backend_name
|
||||
else:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_attn_backend(
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
kv_cache_dtype="fp8",
|
||||
block_size=64,
|
||||
use_per_head_quant_scales=True,
|
||||
)
|
||||
assert backend_name in str(exc_info.value)
|
||||
|
||||
@@ -229,13 +229,20 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
calculate_kv_scales = False
|
||||
|
||||
# llm-compressor mdls need to set cache_dtype to "fp8" manually.
|
||||
if getattr(quant_config, "kv_cache_scheme", None) is not None:
|
||||
kv_cache_scheme = getattr(quant_config, "kv_cache_scheme", None)
|
||||
if kv_cache_scheme is not None:
|
||||
kv_cache_dtype = "fp8"
|
||||
calculate_kv_scales = False
|
||||
if cache_config is not None:
|
||||
cache_config.cache_dtype = "fp8"
|
||||
cache_config.calculate_kv_scales = False
|
||||
|
||||
# Check if per-head quant scales are required based on kv_cache_scheme
|
||||
use_per_head_quant_scales = (
|
||||
kv_cache_scheme is not None
|
||||
and kv_cache_scheme.get("strategy") == "attn_head"
|
||||
)
|
||||
|
||||
self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
|
||||
kv_cache_dtype, vllm_config.model_config
|
||||
)
|
||||
@@ -272,6 +279,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
use_mla=False,
|
||||
has_sink=self.has_sink,
|
||||
use_mm_prefix=self.use_mm_prefix,
|
||||
use_per_head_quant_scales=use_per_head_quant_scales,
|
||||
attn_type=attn_type,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -985,14 +985,7 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
|
||||
self.quant_config.kv_cache_scheme["strategy"]
|
||||
)
|
||||
|
||||
if strategy == QuantizationStrategy.ATTN_HEAD:
|
||||
assert layer.impl.supports_per_head_quant_scales, (
|
||||
f"Layer {layer.__class__.__name__} with implementation "
|
||||
f"{layer.impl.__class__.__name__} does not support per-head scales."
|
||||
)
|
||||
n_scales = int(layer.num_kv_heads)
|
||||
else:
|
||||
n_scales = 1
|
||||
n_scales = int(layer.num_kv_heads) if strategy == "attn_head" else 1
|
||||
|
||||
layer.k_scale = torch.nn.Parameter(
|
||||
torch.ones(n_scales, requires_grad=False, dtype=torch.float32)
|
||||
|
||||
@@ -187,6 +187,10 @@ class AttentionBackend(ABC):
|
||||
def is_sparse(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def supports_per_head_quant_scales(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""Check if backend supports a given attention type.
|
||||
@@ -225,6 +229,7 @@ class AttentionBackend(ABC):
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
use_mm_prefix: bool,
|
||||
use_per_head_quant_scales: bool,
|
||||
device_capability: "DeviceCapability",
|
||||
attn_type: str,
|
||||
) -> list[str]:
|
||||
@@ -253,6 +258,8 @@ class AttentionBackend(ABC):
|
||||
invalid_reasons.append("sparse not supported")
|
||||
else:
|
||||
invalid_reasons.append("non-sparse not supported")
|
||||
if use_per_head_quant_scales and not cls.supports_per_head_quant_scales():
|
||||
invalid_reasons.append("per-head quant scales not supported")
|
||||
if not cls.supports_compute_capability(device_capability):
|
||||
invalid_reasons.append("compute capability not supported")
|
||||
if not cls.supports_attn_type(attn_type):
|
||||
@@ -635,7 +642,6 @@ class AttentionImplBase(ABC, Generic[T]):
|
||||
# TODO add support to more backends:
|
||||
# https://github.com/vllm-project/vllm/issues/25584
|
||||
supports_quant_query_input: bool = False
|
||||
supports_per_head_quant_scales: bool = False
|
||||
|
||||
dcp_world_size: int
|
||||
dcp_rank: int
|
||||
|
||||
@@ -95,6 +95,11 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
AttentionType.ENCODER_DECODER,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def supports_per_head_quant_scales(cls) -> bool:
|
||||
fa_version = get_flash_attn_version()
|
||||
return fa_version is not None and fa_version >= 3
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashAttentionImpl"]:
|
||||
return FlashAttentionImpl
|
||||
@@ -595,11 +600,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
self.supports_quant_query_input = True
|
||||
self.supports_per_head_quant_scales = (
|
||||
self.vllm_flash_attn_version >= 3
|
||||
if self.vllm_flash_attn_version is not None
|
||||
else False
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -27,6 +27,7 @@ class AttentionSelectorConfig(NamedTuple):
|
||||
has_sink: bool = False
|
||||
use_sparse: bool = False
|
||||
use_mm_prefix: bool = False
|
||||
use_per_head_quant_scales: bool = False
|
||||
attn_type: str = AttentionType.DECODER
|
||||
|
||||
def __repr__(self):
|
||||
@@ -39,6 +40,7 @@ class AttentionSelectorConfig(NamedTuple):
|
||||
f"has_sink={self.has_sink}, "
|
||||
f"use_sparse={self.use_sparse}, "
|
||||
f"use_mm_prefix={self.use_mm_prefix}, "
|
||||
f"use_per_head_quant_scales={self.use_per_head_quant_scales}, "
|
||||
f"attn_type={self.attn_type})"
|
||||
)
|
||||
|
||||
@@ -52,6 +54,7 @@ def get_attn_backend(
|
||||
has_sink: bool = False,
|
||||
use_sparse: bool = False,
|
||||
use_mm_prefix: bool = False,
|
||||
use_per_head_quant_scales: bool = False,
|
||||
attn_type: str | None = None,
|
||||
num_heads: int | None = None,
|
||||
) -> type[AttentionBackend]:
|
||||
@@ -77,6 +80,7 @@ def get_attn_backend(
|
||||
has_sink=has_sink,
|
||||
use_sparse=use_sparse,
|
||||
use_mm_prefix=use_mm_prefix,
|
||||
use_per_head_quant_scales=use_per_head_quant_scales,
|
||||
attn_type=attn_type or AttentionType.DECODER,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user