[Bugfix][Kernel] Give unique name to BlockSparseFlashAttention (#12040)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -89,8 +89,7 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
# For attention layer compatibility
|
||||
return "FLASH_ATTN"
|
||||
return "BLOCK_SPARSE_FLASH_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
|
||||
|
||||
@@ -33,6 +33,7 @@ class _Backend(enum.Enum):
|
||||
HPU_ATTN = enum.auto()
|
||||
PALLAS = enum.auto()
|
||||
IPEX = enum.auto()
|
||||
BLOCK_SPARSE_FLASH_ATTN = enum.auto()
|
||||
NO_ATTENTION = enum.auto()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user