[Model Runner V2] Refactor prefill token preparation (#29712)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-11-28 19:49:17 -08:00
committed by GitHub
parent 762a4a6ca9
commit ca1b1e7296
5 changed files with 83 additions and 78 deletions

View File

@@ -78,7 +78,7 @@ class CudaGraphManager:
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
) -> None: ) -> None:
num_reqs = min(num_tokens, self.max_num_reqs) num_reqs = min(num_tokens, self.max_num_reqs)
input_ids = input_buffers.input_ids.gpu[:num_tokens] input_ids = input_buffers.input_ids[:num_tokens]
positions = input_buffers.positions[:num_tokens] positions = input_buffers.positions[:num_tokens]
attn_metadata = prepare_inputs_to_capture( attn_metadata = prepare_inputs_to_capture(
num_reqs, num_reqs,

View File

@@ -3,7 +3,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
import numba
import numpy as np import numpy as np
import torch import torch
@@ -30,15 +29,12 @@ class InputBuffers:
self.pin_memory = pin_memory self.pin_memory = pin_memory
self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32) self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32) self.input_ids = torch.zeros(max_num_tokens, dtype=torch.int32, device=device)
self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device) self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device)
self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32) self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device) self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
self.cu_num_logits = self._make_buffer(max_num_reqs + 1, dtype=torch.int32) self.cu_num_logits = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
# Spec decoding.
self.next_prefill_tokens = self._make_buffer(max_num_reqs, dtype=torch.int32)
# Structured outputs. # Structured outputs.
self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32) self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32)
self.grammar_bitmask = self._make_buffer( self.grammar_bitmask = self._make_buffer(
@@ -120,7 +116,7 @@ class InputBatch:
input_buffers.seq_lens[num_reqs:] = 0 input_buffers.seq_lens[num_reqs:] = 0
seq_lens = input_buffers.seq_lens[:num_reqs] seq_lens = input_buffers.seq_lens[:num_reqs]
input_ids = input_buffers.input_ids.copy_to_gpu(num_tokens) input_ids = input_buffers.input_ids[:num_tokens]
positions = input_buffers.positions[:num_tokens] positions = input_buffers.positions[:num_tokens]
# attn_metadata = defaultdict(lambda: None) # attn_metadata = defaultdict(lambda: None)
logits_indices = query_start_loc[1:] - 1 logits_indices = query_start_loc[1:] - 1
@@ -146,41 +142,63 @@ class InputBatch:
) )
@numba.njit(cache=True) @triton.jit
def _prepare_prefill_inputs( def _prepare_prefill_inputs_kernel(
idx_mapping: np.ndarray, # [B] input_ids_ptr,
query_lens: np.ndarray, # [B] next_prefill_tokens_ptr,
query_start_loc: np.ndarray, # [B + 1] idx_mapping_ptr,
prefill_token_ids: np.ndarray, # [N, max_model_len] query_start_loc_ptr,
num_computed_prefill_tokens: np.ndarray, # [N] prefill_token_ids_ptr,
input_ids: np.ndarray, # [num_input_tokens] prefill_token_ids_stride,
) -> None: prefill_lens_ptr,
num_reqs = idx_mapping.shape[0] num_computed_tokens_ptr,
query_starts = query_start_loc[:num_reqs] BLOCK_SIZE: tl.constexpr,
query_ends = query_start_loc[1 : num_reqs + 1] ):
starts = num_computed_prefill_tokens[idx_mapping] batch_idx = tl.program_id(0)
ends = starts + query_lens req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
for i in range(num_reqs): prefill_len = tl.load(prefill_lens_ptr + req_state_idx)
input_ids[query_starts[i] : query_ends[i]] = prefill_token_ids[ num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
idx_mapping[i], starts[i] : ends[i] if num_computed >= prefill_len:
] # Not prefill.
return
query_start = tl.load(query_start_loc_ptr + batch_idx)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
query_len = query_end - query_start
prefill_ptr = prefill_token_ids_ptr + req_state_idx * prefill_token_ids_stride
for i in range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
tokens = tl.load(prefill_ptr + num_computed + block, mask=mask)
tl.store(input_ids_ptr + query_start + block, tokens, mask=mask)
next_pos = num_computed + query_len
if next_pos < prefill_len:
next_token = tl.load(prefill_ptr + next_pos)
tl.store(next_prefill_tokens_ptr + req_state_idx, next_token)
def prepare_prefill_inputs( def prepare_prefill_inputs(
idx_mapping: np.ndarray, input_ids: torch.Tensor,
num_scheduled_tokens: np.ndarray, next_prefill_tokens: torch.Tensor,
query_start_loc: np.ndarray, idx_mapping: torch.Tensor,
prefill_token_ids: np.ndarray, query_start_loc: torch.Tensor,
num_computed_prefill_tokens: np.ndarray, prefill_token_ids: torch.Tensor,
input_ids: np.ndarray, prefill_len: torch.Tensor,
num_computed_tokens: torch.Tensor,
) -> None: ) -> None:
_prepare_prefill_inputs( num_reqs = idx_mapping.shape[0]
_prepare_prefill_inputs_kernel[(num_reqs,)](
input_ids,
next_prefill_tokens,
idx_mapping, idx_mapping,
num_scheduled_tokens,
query_start_loc, query_start_loc,
prefill_token_ids, prefill_token_ids,
num_computed_prefill_tokens, prefill_token_ids.stride(0),
input_ids, prefill_len,
num_computed_tokens,
BLOCK_SIZE=1024,
) )

View File

@@ -104,11 +104,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.use_async_scheduling: if self.use_async_scheduling:
self.input_prep_event = torch.cuda.Event() self.input_prep_event = torch.cuda.Event()
self.structured_outputs_event = torch.cuda.Event() self.structured_outputs_event = torch.cuda.Event()
self.spec_decode_event = torch.cuda.Event()
else: else:
self.input_prep_event = None self.input_prep_event = None
self.structured_outputs_event = None self.structured_outputs_event = None
self.spec_decode_event = None
if self.speculative_config is not None: if self.speculative_config is not None:
self.do_spec_decode = True self.do_spec_decode = True
@@ -412,9 +410,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cu_num_new_blocks[i].append(x + len(block_ids)) cu_num_new_blocks[i].append(x + len(block_ids))
new_block_ids[i].extend(block_ids) new_block_ids[i].extend(block_ids)
overwrite.append(True) overwrite.append(True)
# Update the GPU tensors for request states.
if scheduler_output.scheduled_new_reqs:
self.req_states.prefill_len.copy_to_gpu()
# Add new blocks for the existing requests. # Add new blocks for the existing requests.
cached_reqs = scheduler_output.scheduled_cached_reqs cached_reqs = scheduler_output.scheduled_cached_reqs
@@ -507,16 +502,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
query_start_loc_cpu = self.input_buffers.query_start_loc.cpu[: num_reqs + 1] query_start_loc_cpu = self.input_buffers.query_start_loc.cpu[: num_reqs + 1]
query_start_loc_np = self.input_buffers.query_start_loc.np[: num_reqs + 1] query_start_loc_np = self.input_buffers.query_start_loc.np[: num_reqs + 1]
# Copy prefill tokens from CPU to GPU. # Get prefill tokens.
prepare_prefill_inputs( prepare_prefill_inputs(
idx_mapping_np, self.input_buffers.input_ids,
num_scheduled_tokens, self.req_states.next_prefill_tokens,
query_start_loc_np, idx_mapping,
self.req_states.prefill_token_ids.np, query_start_loc_gpu,
self.req_states.num_computed_prefill_tokens, self.req_states.prefill_token_ids.gpu,
self.input_buffers.input_ids.np, self.req_states.prefill_len.gpu,
self.req_states.num_computed_tokens,
) )
self.input_buffers.input_ids.copy_to_gpu(num_tokens)
# Prepare positions and seq_lens. # Prepare positions and seq_lens.
prepare_pos_seq_lens( prepare_pos_seq_lens(
@@ -531,7 +526,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Some input token ids are directly read from the last sampled tokens # Some input token ids are directly read from the last sampled tokens
# and draft tokens. Also, get the logits indices to sample tokens from. # and draft tokens. Also, get the logits indices to sample tokens from.
logits_indices = combine_sampled_and_draft_tokens( logits_indices = combine_sampled_and_draft_tokens(
self.input_buffers.input_ids.gpu, self.input_buffers.input_ids,
idx_mapping, idx_mapping,
self.req_states.last_sampled_tokens, self.req_states.last_sampled_tokens,
query_start_loc_gpu, query_start_loc_gpu,
@@ -572,7 +567,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_config=self.kv_cache_config, kv_cache_config=self.kv_cache_config,
) )
input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding] input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
positions = self.input_buffers.positions[:num_tokens_after_padding] positions = self.input_buffers.positions[:num_tokens_after_padding]
return InputBatch( return InputBatch(
req_ids=req_ids, req_ids=req_ids,
@@ -782,20 +777,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_sampled: torch.Tensor, num_sampled: torch.Tensor,
num_rejected: torch.Tensor, num_rejected: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
num_reqs = input_batch.num_reqs
idx_mapping_np = input_batch.idx_mapping_np
with async_barrier(self.spec_decode_event):
self.input_buffers.next_prefill_tokens.np[:num_reqs] = (
self.req_states.prefill_token_ids.np[
idx_mapping_np,
self.req_states.num_computed_prefill_tokens[idx_mapping_np],
]
)
next_prefill_tokens = self.input_buffers.next_prefill_tokens.copy_to_gpu(
num_reqs
)
assert self.speculator is not None assert self.speculator is not None
last_sampled_tokens = self.req_states.last_sampled_tokens[
input_batch.idx_mapping
]
next_prefill_tokens = self.req_states.next_prefill_tokens[
input_batch.idx_mapping
]
draft_tokens = self.speculator.propose( draft_tokens = self.speculator.propose(
input_batch, input_batch,
sampling_metadata, sampling_metadata,
@@ -803,7 +791,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
aux_hidden_states, aux_hidden_states,
num_sampled, num_sampled,
num_rejected, num_rejected,
self.req_states.last_sampled_tokens, last_sampled_tokens,
next_prefill_tokens, next_prefill_tokens,
) )
return draft_tokens return draft_tokens

View File

@@ -121,7 +121,7 @@ class EagleSpeculator:
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
): ):
ret_hidden_states = self.model( ret_hidden_states = self.model(
input_ids=self.input_buffers.input_ids.gpu[:num_tokens], input_ids=self.input_buffers.input_ids[:num_tokens],
positions=self.input_buffers.positions[:num_tokens], positions=self.input_buffers.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens], hidden_states=self.hidden_states[:num_tokens],
) )
@@ -194,7 +194,7 @@ class EagleSpeculator:
num_sampled: torch.Tensor, num_sampled: torch.Tensor,
# [num_reqs] # [num_reqs]
num_rejected: torch.Tensor, num_rejected: torch.Tensor,
# [max_num_reqs, 1] # [num_reqs]
last_sampled: torch.Tensor, last_sampled: torch.Tensor,
# [num_reqs] # [num_reqs]
next_prefill_tokens: torch.Tensor, next_prefill_tokens: torch.Tensor,
@@ -316,7 +316,6 @@ def _prepare_eagle_inputs_kernel(
eagle_positions_ptr, eagle_positions_ptr,
target_input_ids_ptr, target_input_ids_ptr,
target_positions_ptr, target_positions_ptr,
idx_mapping_ptr,
last_sampled_ptr, last_sampled_ptr,
next_prefill_tokens_ptr, next_prefill_tokens_ptr,
num_sampled_ptr, num_sampled_ptr,
@@ -335,8 +334,7 @@ def _prepare_eagle_inputs_kernel(
num_sampled = tl.load(num_sampled_ptr + batch_idx) num_sampled = tl.load(num_sampled_ptr + batch_idx)
if num_sampled > 0: if num_sampled > 0:
req_state_idx = tl.load(idx_mapping_ptr + batch_idx) next_token = tl.load(last_sampled_ptr + batch_idx).to(tl.int32)
next_token = tl.load(last_sampled_ptr + req_state_idx).to(tl.int32)
else: else:
# Chunked prefilling. # Chunked prefilling.
# Get the next prefill token. # Get the next prefill token.
@@ -368,9 +366,9 @@ def prepare_eagle_inputs(
num_sampled: torch.Tensor, num_sampled: torch.Tensor,
# [num_reqs] # [num_reqs]
num_rejected: torch.Tensor, num_rejected: torch.Tensor,
# [max_num_reqs, 1] # [num_reqs]
last_sampled: torch.Tensor, last_sampled: torch.Tensor,
# [max_num_reqs] # [num_reqs]
next_prefill_tokens: torch.Tensor, next_prefill_tokens: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
num_reqs = input_batch.num_reqs num_reqs = input_batch.num_reqs
@@ -381,11 +379,10 @@ def prepare_eagle_inputs(
) )
_prepare_eagle_inputs_kernel[(num_reqs,)]( _prepare_eagle_inputs_kernel[(num_reqs,)](
last_token_indices, last_token_indices,
input_buffers.input_ids.gpu, input_buffers.input_ids,
input_buffers.positions, input_buffers.positions,
input_batch.input_ids, input_batch.input_ids,
input_batch.positions, input_batch.positions,
input_batch.idx_mapping,
last_sampled, last_sampled,
next_prefill_tokens, next_prefill_tokens,
num_sampled, num_sampled,
@@ -485,7 +482,7 @@ def prepare_eagle_decode(
last_token_indices, last_token_indices,
target_seq_lens, target_seq_lens,
num_rejected, num_rejected,
input_buffers.input_ids.gpu, input_buffers.input_ids,
input_buffers.positions, input_buffers.positions,
input_hidden_states, input_hidden_states,
input_hidden_states.stride(0), input_hidden_states.stride(0),
@@ -553,7 +550,7 @@ def update_eagle_inputs(
): ):
num_reqs, hidden_size = output_hidden_states.shape num_reqs, hidden_size = output_hidden_states.shape
_update_eagle_inputs_kernel[(num_reqs,)]( _update_eagle_inputs_kernel[(num_reqs,)](
input_buffers.input_ids.gpu, input_buffers.input_ids,
input_buffers.positions, input_buffers.positions,
hidden_states, hidden_states,
hidden_states.stride(0), hidden_states.stride(0),

View File

@@ -117,8 +117,7 @@ class RequestState:
self.prefill_token_ids = UvaBuffer( self.prefill_token_ids = UvaBuffer(
self.max_num_reqs, self.max_model_len, dtype=torch.int32 self.max_num_reqs, self.max_model_len, dtype=torch.int32
) )
self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32) self.prefill_len = UvaBuffer(self.max_num_reqs, dtype=torch.int32)
# Number of computed tokens. # Number of computed tokens.
self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32) self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
self.num_computed_tokens = torch.zeros( self.num_computed_tokens = torch.zeros(
@@ -140,6 +139,9 @@ class RequestState:
dtype=torch.int64, dtype=torch.int64,
device=device, device=device,
) )
self.next_prefill_tokens = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=device
)
# LoRA. # LoRA.
self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32) self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32)
@@ -380,13 +382,13 @@ def _expand_sampling_metadata_kernel(
expanded_top_p_ptr, expanded_top_p_ptr,
top_k_ptr, top_k_ptr,
expanded_top_k_ptr, expanded_top_k_ptr,
seeds_ptr,
rep_penalty_ptr, rep_penalty_ptr,
expanded_rep_penalty_ptr, expanded_rep_penalty_ptr,
freq_penalty_ptr, freq_penalty_ptr,
expanded_freq_penalty_ptr, expanded_freq_penalty_ptr,
pres_penalty_ptr, pres_penalty_ptr,
expanded_pres_penalty_ptr, expanded_pres_penalty_ptr,
seeds_ptr,
expanded_seeds_ptr, expanded_seeds_ptr,
cu_num_logits_ptr, cu_num_logits_ptr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,