Triton Attention: Support cross-layers blocks (#30687)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
@@ -15,12 +15,10 @@ from vllm.distributed.kv_events import BlockStored, KVEventBatch
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
CPU_BLOCK_SIZES = [48]
|
||||
ATTN_BACKENDS = ["FLASH_ATTN"]
|
||||
ATTN_BACKENDS = ["FLASH_ATTN", "TRITON_ATTN"]
|
||||
|
||||
if current_platform.is_cuda():
|
||||
ATTN_BACKENDS.append("FLASHINFER")
|
||||
elif current_platform.is_rocm():
|
||||
ATTN_BACKENDS = ["TRITON_ATTN"]
|
||||
|
||||
|
||||
class MockSubscriber:
|
||||
|
||||
@@ -290,6 +290,19 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (num_blocks, 2, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order(
|
||||
include_num_layers_dimension: bool = False,
|
||||
) -> tuple[int, ...]:
|
||||
# `stride_order` indicates the permutation that gets
|
||||
# us from `get_kv_cache_shape` to the actual memory layout we want.
|
||||
if include_num_layers_dimension:
|
||||
# (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
|
||||
return (1, 0, 2, 3, 4, 5)
|
||||
|
||||
# (num_blocks, 2, block_size, num_kv_heads, head_size)
|
||||
return (0, 1, 2, 3, 4)
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user