[Model Runner V2] Fix _compute_slot_mappings_kernel for chunked prefill (#36580)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-03-10 00:23:42 -07:00
committed by GitHub
parent 156e33553c
commit 9efc3bdcd6

View File

@@ -138,10 +138,8 @@ class BlockTables:
num_tokens_padded: int,
) -> torch.Tensor:
num_reqs = idx_mapping.shape[0]
num_tokens = positions.shape[0]
num_groups = self.num_kv_cache_groups
_compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
num_tokens,
self.max_num_batched_tokens,
idx_mapping,
query_start_loc,
@@ -213,7 +211,6 @@ def _gather_block_tables_kernel(
@triton.jit
def _compute_slot_mappings_kernel(
num_tokens,
max_num_tokens,
idx_mapping, # [num_reqs]
query_start_loc, # [num_reqs + 1]
@@ -236,7 +233,11 @@ def _compute_slot_mappings_kernel(
if batch_idx == tl.num_programs(1) - 1:
# Pad remaining slots to -1. This is needed for CUDA graphs.
for i in range(num_tokens, max_num_tokens, TRITON_BLOCK_SIZE):
# Start from actual token count (not padded) to cover the gap
# between actual tokens and padded tokens that can contain stale
# valid slot IDs from previous chunks during chunked prefill.
actual_num_tokens = tl.load(query_start_loc + batch_idx)
for i in range(actual_num_tokens, max_num_tokens, TRITON_BLOCK_SIZE):
offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
return