[Model Runner V2] Fix _compute_slot_mappings_kernel for chunked prefill (#36580)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -138,10 +138,8 @@ class BlockTables:
|
||||
num_tokens_padded: int,
|
||||
) -> torch.Tensor:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
num_tokens = positions.shape[0]
|
||||
num_groups = self.num_kv_cache_groups
|
||||
_compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
|
||||
num_tokens,
|
||||
self.max_num_batched_tokens,
|
||||
idx_mapping,
|
||||
query_start_loc,
|
||||
@@ -213,7 +211,6 @@ def _gather_block_tables_kernel(
|
||||
|
||||
@triton.jit
|
||||
def _compute_slot_mappings_kernel(
|
||||
num_tokens,
|
||||
max_num_tokens,
|
||||
idx_mapping, # [num_reqs]
|
||||
query_start_loc, # [num_reqs + 1]
|
||||
@@ -236,7 +233,11 @@ def _compute_slot_mappings_kernel(
|
||||
|
||||
if batch_idx == tl.num_programs(1) - 1:
|
||||
# Pad remaining slots to -1. This is needed for CUDA graphs.
|
||||
for i in range(num_tokens, max_num_tokens, TRITON_BLOCK_SIZE):
|
||||
# Start from actual token count (not padded) to cover the gap
|
||||
# between actual tokens and padded tokens that can contain stale
|
||||
# valid slot IDs from previous chunks during chunked prefill.
|
||||
actual_num_tokens = tl.load(query_start_loc + batch_idx)
|
||||
for i in range(actual_num_tokens, max_num_tokens, TRITON_BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
|
||||
tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user