[Kernel] Optimize Sliding Window Attention in 3D Triton Kernel (#31984)

Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
This commit is contained in:
jvlunteren
2026-01-10 19:13:44 +01:00
committed by GitHub
parent e6c6f2c79d
commit b8bf5c45bb

View File

@@ -545,10 +545,33 @@ def kernel_unified_attention_3d(
# this prefix can be skipped)
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
# iterate through tiles within current segment
# ---- Sliding-window tile pruning --------------------
# Default: keep previous global behavior
tile_start = 0
tile_end = num_tiles
# TODO(Isotr0py): sliding window pruning with image bidirectional mask
if SLIDING_WINDOW > 0 and not USE_MM_PREFIX:
# Query rows covered by this Q-block
qpos_lo = q_block_local_idx * BLOCK_Q
qpos_hi = tl.minimum(
qpos_lo + (BLOCK_M - 1) // num_queries_per_kv,
cur_batch_query_len - 1,
)
# For sliding window, each query position q can only attend to
# keys in the range [q_abs - SLIDING_WINDOW + 1, q_abs]
# where q_abs = context_len + q
# The union of allowed key positions for this Q-block is:
# [context_len + qpos_lo - SLIDING_WINDOW + 1, context_len + qpos_hi]
first_allowed_key = context_len + qpos_lo - SLIDING_WINDOW + 1
last_allowed_key = context_len + qpos_hi
# Convert to tile indices and clamp
tile_start = tl.maximum(0, first_allowed_key // TILE_SIZE)
tile_end = tl.minimum((last_allowed_key // TILE_SIZE) + 1, num_tiles)
# iterate through tiles (now limited to the sliding window range)
for j in range(
segm_idx * tiles_per_segment,
min((segm_idx + 1) * tiles_per_segment, num_tiles),
max(segm_idx * tiles_per_segment, tile_start),
min((segm_idx + 1) * tiles_per_segment, tile_end),
):
seq_offset = j * TILE_SIZE + offs_t
tile_mask = seq_offset < max_seq_prefix_len