[V1][Bugfix] Standardize quantized kv cache rejection for attention backends (#14221)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2025-03-06 17:18:29 -05:00
committed by GitHub
parent 6b2ef5cd17
commit 6832707e90
11 changed files with 59 additions and 20 deletions

View File

@@ -7,7 +7,8 @@ import numpy as np
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.utils import get_flash_attn_version
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.logger import init_logger
@@ -180,6 +181,9 @@ class FlashAttentionImpl(AttentionImpl):
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashAttention V1 with FP8 KV cache not yet supported")
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0

View File

@@ -5,7 +5,8 @@ from typing import Any, Optional
import torch
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported)
@@ -115,6 +116,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"are not implemented for "
"FlashMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashMLA V1 with FP8 KV cache not yet supported")
def _forward_decode(
self,
q_nope: torch.Tensor,
@@ -125,9 +130,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 FlashMLA not yet supported")
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)

View File

@@ -4,7 +4,8 @@ from typing import Any, Optional
import torch
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
@@ -61,6 +62,10 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
"are not implemented for "
"TritonMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"TritonMLA V1 with FP8 KV cache not yet supported")
def _forward_decode(
self,
q_nope: torch.Tensor,