Reenable features for ROCm attention backends (#36185)
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
@@ -252,7 +252,7 @@ class AttentionBackend(ABC):
|
||||
else:
|
||||
invalid_reasons.append("non-MLA not supported")
|
||||
if has_sink and not cls.supports_sink():
|
||||
invalid_reasons.append("sink setting not supported")
|
||||
invalid_reasons.append("attention sinks not supported")
|
||||
if use_sparse != cls.is_sparse():
|
||||
if use_sparse:
|
||||
invalid_reasons.append("sparse not supported")
|
||||
|
||||
@@ -8,6 +8,7 @@ import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
@@ -21,6 +22,15 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
|
||||
class AiterMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
"fp8_e5m2",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [1]
|
||||
|
||||
@@ -9,6 +9,7 @@ import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
get_mla_dims,
|
||||
@@ -21,6 +22,7 @@ from vllm.v1.attention.backend import (
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
MultipleOf,
|
||||
SparseMLAAttentionImpl,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
@@ -77,7 +79,15 @@ def fetch_id_to_ragged_triton(
|
||||
|
||||
class ROCMAiterMLASparseBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [1]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
@@ -105,10 +115,6 @@ class ROCMAiterMLASparseBackend(AttentionBackend):
|
||||
) -> tuple[int, ...]:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [576]
|
||||
|
||||
@classmethod
|
||||
def is_mla(cls) -> bool:
|
||||
return True
|
||||
@@ -117,11 +123,6 @@ class ROCMAiterMLASparseBackend(AttentionBackend):
|
||||
def is_sparse(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_block_size(cls, block_size: int | None) -> bool:
|
||||
# The only supported block_size is 1
|
||||
return block_size is None or block_size == 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ROCMAiterMLASparseMetadata(AttentionMetadata):
|
||||
|
||||
@@ -45,11 +45,6 @@ class TritonMLABackend(MLACommonBackend):
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_block_size(cls, block_size: int | None) -> bool:
|
||||
# The only unsupported block_size is 1
|
||||
return block_size is None or block_size != 1
|
||||
|
||||
|
||||
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
@@ -9,6 +9,7 @@ import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.platforms import current_platform
|
||||
@@ -732,6 +733,13 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
class AiterFlashAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
"fp8_e5m2",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
|
||||
@@ -11,7 +11,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.v1.attention.backend import AttentionLayer, AttentionType
|
||||
from vllm.v1.attention.backend import AttentionLayer, AttentionType, MultipleOf
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.rocm_attn import (
|
||||
RocmAttentionBackend,
|
||||
@@ -25,6 +25,22 @@ logger = init_logger(__name__)
|
||||
class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def supports_head_size(cls, head_size: int) -> bool:
|
||||
return head_size >= 32
|
||||
|
||||
@classmethod
|
||||
def supports_mm_prefix(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_sink(cls) -> bool:
|
||||
return True
|
||||
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -9,6 +9,7 @@ import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
@@ -163,6 +164,13 @@ class RocmAttentionBackend(AttentionBackend):
|
||||
torch.bfloat16,
|
||||
torch.float32,
|
||||
]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
"fp8_e5m2",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
@@ -185,15 +193,12 @@ class RocmAttentionBackend(AttentionBackend):
|
||||
return [32, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
if not cls.supports_head_size(head_size):
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {cls.get_supported_head_sizes()}. "
|
||||
"Set --attention-backend=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes."
|
||||
)
|
||||
def supports_mm_prefix(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_sink(cls) -> bool:
|
||||
return True
|
||||
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
@@ -275,8 +280,6 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
RocmAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
self.fp8_dtype = current_platform.fp8_dtype()
|
||||
|
||||
self.sinks = sinks
|
||||
|
||||
Reference in New Issue
Block a user