[Model Runner V2] Refactor update_states (#32562)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -425,7 +425,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self._dummy_run(self.max_num_tokens, skip_attn=False)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def update_states(self, scheduler_output: SchedulerOutput) -> None:
|
||||
def finish_requests(self, scheduler_output: SchedulerOutput) -> None:
|
||||
if scheduler_output.preempted_req_ids is not None:
|
||||
for req_id in scheduler_output.preempted_req_ids:
|
||||
self.req_states.remove_request(req_id)
|
||||
@@ -436,11 +436,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if self.supports_mm_inputs:
|
||||
self.encoder_runner.remove_request(req_id)
|
||||
|
||||
def free_states(self, scheduler_output: SchedulerOutput) -> None:
|
||||
if self.supports_mm_inputs:
|
||||
for mm_hash in scheduler_output.free_encoder_mm_hashes:
|
||||
self.encoder_runner.free_encoder_cache(mm_hash)
|
||||
|
||||
# Add new requests.
|
||||
def add_requests(self, scheduler_output: SchedulerOutput) -> None:
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
assert new_req_data.prompt_token_ids is not None
|
||||
assert new_req_data.prefill_token_ids is not None
|
||||
@@ -476,6 +477,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
req_index, prompt_len, new_req_data.sampling_params
|
||||
)
|
||||
|
||||
if scheduler_output.scheduled_new_reqs:
|
||||
self.req_states.apply_staged_writes()
|
||||
self.sampler.apply_staged_writes(
|
||||
self.req_states.prefill_token_ids.gpu,
|
||||
self.req_states.prefill_len.np,
|
||||
self.req_states.prompt_len,
|
||||
)
|
||||
if self.uses_mrope:
|
||||
self.mrope_states.apply_staged_writes()
|
||||
|
||||
def update_requests(self, scheduler_output: SchedulerOutput) -> None:
|
||||
# Add new blocks for the existing requests.
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
@@ -486,16 +498,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
req_index, req_new_block_ids, overwrite=False
|
||||
)
|
||||
|
||||
self.req_states.apply_staged_writes()
|
||||
self.block_tables.apply_staged_writes()
|
||||
self.sampler.apply_staged_writes(
|
||||
self.req_states.prefill_token_ids.gpu,
|
||||
self.req_states.prefill_len.np,
|
||||
self.req_states.prompt_len,
|
||||
)
|
||||
if self.uses_mrope:
|
||||
self.mrope_states.apply_staged_writes()
|
||||
|
||||
def prepare_inputs(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
@@ -951,15 +953,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
dummy_run: bool = False,
|
||||
) -> ModelRunnerOutput | None:
|
||||
assert intermediate_tensors is None
|
||||
if scheduler_output.total_num_scheduled_tokens == 0 and not dummy_run:
|
||||
# No need to run the model.
|
||||
self.update_states(scheduler_output)
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
if not dummy_run:
|
||||
# Update the request states.
|
||||
self.finish_requests(scheduler_output)
|
||||
self.free_states(scheduler_output)
|
||||
self.add_requests(scheduler_output)
|
||||
self.update_requests(scheduler_output)
|
||||
self.block_tables.apply_staged_writes()
|
||||
if scheduler_output.total_num_scheduled_tokens == 0:
|
||||
# No need to run the model.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
cudagraph_mode, num_tokens_after_padding, num_tokens_across_dp = (
|
||||
self.get_cudagraph_and_dp_padding(scheduler_output)
|
||||
)
|
||||
self.update_states(scheduler_output)
|
||||
if num_tokens_after_padding == 0:
|
||||
# All DP ranks have zero tokens to run.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
Reference in New Issue
Block a user