[V1][Bugfix] Standardize quantized kv cache rejection for attention backends (#14221)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -294,3 +294,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
|||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
|
||||||
|
return kv_cache_dtype != "auto"
|
||||||
|
|||||||
@@ -8,11 +8,15 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
# yapf conflicts with isort for this block
|
||||||
|
# yapf: disable
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionLayer,
|
AttentionLayer,
|
||||||
AttentionMetadata,
|
AttentionMetadata,
|
||||||
AttentionMetadataBuilder,
|
AttentionMetadataBuilder,
|
||||||
AttentionType)
|
AttentionType,
|
||||||
|
is_quantized_kv_cache)
|
||||||
|
# yapf: enable
|
||||||
from vllm.attention.backends.utils import (
|
from vllm.attention.backends.utils import (
|
||||||
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
|
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
|
||||||
compute_slot_mapping_start_idx, get_flash_attn_version,
|
compute_slot_mapping_start_idx, get_flash_attn_version,
|
||||||
@@ -626,6 +630,9 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
self.sliding_window = ((sliding_window - 1,
|
self.sliding_window = ((sliding_window - 1,
|
||||||
0) if sliding_window is not None else (-1, -1))
|
0) if sliding_window is not None else (-1, -1))
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"FlashAttention with FP8 KV cache not yet supported")
|
||||||
if logits_soft_cap is None:
|
if logits_soft_cap is None:
|
||||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||||
logits_soft_cap = 0
|
logits_soft_cap = 0
|
||||||
|
|||||||
@@ -6,7 +6,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionType
|
from vllm.attention.backends.abstract import (AttentionType,
|
||||||
|
is_quantized_kv_cache)
|
||||||
from vllm.attention.backends.mla.common import (MLACommonBackend,
|
from vllm.attention.backends.mla.common import (MLACommonBackend,
|
||||||
MLACommonImpl,
|
MLACommonImpl,
|
||||||
MLACommonMetadata,
|
MLACommonMetadata,
|
||||||
@@ -207,6 +208,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
|||||||
"are not implemented for "
|
"are not implemented for "
|
||||||
"FlashMLAImpl")
|
"FlashMLAImpl")
|
||||||
|
|
||||||
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"FlashMLA with FP8 KV cache not yet supported")
|
||||||
|
|
||||||
def _forward_decode(
|
def _forward_decode(
|
||||||
self,
|
self,
|
||||||
q_nope: torch.Tensor,
|
q_nope: torch.Tensor,
|
||||||
@@ -215,8 +220,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
|||||||
attn_metadata: FlashMLAMetadata,
|
attn_metadata: FlashMLAMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert kv_c_and_k_pe_cache.numel() > 0
|
assert kv_c_and_k_pe_cache.numel() > 0
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
|
||||||
raise NotImplementedError("FP8 FlashMLA not yet supported")
|
|
||||||
|
|
||||||
decode_meta = attn_metadata.decode_metadata
|
decode_meta = attn_metadata.decode_metadata
|
||||||
assert decode_meta is not None
|
assert decode_meta is not None
|
||||||
|
|||||||
@@ -15,7 +15,8 @@ from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax,
|
|||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionLayer,
|
AttentionLayer,
|
||||||
AttentionMetadata, AttentionType)
|
AttentionMetadata, AttentionType,
|
||||||
|
is_quantized_kv_cache)
|
||||||
from vllm.attention.backends.utils import CommonAttentionState
|
from vllm.attention.backends.utils import CommonAttentionState
|
||||||
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
|
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
|
||||||
HPUPagedAttentionMetadata)
|
HPUPagedAttentionMetadata)
|
||||||
@@ -158,6 +159,10 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
|||||||
"are not implemented for "
|
"are not implemented for "
|
||||||
"HPUAttentionImpl")
|
"HPUAttentionImpl")
|
||||||
|
|
||||||
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"HPUAttention with FP8 KV cache not yet supported")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
layer: AttentionLayer,
|
layer: AttentionLayer,
|
||||||
|
|||||||
@@ -9,7 +9,8 @@ import torch
|
|||||||
from vllm._ipex_ops import ipex_ops
|
from vllm._ipex_ops import ipex_ops
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionLayer,
|
AttentionLayer,
|
||||||
AttentionMetadata, AttentionType)
|
AttentionMetadata, AttentionType,
|
||||||
|
is_quantized_kv_cache)
|
||||||
from vllm.attention.backends.utils import CommonAttentionState
|
from vllm.attention.backends.utils import CommonAttentionState
|
||||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||||
PagedAttentionMetadata)
|
PagedAttentionMetadata)
|
||||||
@@ -145,7 +146,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Head size {head_size} is not supported by PagedAttention. "
|
f"Head size {head_size} is not supported by PagedAttention. "
|
||||||
f"Supported head sizes are: {supported_head_sizes}.")
|
f"Supported head sizes are: {supported_head_sizes}.")
|
||||||
if kv_cache_dtype != "auto":
|
if is_quantized_kv_cache(kv_cache_dtype):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"IPEX backend does not support FP8 KV cache. "
|
"IPEX backend does not support FP8 KV cache. "
|
||||||
"Please use xFormers backend instead.")
|
"Please use xFormers backend instead.")
|
||||||
|
|||||||
@@ -8,7 +8,8 @@ import torch_xla.experimental.custom_kernel # Required to register custom ops.
|
|||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionLayer,
|
AttentionLayer,
|
||||||
AttentionMetadata, AttentionType)
|
AttentionMetadata, AttentionType,
|
||||||
|
is_quantized_kv_cache)
|
||||||
from vllm.attention.backends.utils import CommonAttentionState
|
from vllm.attention.backends.utils import CommonAttentionState
|
||||||
|
|
||||||
|
|
||||||
@@ -119,7 +120,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
raise NotImplementedError("Alibi slopes is not supported.")
|
raise NotImplementedError("Alibi slopes is not supported.")
|
||||||
if sliding_window is not None:
|
if sliding_window is not None:
|
||||||
raise NotImplementedError("Sliding window is not supported.")
|
raise NotImplementedError("Sliding window is not supported.")
|
||||||
if kv_cache_dtype != "auto":
|
if is_quantized_kv_cache(kv_cache_dtype):
|
||||||
raise NotImplementedError("FP8 KV cache dtype is not supported.")
|
raise NotImplementedError("FP8 KV cache dtype is not supported.")
|
||||||
if blocksparse_params is not None:
|
if blocksparse_params is not None:
|
||||||
raise NotImplementedError("Blocksparse is not supported.")
|
raise NotImplementedError("Blocksparse is not supported.")
|
||||||
|
|||||||
@@ -7,11 +7,15 @@ from typing import Any, Dict, List, Optional, Tuple, Type
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn.functional import scaled_dot_product_attention
|
from torch.nn.functional import scaled_dot_product_attention
|
||||||
|
|
||||||
|
# yapf conflicts with isort for this block
|
||||||
|
# yapf: disable
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionLayer,
|
AttentionLayer,
|
||||||
AttentionMetadata,
|
AttentionMetadata,
|
||||||
AttentionMetadataBuilder,
|
AttentionMetadataBuilder,
|
||||||
AttentionType)
|
AttentionType,
|
||||||
|
is_quantized_kv_cache)
|
||||||
|
# yapf: enable
|
||||||
from vllm.attention.backends.utils import CommonAttentionState
|
from vllm.attention.backends.utils import CommonAttentionState
|
||||||
from vllm.attention.ops.ipex_attn import PagedAttention
|
from vllm.attention.ops.ipex_attn import PagedAttention
|
||||||
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
|
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
|
||||||
@@ -427,7 +431,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Head size {head_size} is not supported by PagedAttention. "
|
f"Head size {head_size} is not supported by PagedAttention. "
|
||||||
f"Supported head sizes are: {supported_head_sizes}.")
|
f"Supported head sizes are: {supported_head_sizes}.")
|
||||||
if kv_cache_dtype != "auto":
|
if is_quantized_kv_cache(kv_cache_dtype):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Torch SDPA backend does not support FP8 KV cache. "
|
"Torch SDPA backend does not support FP8 KV cache. "
|
||||||
"Please use xFormers backend instead.")
|
"Please use xFormers backend instead.")
|
||||||
|
|||||||
@@ -4,7 +4,8 @@ from typing import Any, Dict, List, Optional, Type
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionType
|
from vllm.attention.backends.abstract import (AttentionType,
|
||||||
|
is_quantized_kv_cache)
|
||||||
from vllm.attention.backends.mla.common import (MLACommonBackend,
|
from vllm.attention.backends.mla.common import (MLACommonBackend,
|
||||||
MLACommonImpl,
|
MLACommonImpl,
|
||||||
MLACommonMetadata)
|
MLACommonMetadata)
|
||||||
@@ -58,6 +59,10 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
"are not implemented for "
|
"are not implemented for "
|
||||||
"TritonMLAImpl")
|
"TritonMLAImpl")
|
||||||
|
|
||||||
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"TritonMLA with FP8 KV cache not yet supported")
|
||||||
|
|
||||||
def _forward_decode(
|
def _forward_decode(
|
||||||
self,
|
self,
|
||||||
q_nope: torch.Tensor,
|
q_nope: torch.Tensor,
|
||||||
@@ -66,8 +71,6 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
attn_metadata: MLACommonMetadata,
|
attn_metadata: MLACommonMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert kv_c_and_k_pe_cache.numel() > 0
|
assert kv_c_and_k_pe_cache.numel() > 0
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
|
||||||
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
|
||||||
|
|
||||||
decode_meta = attn_metadata.decode_metadata
|
decode_meta = attn_metadata.decode_metadata
|
||||||
assert decode_meta is not None
|
assert decode_meta is not None
|
||||||
|
|||||||
@@ -7,7 +7,8 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata, AttentionType)
|
AttentionMetadata, AttentionType,
|
||||||
|
is_quantized_kv_cache)
|
||||||
from vllm.attention.backends.utils import get_flash_attn_version
|
from vllm.attention.backends.utils import get_flash_attn_version
|
||||||
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@@ -180,6 +181,9 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
else:
|
else:
|
||||||
self.sliding_window = (sliding_window - 1, 0)
|
self.sliding_window = (sliding_window - 1, 0)
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"FlashAttention V1 with FP8 KV cache not yet supported")
|
||||||
if logits_soft_cap is None:
|
if logits_soft_cap is None:
|
||||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||||
logits_soft_cap = 0
|
logits_soft_cap = 0
|
||||||
|
|||||||
@@ -5,7 +5,8 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionType
|
from vllm.attention.backends.abstract import (AttentionType,
|
||||||
|
is_quantized_kv_cache)
|
||||||
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
||||||
get_mla_metadata,
|
get_mla_metadata,
|
||||||
is_flashmla_supported)
|
is_flashmla_supported)
|
||||||
@@ -115,6 +116,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
|||||||
"are not implemented for "
|
"are not implemented for "
|
||||||
"FlashMLAImpl")
|
"FlashMLAImpl")
|
||||||
|
|
||||||
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"FlashMLA V1 with FP8 KV cache not yet supported")
|
||||||
|
|
||||||
def _forward_decode(
|
def _forward_decode(
|
||||||
self,
|
self,
|
||||||
q_nope: torch.Tensor,
|
q_nope: torch.Tensor,
|
||||||
@@ -125,9 +130,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
|||||||
assert kv_c_and_k_pe_cache.numel() > 0
|
assert kv_c_and_k_pe_cache.numel() > 0
|
||||||
assert attn_metadata.decode is not None
|
assert attn_metadata.decode is not None
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
|
||||||
raise NotImplementedError("FP8 FlashMLA not yet supported")
|
|
||||||
|
|
||||||
q = torch.cat([q_nope, q_pe], dim=-1)\
|
q = torch.cat([q_nope, q_pe], dim=-1)\
|
||||||
.unsqueeze(1) # Add seqlen dim of 1 (decode)
|
.unsqueeze(1) # Add seqlen dim of 1 (decode)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,8 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionType
|
from vllm.attention.backends.abstract import (AttentionType,
|
||||||
|
is_quantized_kv_cache)
|
||||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||||
@@ -61,6 +62,10 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
"are not implemented for "
|
"are not implemented for "
|
||||||
"TritonMLAImpl")
|
"TritonMLAImpl")
|
||||||
|
|
||||||
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"TritonMLA V1 with FP8 KV cache not yet supported")
|
||||||
|
|
||||||
def _forward_decode(
|
def _forward_decode(
|
||||||
self,
|
self,
|
||||||
q_nope: torch.Tensor,
|
q_nope: torch.Tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user