Triton Attention: Support cross-layers blocks (#30687)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Or Ozeri
2026-01-05 21:29:16 +02:00
committed by GitHub
parent 21156ff199
commit d8e38d4939
2 changed files with 14 additions and 3 deletions

View File

@@ -290,6 +290,19 @@ class TritonAttentionBackend(AttentionBackend):
raise ValueError("Block size must be a multiple of 16.")
return (num_blocks, 2, block_size, num_kv_heads, head_size)
@staticmethod
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
if include_num_layers_dimension:
# (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
return (1, 0, 2, 3, 4, 5)
# (num_blocks, 2, block_size, num_kv_heads, head_size)
return (0, 1, 2, 3, 4)
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False