[Attention] Refactor FA block_size limitations to hybrid models only (#29084)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -154,7 +154,6 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
torch.bfloat16,
|
||||
torch.float32,
|
||||
]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"fp8",
|
||||
@@ -162,6 +161,10 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
"fp8_e5m2",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_ATTN"
|
||||
|
||||
Reference in New Issue
Block a user