[Model Runner V2] Refactor prefill token preparation (#29712)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -104,11 +104,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if self.use_async_scheduling:
|
||||
self.input_prep_event = torch.cuda.Event()
|
||||
self.structured_outputs_event = torch.cuda.Event()
|
||||
self.spec_decode_event = torch.cuda.Event()
|
||||
else:
|
||||
self.input_prep_event = None
|
||||
self.structured_outputs_event = None
|
||||
self.spec_decode_event = None
|
||||
|
||||
if self.speculative_config is not None:
|
||||
self.do_spec_decode = True
|
||||
@@ -412,9 +410,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
cu_num_new_blocks[i].append(x + len(block_ids))
|
||||
new_block_ids[i].extend(block_ids)
|
||||
overwrite.append(True)
|
||||
# Update the GPU tensors for request states.
|
||||
if scheduler_output.scheduled_new_reqs:
|
||||
self.req_states.prefill_len.copy_to_gpu()
|
||||
|
||||
# Add new blocks for the existing requests.
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
@@ -507,16 +502,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
query_start_loc_cpu = self.input_buffers.query_start_loc.cpu[: num_reqs + 1]
|
||||
query_start_loc_np = self.input_buffers.query_start_loc.np[: num_reqs + 1]
|
||||
|
||||
# Copy prefill tokens from CPU to GPU.
|
||||
# Get prefill tokens.
|
||||
prepare_prefill_inputs(
|
||||
idx_mapping_np,
|
||||
num_scheduled_tokens,
|
||||
query_start_loc_np,
|
||||
self.req_states.prefill_token_ids.np,
|
||||
self.req_states.num_computed_prefill_tokens,
|
||||
self.input_buffers.input_ids.np,
|
||||
self.input_buffers.input_ids,
|
||||
self.req_states.next_prefill_tokens,
|
||||
idx_mapping,
|
||||
query_start_loc_gpu,
|
||||
self.req_states.prefill_token_ids.gpu,
|
||||
self.req_states.prefill_len.gpu,
|
||||
self.req_states.num_computed_tokens,
|
||||
)
|
||||
self.input_buffers.input_ids.copy_to_gpu(num_tokens)
|
||||
|
||||
# Prepare positions and seq_lens.
|
||||
prepare_pos_seq_lens(
|
||||
@@ -531,7 +526,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Some input token ids are directly read from the last sampled tokens
|
||||
# and draft tokens. Also, get the logits indices to sample tokens from.
|
||||
logits_indices = combine_sampled_and_draft_tokens(
|
||||
self.input_buffers.input_ids.gpu,
|
||||
self.input_buffers.input_ids,
|
||||
idx_mapping,
|
||||
self.req_states.last_sampled_tokens,
|
||||
query_start_loc_gpu,
|
||||
@@ -572,7 +567,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
)
|
||||
|
||||
input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding]
|
||||
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
|
||||
positions = self.input_buffers.positions[:num_tokens_after_padding]
|
||||
return InputBatch(
|
||||
req_ids=req_ids,
|
||||
@@ -782,20 +777,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_sampled: torch.Tensor,
|
||||
num_rejected: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
num_reqs = input_batch.num_reqs
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
with async_barrier(self.spec_decode_event):
|
||||
self.input_buffers.next_prefill_tokens.np[:num_reqs] = (
|
||||
self.req_states.prefill_token_ids.np[
|
||||
idx_mapping_np,
|
||||
self.req_states.num_computed_prefill_tokens[idx_mapping_np],
|
||||
]
|
||||
)
|
||||
next_prefill_tokens = self.input_buffers.next_prefill_tokens.copy_to_gpu(
|
||||
num_reqs
|
||||
)
|
||||
|
||||
assert self.speculator is not None
|
||||
last_sampled_tokens = self.req_states.last_sampled_tokens[
|
||||
input_batch.idx_mapping
|
||||
]
|
||||
next_prefill_tokens = self.req_states.next_prefill_tokens[
|
||||
input_batch.idx_mapping
|
||||
]
|
||||
draft_tokens = self.speculator.propose(
|
||||
input_batch,
|
||||
sampling_metadata,
|
||||
@@ -803,7 +791,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
aux_hidden_states,
|
||||
num_sampled,
|
||||
num_rejected,
|
||||
self.req_states.last_sampled_tokens,
|
||||
last_sampled_tokens,
|
||||
next_prefill_tokens,
|
||||
)
|
||||
return draft_tokens
|
||||
|
||||
Reference in New Issue
Block a user