Add llmcompressor fp8 kv-cache quant (per-tensor and per-attn_head) (#30141)

Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
Signed-off-by: eldarkurtic <8884008+eldarkurtic@users.noreply.github.com>
This commit is contained in:
Eldar Kurtić
2026-01-22 21:29:57 +01:00
committed by GitHub
parent 955b43a5a5
commit 44f08af3a7
18 changed files with 558 additions and 263 deletions

View File

@@ -620,6 +620,7 @@ class AttentionImpl(ABC, Generic[T]):
# TODO add support to more backends:
# https://github.com/vllm-project/vllm/issues/25584
supports_quant_query_input: bool = False
supports_per_head_quant_scales: bool = False
dcp_world_size: int
dcp_rank: int

View File

@@ -576,6 +576,11 @@ class FlashAttentionImpl(AttentionImpl):
)
self.supports_quant_query_input = True
self.supports_per_head_quant_scales = (
self.vllm_flash_attn_version >= 3
if self.vllm_flash_attn_version is not None
else False
)
def forward(
self,
@@ -691,6 +696,10 @@ class FlashAttentionImpl(AttentionImpl):
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
q_descale = layer._q_scale.expand(descale_shape)
k_descale = layer._k_scale.expand(descale_shape)
v_descale = layer._v_scale.expand(descale_shape)
if self.dcp_world_size > 1:
self._forward_with_dcp(
query[:num_actual_tokens],
@@ -700,9 +709,9 @@ class FlashAttentionImpl(AttentionImpl):
value_cache,
output[:num_actual_tokens],
attn_metadata,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)
return output
else:
@@ -728,9 +737,9 @@ class FlashAttentionImpl(AttentionImpl):
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks,
)