From 9a1f16da1e423ede2c2f52a9850cbfbb39cefe96 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 18 Jan 2026 17:32:42 -0800 Subject: [PATCH] [Model Runner V2] Refactor `update_states` (#32562) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/model_runner.py | 41 +++++++++++++++++------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 1dc844bb3..a55519f0f 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -425,7 +425,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self._dummy_run(self.max_num_tokens, skip_attn=False) torch.cuda.synchronize() - def update_states(self, scheduler_output: SchedulerOutput) -> None: + def finish_requests(self, scheduler_output: SchedulerOutput) -> None: if scheduler_output.preempted_req_ids is not None: for req_id in scheduler_output.preempted_req_ids: self.req_states.remove_request(req_id) @@ -436,11 +436,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.supports_mm_inputs: self.encoder_runner.remove_request(req_id) + def free_states(self, scheduler_output: SchedulerOutput) -> None: if self.supports_mm_inputs: for mm_hash in scheduler_output.free_encoder_mm_hashes: self.encoder_runner.free_encoder_cache(mm_hash) - # Add new requests. + def add_requests(self, scheduler_output: SchedulerOutput) -> None: for new_req_data in scheduler_output.scheduled_new_reqs: assert new_req_data.prompt_token_ids is not None assert new_req_data.prefill_token_ids is not None @@ -476,6 +477,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_index, prompt_len, new_req_data.sampling_params ) + if scheduler_output.scheduled_new_reqs: + self.req_states.apply_staged_writes() + self.sampler.apply_staged_writes( + self.req_states.prefill_token_ids.gpu, + self.req_states.prefill_len.np, + self.req_states.prompt_len, + ) + if self.uses_mrope: + self.mrope_states.apply_staged_writes() + + def update_requests(self, scheduler_output: SchedulerOutput) -> None: # Add new blocks for the existing requests. cached_reqs = scheduler_output.scheduled_cached_reqs for i, req_id in enumerate(cached_reqs.req_ids): @@ -486,16 +498,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_index, req_new_block_ids, overwrite=False ) - self.req_states.apply_staged_writes() - self.block_tables.apply_staged_writes() - self.sampler.apply_staged_writes( - self.req_states.prefill_token_ids.gpu, - self.req_states.prefill_len.np, - self.req_states.prompt_len, - ) - if self.uses_mrope: - self.mrope_states.apply_staged_writes() - def prepare_inputs( self, scheduler_output: SchedulerOutput, @@ -951,15 +953,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dummy_run: bool = False, ) -> ModelRunnerOutput | None: assert intermediate_tensors is None - if scheduler_output.total_num_scheduled_tokens == 0 and not dummy_run: - # No need to run the model. - self.update_states(scheduler_output) - return EMPTY_MODEL_RUNNER_OUTPUT + if not dummy_run: + # Update the request states. + self.finish_requests(scheduler_output) + self.free_states(scheduler_output) + self.add_requests(scheduler_output) + self.update_requests(scheduler_output) + self.block_tables.apply_staged_writes() + if scheduler_output.total_num_scheduled_tokens == 0: + # No need to run the model. + return EMPTY_MODEL_RUNNER_OUTPUT cudagraph_mode, num_tokens_after_padding, num_tokens_across_dp = ( self.get_cudagraph_and_dp_padding(scheduler_output) ) - self.update_states(scheduler_output) if num_tokens_after_padding == 0: # All DP ranks have zero tokens to run. return EMPTY_MODEL_RUNNER_OUTPUT