[Model Runner V2] Minor refactor for compute_slot_mappings (#32794)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2026-01-21 10:24:35 -08:00
committed by GitHub
parent 9b693d023c
commit e1da249c93
3 changed files with 33 additions and 22 deletions

View File

@@ -116,24 +116,26 @@ class BlockTables:
def compute_slot_mappings(
self,
idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
num_reqs = query_start_loc.shape[0] - 1
num_reqs = idx_mapping.shape[0]
num_tokens = positions.shape[0]
num_groups = self.num_kv_cache_groups
_compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
num_tokens,
self.max_num_batched_tokens,
idx_mapping,
query_start_loc,
positions,
self.input_block_table_ptrs,
self.block_table_ptrs,
self.block_table_strides,
self.block_sizes_tensor,
self.slot_mappings,
self.slot_mappings.stride(0),
PAD_ID=PAD_SLOT_ID,
BLOCK_SIZE=1024, # type: ignore
TRITON_BLOCK_SIZE=1024, # type: ignore
)
return self.slot_mappings[:, :num_tokens]
@@ -176,42 +178,44 @@ def _gather_block_tables_kernel(
def _compute_slot_mappings_kernel(
num_tokens,
max_num_tokens,
cu_num_tokens, # [num_reqs + 1]
idx_mapping, # [num_reqs]
query_start_loc, # [num_reqs + 1]
pos, # [num_tokens]
block_table_ptrs, # [num_kv_cache_groups]
block_table_strides, # [num_kv_cache_groups]
page_sizes, # [num_kv_cache_groups]
block_sizes, # [num_kv_cache_groups]
slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens]
slot_mappings_stride,
PAD_ID: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
TRITON_BLOCK_SIZE: tl.constexpr,
):
# kv cache group id
group_id = tl.program_id(0)
req_idx = tl.program_id(1)
batch_idx = tl.program_id(1)
slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride
if req_idx == tl.num_programs(1) - 1:
if batch_idx == tl.num_programs(1) - 1:
# Pad remaining slots to -1. This is needed for CUDA graphs.
for i in range(num_tokens, max_num_tokens, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
for i in range(num_tokens, max_num_tokens, TRITON_BLOCK_SIZE):
offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
return
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
block_table_stride = tl.load(block_table_strides + group_id)
page_size = tl.load(page_sizes + group_id)
block_size = tl.load(block_sizes + group_id)
start_idx = tl.load(cu_num_tokens + req_idx)
end_idx = tl.load(cu_num_tokens + req_idx + 1)
for i in range(start_idx, end_idx, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
req_state_idx = tl.load(idx_mapping + batch_idx)
start_idx = tl.load(query_start_loc + batch_idx)
end_idx = tl.load(query_start_loc + batch_idx + 1)
for i in range(start_idx, end_idx, TRITON_BLOCK_SIZE):
offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
block_indices = positions // page_size
block_indices = positions // block_size
block_numbers = tl.load(
block_table_ptr + req_idx * block_table_stride + block_indices
block_table_ptr + req_state_idx * block_table_stride + block_indices
)
slot_ids = block_numbers * page_size + positions % page_size
slot_ids = block_numbers * block_size + positions % block_size
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)

View File

@@ -607,7 +607,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings(
query_start_loc, self.input_buffers.positions[:num_tokens]
idx_mapping,
query_start_loc,
self.input_buffers.positions[:num_tokens],
)
# Layer name -> attention metadata.

View File

@@ -138,6 +138,7 @@ class EagleSpeculator:
) -> None:
pos = self.input_buffers.positions[:num_reqs]
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
idx_mapping = self.idx_mapping[:num_reqs]
for step in range(1, self.num_speculative_steps):
# Run the eagle model.
last_hidden_states, hidden_states = self.run_model(
@@ -149,7 +150,7 @@ class EagleSpeculator:
# used for draft and target sampling.
draft_tokens = gumbel_sample(
logits,
self.idx_mapping[:num_reqs],
idx_mapping,
self.temperature,
self.seeds,
pos + 1,
@@ -166,7 +167,9 @@ class EagleSpeculator:
self.hidden_states,
self.max_model_len,
)
self.block_tables.compute_slot_mappings(query_start_loc, pos)
self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
)
def capture_model(self) -> None:
if self.num_speculative_steps == 1:
@@ -279,7 +282,9 @@ class EagleSpeculator:
self.max_num_reqs,
)
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
slot_mappings = self.block_tables.compute_slot_mappings(query_start_loc, pos)
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
)
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
if cudagraph_size is not None: