[Misc] Add get_name to missing AttentionBackends (#32698)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -22,6 +22,10 @@ from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
|
||||
class GDNAttentionBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "GDN_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]:
|
||||
return GDNAttentionMetadataBuilder
|
||||
|
||||
@@ -16,6 +16,10 @@ from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
|
||||
class LinearAttentionBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "LINEAR_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]:
|
||||
return LinearAttentionMetadataBuilder
|
||||
|
||||
@@ -11,6 +11,10 @@ from vllm.v1.attention.backends.mamba_attn import (
|
||||
|
||||
|
||||
class Mamba1AttentionBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "MAMBA1_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]:
|
||||
return Mamba1AttentionMetadataBuilder
|
||||
|
||||
@@ -7,7 +7,10 @@ import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backend import AttentionBackend, CommonAttentionMetadata
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.mamba_attn import (
|
||||
BaseMambaAttentionMetadata,
|
||||
BaseMambaAttentionMetadataBuilder,
|
||||
@@ -85,6 +88,10 @@ def compute_varlen_chunk_metadata(
|
||||
|
||||
|
||||
class Mamba2AttentionBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "MAMBA2_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]:
|
||||
return Mamba2AttentionMetadataBuilder
|
||||
|
||||
@@ -25,6 +25,10 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DeepseekV32IndexerBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "DEEPSEEK_V32_INDEXER"
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [1 if current_platform.is_rocm() else 64]
|
||||
|
||||
@@ -10,6 +10,10 @@ from vllm.v1.attention.backends.mamba_attn import (
|
||||
|
||||
|
||||
class ShortConvAttentionBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "SHORT_CONV_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]:
|
||||
return ShortConvAttentionMetadataBuilder
|
||||
|
||||
Reference in New Issue
Block a user