Add llmcompressor fp8 kv-cache quant (per-tensor and per-attn_head) (#30141)
Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: eldarkurtic <8884008+eldarkurtic@users.noreply.github.com>
This commit is contained in:
@@ -620,6 +620,7 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
# TODO add support to more backends:
|
||||
# https://github.com/vllm-project/vllm/issues/25584
|
||||
supports_quant_query_input: bool = False
|
||||
supports_per_head_quant_scales: bool = False
|
||||
|
||||
dcp_world_size: int
|
||||
dcp_rank: int
|
||||
|
||||
@@ -576,6 +576,11 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
self.supports_quant_query_input = True
|
||||
self.supports_per_head_quant_scales = (
|
||||
self.vllm_flash_attn_version >= 3
|
||||
if self.vllm_flash_attn_version is not None
|
||||
else False
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -691,6 +696,10 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
|
||||
|
||||
q_descale = layer._q_scale.expand(descale_shape)
|
||||
k_descale = layer._k_scale.expand(descale_shape)
|
||||
v_descale = layer._v_scale.expand(descale_shape)
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
self._forward_with_dcp(
|
||||
query[:num_actual_tokens],
|
||||
@@ -700,9 +709,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
value_cache,
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
return output
|
||||
else:
|
||||
@@ -728,9 +737,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
softcap=self.logits_soft_cap,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
num_splits=attn_metadata.max_num_splits,
|
||||
s_aux=self.sinks,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user