[V1] Remove scheduling constraint on partial requests (#12674)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -46,6 +46,8 @@ class BlockTable:
|
||||
start: int,
|
||||
block_ids: List[int],
|
||||
) -> None:
|
||||
if not block_ids:
|
||||
return
|
||||
num_blocks = len(block_ids)
|
||||
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
|
||||
self.num_blocks_per_row[row_idx] = start + num_blocks
|
||||
|
||||
@@ -205,12 +205,32 @@ class GPUModelRunner:
|
||||
pin_memory=self.pin_memory)
|
||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
# Remove stopped requests from the cached states.
|
||||
# Keep the states of the preempted requests.
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
|
||||
"""Update the cached states and the persistent batch with the scheduler
|
||||
output.
|
||||
|
||||
The updated states are used by the `_prepare_inputs` function to create
|
||||
the input GPU tensors for the model.
|
||||
|
||||
Returns:
|
||||
True if there is a new/resumed/paused/finished request in the batch.
|
||||
If False, we can skip copying SamplingMetadata to the GPU.
|
||||
"""
|
||||
# Remove finished requests from the cached states.
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.requests.pop(req_id, None)
|
||||
self.encoder_cache.pop(req_id, None)
|
||||
# Remove the finished requests from the persistent batch.
|
||||
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
||||
# scheduled_req_ids overlap. This happens when a request is aborted and
|
||||
# then resubmitted with the same ID. In this case, we treat them as two
|
||||
# distinct requests - clearing the cached states for the first request
|
||||
# and handling the second as a new request.
|
||||
removed_req_indices: List[int] = []
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
req_index = self.input_batch.remove_request(req_id)
|
||||
if req_index is not None:
|
||||
removed_req_indices.append(req_index)
|
||||
|
||||
# Free the cached encoder outputs.
|
||||
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
||||
@@ -220,36 +240,22 @@ class GPUModelRunner:
|
||||
if not encoder_outputs:
|
||||
self.encoder_cache.pop(req_id, None)
|
||||
|
||||
# Remove the requests from the persistent batch.
|
||||
stopped_req_ids = set().union(
|
||||
scheduler_output.preempted_req_ids,
|
||||
scheduler_output.finished_req_ids,
|
||||
)
|
||||
removed_req_indices: List[int] = []
|
||||
for req_id in stopped_req_ids:
|
||||
# Remove the unscheduled requests from the persistent batch.
|
||||
# NOTE(woosuk): The unscheduled requests are either preempted requests
|
||||
# or running requests that are not scheduled in this step. We remove
|
||||
# them from the persistent batch but keep their cached states since
|
||||
# they will be scheduled again sometime in the future.
|
||||
scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
|
||||
cached_req_ids = self.input_batch.req_id_to_index.keys()
|
||||
unscheduled_req_ids = cached_req_ids - scheduled_req_ids
|
||||
# NOTE(woosuk): The persistent batch optimization assumes that
|
||||
# consecutive batches contain mostly the same requests. If batches
|
||||
# have low request overlap (e.g., alternating between two distinct
|
||||
# sets of requests), this optimization becomes very inefficient.
|
||||
for req_id in unscheduled_req_ids:
|
||||
req_index = self.input_batch.remove_request(req_id)
|
||||
if req_index is not None:
|
||||
removed_req_indices.append(req_index)
|
||||
|
||||
# Update the states of the running requests.
|
||||
for req_data in scheduler_output.scheduled_running_reqs:
|
||||
req_id = req_data.req_id
|
||||
req_state = self.requests[req_id]
|
||||
req_index = self.input_batch.req_id_to_index[req_id]
|
||||
|
||||
# Update the num_computed_tokens.
|
||||
req_state.num_computed_tokens = req_data.num_computed_tokens
|
||||
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
||||
req_data.num_computed_tokens)
|
||||
|
||||
# Update the block table.
|
||||
num_new_blocks = len(req_data.new_block_ids)
|
||||
if num_new_blocks == 0:
|
||||
continue
|
||||
start_index = len(req_state.block_ids)
|
||||
req_state.block_ids.extend(req_data.new_block_ids)
|
||||
self.input_batch.block_table.append_row(req_index, start_index,
|
||||
req_data.new_block_ids)
|
||||
assert req_index is not None
|
||||
removed_req_indices.append(req_index)
|
||||
|
||||
req_ids_to_add: List[str] = []
|
||||
# Add new requests to the cached states.
|
||||
@@ -305,14 +311,36 @@ class GPUModelRunner:
|
||||
|
||||
req_ids_to_add.append(req_id)
|
||||
|
||||
# Update the cached states of the resumed requests.
|
||||
for res_req_data in scheduler_output.scheduled_resumed_reqs:
|
||||
req_id = res_req_data.req_id
|
||||
# Update the states of the running/resumed requests.
|
||||
for req_data in scheduler_output.scheduled_cached_reqs:
|
||||
req_id = req_data.req_id
|
||||
req_state = self.requests[req_id]
|
||||
|
||||
req_state.block_ids = res_req_data.block_ids
|
||||
req_state.num_computed_tokens = res_req_data.num_computed_tokens
|
||||
req_ids_to_add.append(req_id)
|
||||
# Update the cached states.
|
||||
req_state.num_computed_tokens = req_data.num_computed_tokens
|
||||
if not req_data.resumed_from_preemption:
|
||||
# Append the new blocks to the existing block IDs.
|
||||
req_state.block_ids.extend(req_data.new_block_ids)
|
||||
else:
|
||||
# The request is resumed from preemption.
|
||||
# Replace the existing block IDs with the new ones.
|
||||
req_state.block_ids = req_data.new_block_ids
|
||||
|
||||
req_index = self.input_batch.req_id_to_index.get(req_id)
|
||||
if req_index is None:
|
||||
# The request is not in the persistent batch.
|
||||
# The request was either preempted and resumed later, or was not
|
||||
# scheduled in the previous step and needs to be added again.
|
||||
req_ids_to_add.append(req_id)
|
||||
continue
|
||||
|
||||
# Update the persistent batch.
|
||||
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
||||
req_data.num_computed_tokens)
|
||||
start_index = len(req_state.block_ids) - len(
|
||||
req_data.new_block_ids)
|
||||
self.input_batch.block_table.append_row(req_index, start_index,
|
||||
req_data.new_block_ids)
|
||||
|
||||
# Add the new or resumed requests to the persistent batch.
|
||||
# The smaller empty indices are filled first.
|
||||
@@ -330,6 +358,7 @@ class GPUModelRunner:
|
||||
# Condense the batched states if there are empty indices.
|
||||
if removed_req_indices:
|
||||
self.input_batch.condense(removed_req_indices)
|
||||
return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
|
||||
|
||||
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
@@ -536,10 +565,10 @@ class GPUModelRunner:
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
)
|
||||
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
|
||||
# request in the batch. While we should not sample any token from this
|
||||
# partial request, we do so for simplicity. We will ignore the sampled
|
||||
# token from the partial request.
|
||||
# NOTE(woosuk): Due to chunked prefills, the batch may contain partial
|
||||
# requests. While we should not sample any token from these partial
|
||||
# requests, we do so for simplicity. We will ignore the sampled
|
||||
# tokens from the partial requests.
|
||||
# TODO: Support prompt logprobs.
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
return attn_metadata, logits_indices
|
||||
@@ -601,22 +630,15 @@ class GPUModelRunner:
|
||||
|
||||
def _prepare_sampling(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
batch_changed: bool,
|
||||
) -> SamplingMetadata:
|
||||
skip_copy = True
|
||||
if (scheduler_output.finished_req_ids
|
||||
or scheduler_output.preempted_req_ids):
|
||||
skip_copy = False
|
||||
if (scheduler_output.scheduled_new_reqs
|
||||
or scheduler_output.scheduled_resumed_reqs):
|
||||
skip_copy = False
|
||||
# Create the sampling metadata.
|
||||
req_id_output_token_ids: Dict[str, List[int]] = \
|
||||
{req_id: req.output_token_ids \
|
||||
for req_id, req in self.requests.items()}
|
||||
|
||||
sampling_metadata = self.input_batch.make_sampling_metadata(
|
||||
req_id_output_token_ids, skip_copy)
|
||||
req_id_output_token_ids, skip_copy=not batch_changed)
|
||||
return sampling_metadata
|
||||
|
||||
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||
@@ -715,7 +737,7 @@ class GPUModelRunner:
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> ModelRunnerOutput:
|
||||
self._update_states(scheduler_output)
|
||||
batch_changed = self._update_states(scheduler_output)
|
||||
|
||||
if self.is_multimodal_model:
|
||||
# Run the multimodal encoder if any.
|
||||
@@ -778,7 +800,7 @@ class GPUModelRunner:
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self._prepare_sampling(scheduler_output)
|
||||
sampling_metadata = self._prepare_sampling(batch_changed)
|
||||
sampler_output = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
|
||||
Reference in New Issue
Block a user