[Model Runner V2] Minor refactor for compute_slot_mappings (#32794)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -116,24 +116,26 @@ class BlockTables:
|
||||
|
||||
def compute_slot_mappings(
|
||||
self,
|
||||
idx_mapping: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
num_reqs = query_start_loc.shape[0] - 1
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
num_tokens = positions.shape[0]
|
||||
num_groups = self.num_kv_cache_groups
|
||||
_compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
|
||||
num_tokens,
|
||||
self.max_num_batched_tokens,
|
||||
idx_mapping,
|
||||
query_start_loc,
|
||||
positions,
|
||||
self.input_block_table_ptrs,
|
||||
self.block_table_ptrs,
|
||||
self.block_table_strides,
|
||||
self.block_sizes_tensor,
|
||||
self.slot_mappings,
|
||||
self.slot_mappings.stride(0),
|
||||
PAD_ID=PAD_SLOT_ID,
|
||||
BLOCK_SIZE=1024, # type: ignore
|
||||
TRITON_BLOCK_SIZE=1024, # type: ignore
|
||||
)
|
||||
return self.slot_mappings[:, :num_tokens]
|
||||
|
||||
@@ -176,42 +178,44 @@ def _gather_block_tables_kernel(
|
||||
def _compute_slot_mappings_kernel(
|
||||
num_tokens,
|
||||
max_num_tokens,
|
||||
cu_num_tokens, # [num_reqs + 1]
|
||||
idx_mapping, # [num_reqs]
|
||||
query_start_loc, # [num_reqs + 1]
|
||||
pos, # [num_tokens]
|
||||
block_table_ptrs, # [num_kv_cache_groups]
|
||||
block_table_strides, # [num_kv_cache_groups]
|
||||
page_sizes, # [num_kv_cache_groups]
|
||||
block_sizes, # [num_kv_cache_groups]
|
||||
slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens]
|
||||
slot_mappings_stride,
|
||||
PAD_ID: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
TRITON_BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# kv cache group id
|
||||
group_id = tl.program_id(0)
|
||||
req_idx = tl.program_id(1)
|
||||
batch_idx = tl.program_id(1)
|
||||
slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride
|
||||
|
||||
if req_idx == tl.num_programs(1) - 1:
|
||||
if batch_idx == tl.num_programs(1) - 1:
|
||||
# Pad remaining slots to -1. This is needed for CUDA graphs.
|
||||
for i in range(num_tokens, max_num_tokens, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
for i in range(num_tokens, max_num_tokens, TRITON_BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
|
||||
tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
|
||||
return
|
||||
|
||||
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
|
||||
block_table_stride = tl.load(block_table_strides + group_id)
|
||||
page_size = tl.load(page_sizes + group_id)
|
||||
block_size = tl.load(block_sizes + group_id)
|
||||
|
||||
start_idx = tl.load(cu_num_tokens + req_idx)
|
||||
end_idx = tl.load(cu_num_tokens + req_idx + 1)
|
||||
for i in range(start_idx, end_idx, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
req_state_idx = tl.load(idx_mapping + batch_idx)
|
||||
start_idx = tl.load(query_start_loc + batch_idx)
|
||||
end_idx = tl.load(query_start_loc + batch_idx + 1)
|
||||
for i in range(start_idx, end_idx, TRITON_BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
|
||||
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
|
||||
block_indices = positions // page_size
|
||||
block_indices = positions // block_size
|
||||
block_numbers = tl.load(
|
||||
block_table_ptr + req_idx * block_table_stride + block_indices
|
||||
block_table_ptr + req_state_idx * block_table_stride + block_indices
|
||||
)
|
||||
slot_ids = block_numbers * page_size + positions % page_size
|
||||
slot_ids = block_numbers * block_size + positions % block_size
|
||||
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
|
||||
|
||||
|
||||
|
||||
@@ -607,7 +607,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
query_start_loc, self.input_buffers.positions[:num_tokens]
|
||||
idx_mapping,
|
||||
query_start_loc,
|
||||
self.input_buffers.positions[:num_tokens],
|
||||
)
|
||||
|
||||
# Layer name -> attention metadata.
|
||||
|
||||
@@ -138,6 +138,7 @@ class EagleSpeculator:
|
||||
) -> None:
|
||||
pos = self.input_buffers.positions[:num_reqs]
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||
idx_mapping = self.idx_mapping[:num_reqs]
|
||||
for step in range(1, self.num_speculative_steps):
|
||||
# Run the eagle model.
|
||||
last_hidden_states, hidden_states = self.run_model(
|
||||
@@ -149,7 +150,7 @@ class EagleSpeculator:
|
||||
# used for draft and target sampling.
|
||||
draft_tokens = gumbel_sample(
|
||||
logits,
|
||||
self.idx_mapping[:num_reqs],
|
||||
idx_mapping,
|
||||
self.temperature,
|
||||
self.seeds,
|
||||
pos + 1,
|
||||
@@ -166,7 +167,9 @@ class EagleSpeculator:
|
||||
self.hidden_states,
|
||||
self.max_model_len,
|
||||
)
|
||||
self.block_tables.compute_slot_mappings(query_start_loc, pos)
|
||||
self.block_tables.compute_slot_mappings(
|
||||
idx_mapping, query_start_loc, pos
|
||||
)
|
||||
|
||||
def capture_model(self) -> None:
|
||||
if self.num_speculative_steps == 1:
|
||||
@@ -279,7 +282,9 @@ class EagleSpeculator:
|
||||
self.max_num_reqs,
|
||||
)
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(query_start_loc, pos)
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
idx_mapping, query_start_loc, pos
|
||||
)
|
||||
|
||||
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
|
||||
if cudagraph_size is not None:
|
||||
|
||||
Reference in New Issue
Block a user