[Kernel] Optimize Sliding Window Attention in 3D Triton Kernel (#31984)
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
This commit is contained in:
@@ -545,10 +545,33 @@ def kernel_unified_attention_3d(
|
||||
# this prefix can be skipped)
|
||||
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
|
||||
|
||||
# iterate through tiles within current segment
|
||||
# ---- Sliding-window tile pruning --------------------
|
||||
# Default: keep previous global behavior
|
||||
tile_start = 0
|
||||
tile_end = num_tiles
|
||||
# TODO(Isotr0py): sliding window pruning with image bidirectional mask
|
||||
if SLIDING_WINDOW > 0 and not USE_MM_PREFIX:
|
||||
# Query rows covered by this Q-block
|
||||
qpos_lo = q_block_local_idx * BLOCK_Q
|
||||
qpos_hi = tl.minimum(
|
||||
qpos_lo + (BLOCK_M - 1) // num_queries_per_kv,
|
||||
cur_batch_query_len - 1,
|
||||
)
|
||||
# For sliding window, each query position q can only attend to
|
||||
# keys in the range [q_abs - SLIDING_WINDOW + 1, q_abs]
|
||||
# where q_abs = context_len + q
|
||||
# The union of allowed key positions for this Q-block is:
|
||||
# [context_len + qpos_lo - SLIDING_WINDOW + 1, context_len + qpos_hi]
|
||||
first_allowed_key = context_len + qpos_lo - SLIDING_WINDOW + 1
|
||||
last_allowed_key = context_len + qpos_hi
|
||||
# Convert to tile indices and clamp
|
||||
tile_start = tl.maximum(0, first_allowed_key // TILE_SIZE)
|
||||
tile_end = tl.minimum((last_allowed_key // TILE_SIZE) + 1, num_tiles)
|
||||
|
||||
# iterate through tiles (now limited to the sliding window range)
|
||||
for j in range(
|
||||
segm_idx * tiles_per_segment,
|
||||
min((segm_idx + 1) * tiles_per_segment, num_tiles),
|
||||
max(segm_idx * tiles_per_segment, tile_start),
|
||||
min((segm_idx + 1) * tiles_per_segment, tile_end),
|
||||
):
|
||||
seq_offset = j * TILE_SIZE + offs_t
|
||||
tile_mask = seq_offset < max_seq_prefix_len
|
||||
|
||||
Reference in New Issue
Block a user