[Model Runner V2] Refactor prefill token preparation (#29712)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-11-28 19:49:17 -08:00
committed by GitHub
parent 762a4a6ca9
commit ca1b1e7296
5 changed files with 83 additions and 78 deletions

View File

@@ -104,11 +104,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.use_async_scheduling:
self.input_prep_event = torch.cuda.Event()
self.structured_outputs_event = torch.cuda.Event()
self.spec_decode_event = torch.cuda.Event()
else:
self.input_prep_event = None
self.structured_outputs_event = None
self.spec_decode_event = None
if self.speculative_config is not None:
self.do_spec_decode = True
@@ -412,9 +410,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cu_num_new_blocks[i].append(x + len(block_ids))
new_block_ids[i].extend(block_ids)
overwrite.append(True)
# Update the GPU tensors for request states.
if scheduler_output.scheduled_new_reqs:
self.req_states.prefill_len.copy_to_gpu()
# Add new blocks for the existing requests.
cached_reqs = scheduler_output.scheduled_cached_reqs
@@ -507,16 +502,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
query_start_loc_cpu = self.input_buffers.query_start_loc.cpu[: num_reqs + 1]
query_start_loc_np = self.input_buffers.query_start_loc.np[: num_reqs + 1]
# Copy prefill tokens from CPU to GPU.
# Get prefill tokens.
prepare_prefill_inputs(
idx_mapping_np,
num_scheduled_tokens,
query_start_loc_np,
self.req_states.prefill_token_ids.np,
self.req_states.num_computed_prefill_tokens,
self.input_buffers.input_ids.np,
self.input_buffers.input_ids,
self.req_states.next_prefill_tokens,
idx_mapping,
query_start_loc_gpu,
self.req_states.prefill_token_ids.gpu,
self.req_states.prefill_len.gpu,
self.req_states.num_computed_tokens,
)
self.input_buffers.input_ids.copy_to_gpu(num_tokens)
# Prepare positions and seq_lens.
prepare_pos_seq_lens(
@@ -531,7 +526,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Some input token ids are directly read from the last sampled tokens
# and draft tokens. Also, get the logits indices to sample tokens from.
logits_indices = combine_sampled_and_draft_tokens(
self.input_buffers.input_ids.gpu,
self.input_buffers.input_ids,
idx_mapping,
self.req_states.last_sampled_tokens,
query_start_loc_gpu,
@@ -572,7 +567,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_config=self.kv_cache_config,
)
input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding]
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
positions = self.input_buffers.positions[:num_tokens_after_padding]
return InputBatch(
req_ids=req_ids,
@@ -782,20 +777,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_sampled: torch.Tensor,
num_rejected: torch.Tensor,
) -> torch.Tensor:
num_reqs = input_batch.num_reqs
idx_mapping_np = input_batch.idx_mapping_np
with async_barrier(self.spec_decode_event):
self.input_buffers.next_prefill_tokens.np[:num_reqs] = (
self.req_states.prefill_token_ids.np[
idx_mapping_np,
self.req_states.num_computed_prefill_tokens[idx_mapping_np],
]
)
next_prefill_tokens = self.input_buffers.next_prefill_tokens.copy_to_gpu(
num_reqs
)
assert self.speculator is not None
last_sampled_tokens = self.req_states.last_sampled_tokens[
input_batch.idx_mapping
]
next_prefill_tokens = self.req_states.next_prefill_tokens[
input_batch.idx_mapping
]
draft_tokens = self.speculator.propose(
input_batch,
sampling_metadata,
@@ -803,7 +791,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
aux_hidden_states,
num_sampled,
num_rejected,
self.req_states.last_sampled_tokens,
last_sampled_tokens,
next_prefill_tokens,
)
return draft_tokens