diff --git a/vllm/v1/worker/gpu/block_table.py b/vllm/v1/worker/gpu/block_table.py index d45917d4b..ca7c68120 100644 --- a/vllm/v1/worker/gpu/block_table.py +++ b/vllm/v1/worker/gpu/block_table.py @@ -116,24 +116,26 @@ class BlockTables: def compute_slot_mappings( self, + idx_mapping: torch.Tensor, query_start_loc: torch.Tensor, positions: torch.Tensor, ) -> torch.Tensor: - num_reqs = query_start_loc.shape[0] - 1 + num_reqs = idx_mapping.shape[0] num_tokens = positions.shape[0] num_groups = self.num_kv_cache_groups _compute_slot_mappings_kernel[(num_groups, num_reqs + 1)]( num_tokens, self.max_num_batched_tokens, + idx_mapping, query_start_loc, positions, - self.input_block_table_ptrs, + self.block_table_ptrs, self.block_table_strides, self.block_sizes_tensor, self.slot_mappings, self.slot_mappings.stride(0), PAD_ID=PAD_SLOT_ID, - BLOCK_SIZE=1024, # type: ignore + TRITON_BLOCK_SIZE=1024, # type: ignore ) return self.slot_mappings[:, :num_tokens] @@ -176,42 +178,44 @@ def _gather_block_tables_kernel( def _compute_slot_mappings_kernel( num_tokens, max_num_tokens, - cu_num_tokens, # [num_reqs + 1] + idx_mapping, # [num_reqs] + query_start_loc, # [num_reqs + 1] pos, # [num_tokens] block_table_ptrs, # [num_kv_cache_groups] block_table_strides, # [num_kv_cache_groups] - page_sizes, # [num_kv_cache_groups] + block_sizes, # [num_kv_cache_groups] slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens] slot_mappings_stride, PAD_ID: tl.constexpr, - BLOCK_SIZE: tl.constexpr, + TRITON_BLOCK_SIZE: tl.constexpr, ): # kv cache group id group_id = tl.program_id(0) - req_idx = tl.program_id(1) + batch_idx = tl.program_id(1) slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride - if req_idx == tl.num_programs(1) - 1: + if batch_idx == tl.num_programs(1) - 1: # Pad remaining slots to -1. This is needed for CUDA graphs. - for i in range(num_tokens, max_num_tokens, BLOCK_SIZE): - offset = i + tl.arange(0, BLOCK_SIZE) + for i in range(num_tokens, max_num_tokens, TRITON_BLOCK_SIZE): + offset = i + tl.arange(0, TRITON_BLOCK_SIZE) tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens) return block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32) block_table_stride = tl.load(block_table_strides + group_id) - page_size = tl.load(page_sizes + group_id) + block_size = tl.load(block_sizes + group_id) - start_idx = tl.load(cu_num_tokens + req_idx) - end_idx = tl.load(cu_num_tokens + req_idx + 1) - for i in range(start_idx, end_idx, BLOCK_SIZE): - offset = i + tl.arange(0, BLOCK_SIZE) + req_state_idx = tl.load(idx_mapping + batch_idx) + start_idx = tl.load(query_start_loc + batch_idx) + end_idx = tl.load(query_start_loc + batch_idx + 1) + for i in range(start_idx, end_idx, TRITON_BLOCK_SIZE): + offset = i + tl.arange(0, TRITON_BLOCK_SIZE) positions = tl.load(pos + offset, mask=offset < end_idx, other=0) - block_indices = positions // page_size + block_indices = positions // block_size block_numbers = tl.load( - block_table_ptr + req_idx * block_table_stride + block_indices + block_table_ptr + req_state_idx * block_table_stride + block_indices ) - slot_ids = block_numbers * page_size + positions % page_size + slot_ids = block_numbers * block_size + positions % block_size tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 898d64879..4aad46385 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -607,7 +607,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Compute slot mappings: [num_kv_cache_groups, num_tokens] slot_mappings = self.block_tables.compute_slot_mappings( - query_start_loc, self.input_buffers.positions[:num_tokens] + idx_mapping, + query_start_loc, + self.input_buffers.positions[:num_tokens], ) # Layer name -> attention metadata. diff --git a/vllm/v1/worker/gpu/spec_decode/eagle.py b/vllm/v1/worker/gpu/spec_decode/eagle.py index e8eeac7ec..f86b53793 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle.py @@ -138,6 +138,7 @@ class EagleSpeculator: ) -> None: pos = self.input_buffers.positions[:num_reqs] query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] + idx_mapping = self.idx_mapping[:num_reqs] for step in range(1, self.num_speculative_steps): # Run the eagle model. last_hidden_states, hidden_states = self.run_model( @@ -149,7 +150,7 @@ class EagleSpeculator: # used for draft and target sampling. draft_tokens = gumbel_sample( logits, - self.idx_mapping[:num_reqs], + idx_mapping, self.temperature, self.seeds, pos + 1, @@ -166,7 +167,9 @@ class EagleSpeculator: self.hidden_states, self.max_model_len, ) - self.block_tables.compute_slot_mappings(query_start_loc, pos) + self.block_tables.compute_slot_mappings( + idx_mapping, query_start_loc, pos + ) def capture_model(self) -> None: if self.num_speculative_steps == 1: @@ -279,7 +282,9 @@ class EagleSpeculator: self.max_num_reqs, ) query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] - slot_mappings = self.block_tables.compute_slot_mappings(query_start_loc, pos) + slot_mappings = self.block_tables.compute_slot_mappings( + idx_mapping, query_start_loc, pos + ) cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs) if cudagraph_size is not None: