diff --git a/vllm/v1/worker/gpu/block_table.py b/vllm/v1/worker/gpu/block_table.py index 5a1edc076..3a2c0562a 100644 --- a/vllm/v1/worker/gpu/block_table.py +++ b/vllm/v1/worker/gpu/block_table.py @@ -138,10 +138,8 @@ class BlockTables: num_tokens_padded: int, ) -> torch.Tensor: num_reqs = idx_mapping.shape[0] - num_tokens = positions.shape[0] num_groups = self.num_kv_cache_groups _compute_slot_mappings_kernel[(num_groups, num_reqs + 1)]( - num_tokens, self.max_num_batched_tokens, idx_mapping, query_start_loc, @@ -213,7 +211,6 @@ def _gather_block_tables_kernel( @triton.jit def _compute_slot_mappings_kernel( - num_tokens, max_num_tokens, idx_mapping, # [num_reqs] query_start_loc, # [num_reqs + 1] @@ -236,7 +233,11 @@ def _compute_slot_mappings_kernel( if batch_idx == tl.num_programs(1) - 1: # Pad remaining slots to -1. This is needed for CUDA graphs. - for i in range(num_tokens, max_num_tokens, TRITON_BLOCK_SIZE): + # Start from actual token count (not padded) to cover the gap + # between actual tokens and padded tokens that can contain stale + # valid slot IDs from previous chunks during chunked prefill. + actual_num_tokens = tl.load(query_start_loc + batch_idx) + for i in range(actual_num_tokens, max_num_tokens, TRITON_BLOCK_SIZE): offset = i + tl.arange(0, TRITON_BLOCK_SIZE) tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens) return