From b8bf5c45bbbb1249bd1656557603da71e9d4c3b9 Mon Sep 17 00:00:00 2001 From: jvlunteren <161835099+jvlunteren@users.noreply.github.com> Date: Sat, 10 Jan 2026 19:13:44 +0100 Subject: [PATCH] [Kernel] Optimize Sliding Window Attention in 3D Triton Kernel (#31984) Signed-off-by: Jan van Lunteren --- .../attention/ops/triton_unified_attention.py | 29 +++++++++++++++++-- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/ops/triton_unified_attention.py b/vllm/v1/attention/ops/triton_unified_attention.py index c946dbd8a..345889969 100644 --- a/vllm/v1/attention/ops/triton_unified_attention.py +++ b/vllm/v1/attention/ops/triton_unified_attention.py @@ -545,10 +545,33 @@ def kernel_unified_attention_3d( # this prefix can be skipped) num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) - # iterate through tiles within current segment + # ---- Sliding-window tile pruning -------------------- + # Default: keep previous global behavior + tile_start = 0 + tile_end = num_tiles + # TODO(Isotr0py): sliding window pruning with image bidirectional mask + if SLIDING_WINDOW > 0 and not USE_MM_PREFIX: + # Query rows covered by this Q-block + qpos_lo = q_block_local_idx * BLOCK_Q + qpos_hi = tl.minimum( + qpos_lo + (BLOCK_M - 1) // num_queries_per_kv, + cur_batch_query_len - 1, + ) + # For sliding window, each query position q can only attend to + # keys in the range [q_abs - SLIDING_WINDOW + 1, q_abs] + # where q_abs = context_len + q + # The union of allowed key positions for this Q-block is: + # [context_len + qpos_lo - SLIDING_WINDOW + 1, context_len + qpos_hi] + first_allowed_key = context_len + qpos_lo - SLIDING_WINDOW + 1 + last_allowed_key = context_len + qpos_hi + # Convert to tile indices and clamp + tile_start = tl.maximum(0, first_allowed_key // TILE_SIZE) + tile_end = tl.minimum((last_allowed_key // TILE_SIZE) + 1, num_tiles) + + # iterate through tiles (now limited to the sliding window range) for j in range( - segm_idx * tiles_per_segment, - min((segm_idx + 1) * tiles_per_segment, num_tiles), + max(segm_idx * tiles_per_segment, tile_start), + min((segm_idx + 1) * tiles_per_segment, tile_end), ): seq_offset = j * TILE_SIZE + offs_t tile_mask = seq_offset < max_seq_prefix_len