[V0 deprecation] Remove _VLLM_V1 suffixes from attention backend names (#25489)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
This commit is contained in:
@@ -364,7 +364,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
self.impl.process_weights_after_loading(act_dtype)
|
||||
|
||||
# FlashInfer requires attention sinks to be float32
|
||||
if (self.backend == _Backend.FLASHINFER_VLLM_V1
|
||||
if (self.backend == _Backend.FLASHINFER
|
||||
and hasattr(self.impl, 'sinks')):
|
||||
from vllm.v1.attention.backends.flashinfer import FlashInferImpl
|
||||
assert isinstance(self.impl, FlashInferImpl)
|
||||
@@ -420,21 +420,17 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
self.attn_backend = backend if backend in {
|
||||
_Backend.TORCH_SDPA,
|
||||
_Backend.TORCH_SDPA_VLLM_V1,
|
||||
_Backend.XFORMERS,
|
||||
_Backend.PALLAS_VLLM_V1,
|
||||
_Backend.PALLAS,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.FLASH_ATTN_VLLM_V1,
|
||||
} else _Backend.TORCH_SDPA
|
||||
|
||||
if (self.attn_backend == _Backend.XFORMERS
|
||||
and not check_xformers_availability()):
|
||||
self.attn_backend = _Backend.TORCH_SDPA
|
||||
|
||||
if self.attn_backend in {
|
||||
_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1
|
||||
}:
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
if use_upstream_fa:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
self._flash_attn_varlen_func = flash_attn_varlen_func
|
||||
@@ -468,11 +464,7 @@ class MultiHeadAttention(nn.Module):
|
||||
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
||||
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
||||
|
||||
if self.attn_backend in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.FLASH_ATTN_VLLM_V1,
|
||||
}:
|
||||
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
|
||||
step=q_len,
|
||||
dtype=torch.int32,
|
||||
@@ -499,8 +491,7 @@ class MultiHeadAttention(nn.Module):
|
||||
key,
|
||||
value,
|
||||
scale=self.scale)
|
||||
elif (self.attn_backend == _Backend.TORCH_SDPA
|
||||
or self.attn_backend == _Backend.TORCH_SDPA_VLLM_V1):
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
query, key, value = (x.transpose(1, 2)
|
||||
for x in (query, key, value))
|
||||
out = F.scaled_dot_product_attention(query,
|
||||
@@ -508,7 +499,7 @@ class MultiHeadAttention(nn.Module):
|
||||
value,
|
||||
scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
elif self.attn_backend == _Backend.PALLAS_VLLM_V1:
|
||||
elif self.attn_backend == _Backend.PALLAS:
|
||||
query, key, value = (x.transpose(1, 2)
|
||||
for x in (query, key, value))
|
||||
from torch_xla.experimental.custom_kernel import flash_attention
|
||||
|
||||
@@ -186,6 +186,14 @@ def _cached_get_attn_backend(
|
||||
# Check the environment variable and override if specified
|
||||
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||
if backend_by_env_var is not None:
|
||||
if backend_by_env_var.endswith("_VLLM_V1"):
|
||||
logger.warning(
|
||||
"The suffix '_VLLM_V1' in the environment variable "
|
||||
"%s is no longer necessary as V0 backends have been "
|
||||
"deprecated. Please remove this suffix from your "
|
||||
"environment variable setting.", STR_BACKEND_ENV_VAR)
|
||||
backend_by_env_var = backend_by_env_var.removesuffix(
|
||||
"_VLLM_V1")
|
||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||
if selected_backend is None:
|
||||
raise ValueError(
|
||||
|
||||
@@ -577,8 +577,8 @@ class NixlConnectorWorker:
|
||||
use_mla=self.use_mla)
|
||||
self.backend_name = backend.get_name()
|
||||
attn_backend = backend_name_to_enum(self.backend_name)
|
||||
self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1
|
||||
self._use_pallas_v1 = attn_backend == _Backend.PALLAS_VLLM_V1
|
||||
self._use_flashinfer = attn_backend == _Backend.FLASHINFER
|
||||
self._use_pallas = attn_backend == _Backend.PALLAS
|
||||
self.kv_cache_layout = get_kv_cache_layout()
|
||||
logger.debug("Detected attention backend %s", self.backend_name)
|
||||
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
|
||||
@@ -749,7 +749,7 @@ class NixlConnectorWorker:
|
||||
# (roughly 8KB vs 5KB).
|
||||
# Conversely for FlashInfer, K and V are registered in the same region
|
||||
# to better exploit the memory layout (ie num_blocks is the first dim).
|
||||
split_k_and_v = not (self.use_mla or self._use_pallas_v1
|
||||
split_k_and_v = not (self.use_mla or self._use_pallas
|
||||
or self._use_flashinfer)
|
||||
tensor_size_bytes = None
|
||||
for layer_name, cache_or_caches in xfer_buffers.items():
|
||||
@@ -938,7 +938,7 @@ class NixlConnectorWorker:
|
||||
tp_ratio = divide(self._tp_size[self.engine_id],
|
||||
self._tp_size[engine_id])
|
||||
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
|
||||
assert not self._use_pallas_v1 or tp_ratio == 1, \
|
||||
assert not self._use_pallas or tp_ratio == 1, \
|
||||
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
|
||||
|
||||
# Handle tp_size>num_kv_heads: replicate KV cache.
|
||||
|
||||
@@ -1479,25 +1479,21 @@ class EngineArgs:
|
||||
"such as ngram, medusa, eagle, or deepseek_mtp.")
|
||||
|
||||
V1_BACKENDS = [
|
||||
"FLASH_ATTN_VLLM_V1",
|
||||
"FLASH_ATTN",
|
||||
"PALLAS",
|
||||
"PALLAS_VLLM_V1",
|
||||
"TRITON_ATTN_VLLM_V1",
|
||||
"TRITON_ATTN",
|
||||
"TRITON_MLA",
|
||||
"CUTLASS_MLA",
|
||||
"FLASHMLA",
|
||||
"FLASHMLA_VLLM_V1",
|
||||
"FLASH_ATTN_MLA",
|
||||
"FLASHINFER",
|
||||
"FLASHINFER_VLLM_V1",
|
||||
"FLASHINFER_MLA",
|
||||
"ROCM_AITER_MLA",
|
||||
"TORCH_SDPA_VLLM_V1",
|
||||
"TORCH_SDPA",
|
||||
"FLEX_ATTENTION",
|
||||
"TREE_ATTN",
|
||||
"XFORMERS_VLLM_V1",
|
||||
"ROCM_ATTN_VLLM_V1",
|
||||
"XFORMERS",
|
||||
"ROCM_ATTN",
|
||||
]
|
||||
if (envs.is_set("VLLM_ATTENTION_BACKEND")
|
||||
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
|
||||
|
||||
@@ -42,7 +42,7 @@ def kernel_warmup(worker: "Worker"):
|
||||
# and is not a pooling model
|
||||
def _is_flashinfer_backend(backend):
|
||||
try:
|
||||
return backend.get_name() == "FLASHINFER_VLLM_V1"
|
||||
return backend.get_name() == "FLASHINFER"
|
||||
except NotImplementedError:
|
||||
return False
|
||||
|
||||
|
||||
@@ -241,9 +241,8 @@ class CudaPlatformBase(Platform):
|
||||
use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or (
|
||||
selected_backend is None and cls.is_device_capability(100)
|
||||
and block_size in [32, 64])
|
||||
use_flashmla = selected_backend in [
|
||||
_Backend.FLASHMLA, _Backend.FLASHMLA_VLLM_V1
|
||||
] or (selected_backend is None and is_flashmla_supported()[0])
|
||||
use_flashmla = selected_backend == _Backend.FLASHMLA or (
|
||||
selected_backend is None and is_flashmla_supported()[0])
|
||||
use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or (
|
||||
selected_backend is None and flash_attn_supports_mla())
|
||||
use_triton = selected_backend == _Backend.TRITON_MLA or (
|
||||
@@ -282,7 +281,7 @@ class CudaPlatformBase(Platform):
|
||||
if use_v1:
|
||||
FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
|
||||
FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
|
||||
TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
|
||||
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
|
||||
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
|
||||
TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
|
||||
XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501
|
||||
@@ -300,16 +299,16 @@ class CudaPlatformBase(Platform):
|
||||
elif selected_backend == _Backend.FLEX_ATTENTION:
|
||||
logger.info_once("Using FlexAttention backend on V1 engine.")
|
||||
return FLEX_ATTENTION_V1
|
||||
elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
|
||||
elif selected_backend == _Backend.TRITON_ATTN:
|
||||
logger.info_once("Using Triton backend on V1 engine.")
|
||||
return TRITON_ATTN_VLLM_V1
|
||||
return TRITON_ATTN
|
||||
elif selected_backend == _Backend.FLASH_ATTN:
|
||||
logger.info_once("Using Flash Attention backend on V1 engine.")
|
||||
return FLASH_ATTN_V1
|
||||
elif selected_backend == _Backend.TREE_ATTN:
|
||||
logger.info_once("Using Tree Attention backend on V1 engine.")
|
||||
return TREE_ATTN_V1
|
||||
elif selected_backend == _Backend.XFORMERS_VLLM_V1:
|
||||
elif selected_backend == _Backend.XFORMERS:
|
||||
logger.info_once("Using XFormers backend on V1 engine.")
|
||||
return XFORMERS_V1
|
||||
|
||||
@@ -341,7 +340,7 @@ class CudaPlatformBase(Platform):
|
||||
if (has_sink or
|
||||
use_fp8_kv_cache) and not cls.is_device_capability(90):
|
||||
logger.info_once("Using Triton backend on V1 engine.")
|
||||
return TRITON_ATTN_VLLM_V1
|
||||
return TRITON_ATTN
|
||||
elif is_default_backend_supported := is_attn_backend_supported(
|
||||
FLASH_ATTN_V1, head_size, dtype,
|
||||
allow_import_error=False):
|
||||
@@ -457,12 +456,12 @@ class CudaPlatformBase(Platform):
|
||||
else:
|
||||
# Default to FlashAttention
|
||||
if attention_backend is None:
|
||||
attention_backend = "FLASH_ATTN_VLLM_V1"
|
||||
attention_backend = "FLASH_ATTN"
|
||||
|
||||
# All Blackwell backends support fp8
|
||||
if cls.is_device_capability(100):
|
||||
supported = True
|
||||
elif attention_backend == "FLASH_ATTN_VLLM_V1":
|
||||
elif attention_backend == "FLASH_ATTN":
|
||||
if fp8_attention:
|
||||
from vllm.attention.utils.fa_utils import (
|
||||
flash_attn_supports_fp8)
|
||||
@@ -471,7 +470,7 @@ class CudaPlatformBase(Platform):
|
||||
supported = True
|
||||
elif attention_backend == "FLASHINFER":
|
||||
supported = True
|
||||
elif attention_backend == "TRITON_ATTN_VLLM_V1":
|
||||
elif attention_backend == "TRITON_ATTN":
|
||||
supported = cls.supports_fp8()
|
||||
return supported
|
||||
|
||||
|
||||
@@ -40,34 +40,26 @@ def in_wsl() -> bool:
|
||||
|
||||
class _Backend(enum.Enum):
|
||||
FLASH_ATTN = enum.auto()
|
||||
FLASH_ATTN_VLLM_V1 = enum.auto()
|
||||
TRITON_ATTN_VLLM_V1 = enum.auto()
|
||||
TRITON_ATTN = enum.auto()
|
||||
XFORMERS = enum.auto()
|
||||
ROCM_FLASH = enum.auto()
|
||||
ROCM_AITER_MLA = enum.auto() # Supported by V1
|
||||
ROCM_AITER_MLA_VLLM_V1 = enum.auto()
|
||||
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
|
||||
TORCH_SDPA = enum.auto()
|
||||
TORCH_SDPA_VLLM_V1 = enum.auto()
|
||||
FLASHINFER = enum.auto()
|
||||
FLASHINFER_VLLM_V1 = enum.auto()
|
||||
FLASHINFER_MLA = enum.auto()
|
||||
TRITON_MLA = enum.auto() # Supported by V1
|
||||
TRITON_MLA_VLLM_V1 = enum.auto()
|
||||
CUTLASS_MLA = enum.auto()
|
||||
FLASHMLA = enum.auto() # Supported by V1
|
||||
FLASHMLA_VLLM_V1 = enum.auto()
|
||||
FLASH_ATTN_MLA = enum.auto() # Supported by V1
|
||||
PALLAS = enum.auto()
|
||||
PALLAS_VLLM_V1 = enum.auto()
|
||||
IPEX = enum.auto()
|
||||
DUAL_CHUNK_FLASH_ATTN = enum.auto()
|
||||
DIFFERENTIAL_FLASH_ATTN = enum.auto()
|
||||
NO_ATTENTION = enum.auto()
|
||||
FLEX_ATTENTION = enum.auto()
|
||||
TREE_ATTN = enum.auto()
|
||||
XFORMERS_VLLM_V1 = enum.auto()
|
||||
ROCM_ATTN_VLLM_V1 = enum.auto()
|
||||
ROCM_ATTN = enum.auto()
|
||||
|
||||
|
||||
class PlatformEnum(enum.Enum):
|
||||
|
||||
@@ -218,8 +218,7 @@ class RocmPlatform(Platform):
|
||||
raise ValueError(
|
||||
f" The selected backend, {selected_backend.name},"
|
||||
f"does not support block size {block_size}.")
|
||||
if selected_backend in (_Backend.ROCM_AITER_MLA,
|
||||
_Backend.ROCM_AITER_MLA_VLLM_V1):
|
||||
if selected_backend == _Backend.ROCM_AITER_MLA:
|
||||
if block_size == 1:
|
||||
logger.info("Using AITER MLA backend on V1 engine.")
|
||||
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
|
||||
@@ -240,7 +239,7 @@ class RocmPlatform(Platform):
|
||||
elif (envs.VLLM_ROCM_USE_AITER and
|
||||
envs.VLLM_USE_AITER_UNIFIED_ATTENTION) or \
|
||||
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION or \
|
||||
selected_backend == _Backend.ROCM_ATTN_VLLM_V1:
|
||||
selected_backend == _Backend.ROCM_ATTN:
|
||||
# rocm specific backend, with aiter and/or
|
||||
# triton prefix-prefill
|
||||
logger.info("Using Rocm/Aiter Attention backend on V1 engine.")
|
||||
|
||||
@@ -50,8 +50,7 @@ class TpuPlatform(Platform):
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool, use_mla: bool,
|
||||
has_sink) -> str:
|
||||
if (selected_backend != _Backend.PALLAS
|
||||
and selected_backend != _Backend.PALLAS_VLLM_V1):
|
||||
if selected_backend != _Backend.PALLAS:
|
||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||
|
||||
if not use_v1:
|
||||
|
||||
@@ -40,14 +40,14 @@ class XPUPlatform(Platform):
|
||||
use_v1 = envs.VLLM_USE_V1
|
||||
if not use_v1:
|
||||
raise ValueError("XPU backend only supports V1.")
|
||||
TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
|
||||
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
|
||||
if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
|
||||
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
|
||||
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
|
||||
if selected_backend == _Backend.TRITON_ATTN:
|
||||
logger.info_once("Using Triton backend on V1 engine.")
|
||||
return TRITON_ATTN_VLLM_V1
|
||||
return TRITON_ATTN
|
||||
elif selected_backend == _Backend.FLASH_ATTN:
|
||||
logger.info_once("Using Flash Attention backend on V1 engine.")
|
||||
return FLASH_ATTN_V1
|
||||
return FLASH_ATTN
|
||||
elif selected_backend:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {cls.device_name}, "
|
||||
@@ -64,7 +64,7 @@ class XPUPlatform(Platform):
|
||||
XPU only support fp8 kv cache with triton backend.
|
||||
"""
|
||||
if envs.is_set("VLLM_ATTENTION_BACKEND") and \
|
||||
envs.VLLM_ATTENTION_BACKEND == "TRITON_ATTN_VLLM_V1":
|
||||
envs.VLLM_ATTENTION_BACKEND == "TRITON_ATTN":
|
||||
return kv_cache_dtype in ["fp8_e4m3", "fp8_e5m2", "fp8"]
|
||||
|
||||
return False
|
||||
|
||||
@@ -54,7 +54,7 @@ class TorchSDPABackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TORCH_SDPA_VLLM_V1"
|
||||
return "TORCH_SDPA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TorchSDPABackendImpl"]:
|
||||
|
||||
@@ -60,7 +60,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN_VLLM_V1"
|
||||
return "FLASH_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashAttentionImpl"]:
|
||||
|
||||
@@ -167,7 +167,7 @@ class FlashInferBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHINFER_VLLM_V1"
|
||||
return "FLASHINFER"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type[FlashInferImpl]:
|
||||
|
||||
@@ -270,7 +270,7 @@ class MLACommonBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_MLA_VLLM_V1"
|
||||
return "TRITON_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
|
||||
@@ -27,7 +27,7 @@ class FlashMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA_VLLM_V1"
|
||||
return "FLASHMLA"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["FlashMLAMetadata"]:
|
||||
|
||||
@@ -33,7 +33,7 @@ class AiterMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_AITER_MLA_VLLM_V1"
|
||||
return "ROCM_AITER_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["AiterMLAImpl"]:
|
||||
|
||||
@@ -24,7 +24,7 @@ class TritonMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_MLA_VLLM_V1"
|
||||
return "TRITON_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TritonMLAImpl"]:
|
||||
|
||||
@@ -86,7 +86,7 @@ class PallasAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "PALLAS_VLLM_V1"
|
||||
return "PALLAS"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
|
||||
|
||||
@@ -340,7 +340,7 @@ class AiterFlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN_VLLM_V1"
|
||||
return "FLASH_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["AiterFlashAttentionImpl"]:
|
||||
|
||||
@@ -159,7 +159,7 @@ class RocmAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_ATTN_VLLM_V1"
|
||||
return "ROCM_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["RocmAttentionImpl"]:
|
||||
|
||||
@@ -52,7 +52,7 @@ class TreeAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TREE_ATTN_VLLM_V1"
|
||||
return "TREE_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TreeAttentionImpl"]:
|
||||
|
||||
@@ -155,7 +155,7 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_ATTN_VLLM_V1"
|
||||
return "TRITON_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TritonAttentionImpl"]:
|
||||
|
||||
@@ -90,7 +90,7 @@ class XFormersAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "XFORMERS_VLLM_V1"
|
||||
return "XFORMERS"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["XFormersAttentionImpl"]:
|
||||
|
||||
Reference in New Issue
Block a user