diff --git a/vllm/v1/worker/gpu/block_table.py b/vllm/v1/worker/gpu/block_table.py index 9dfdf834d..b06a35805 100644 --- a/vllm/v1/worker/gpu/block_table.py +++ b/vllm/v1/worker/gpu/block_table.py @@ -119,6 +119,10 @@ class BlockTables: return tuple(block_table[:num_reqs] for block_table in self.input_block_tables) def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]: + # NOTE(woosuk): The output may be used for CUDA graph capture. + # Therefore, this method must return the persistent tensor + # with the same memory address as that used during the model's forward pass, + # rather than allocating a new tensor. return tuple(block_table[:num_reqs] for block_table in self.input_block_tables) def compute_slot_mappings( @@ -150,7 +154,14 @@ class BlockTables: return self.slot_mappings[:, :num_tokens] def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor: + # Fill the entire slot_mappings tensor, not just the first `num_tokens` entries. + # This is because the padding logic is complex and kernels may access beyond + # the requested range. self.slot_mappings.fill_(PAD_SLOT_ID) + # NOTE(woosuk): The output may be used for CUDA graph capture. + # Therefore, this method must return the persistent tensor + # with the same memory address as that used during the model's forward pass, + # rather than allocating a new tensor. return self.slot_mappings[:, :num_tokens] diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index 783715cfe..c9ae28abf 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -420,8 +420,8 @@ def prepare_inputs_to_capture( input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens input_buffers.dcp_local_seq_lens[num_reqs:] = 0 - input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables] - slot_mappings = block_tables.slot_mappings[:, :num_tokens] + input_block_tables = block_tables.get_dummy_block_tables(num_reqs) + slot_mappings = block_tables.get_dummy_slot_mappings(num_tokens) slot_mappings_by_layer = build_slot_mappings_by_layer( slot_mappings, kv_cache_config )