[Model Runner V2] Refactor update_states (#32562)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2026-01-18 17:32:42 -08:00
committed by GitHub
parent bb1848cd62
commit 9a1f16da1e

View File

@@ -425,7 +425,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self._dummy_run(self.max_num_tokens, skip_attn=False)
torch.cuda.synchronize()
def update_states(self, scheduler_output: SchedulerOutput) -> None:
def finish_requests(self, scheduler_output: SchedulerOutput) -> None:
if scheduler_output.preempted_req_ids is not None:
for req_id in scheduler_output.preempted_req_ids:
self.req_states.remove_request(req_id)
@@ -436,11 +436,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.supports_mm_inputs:
self.encoder_runner.remove_request(req_id)
def free_states(self, scheduler_output: SchedulerOutput) -> None:
if self.supports_mm_inputs:
for mm_hash in scheduler_output.free_encoder_mm_hashes:
self.encoder_runner.free_encoder_cache(mm_hash)
# Add new requests.
def add_requests(self, scheduler_output: SchedulerOutput) -> None:
for new_req_data in scheduler_output.scheduled_new_reqs:
assert new_req_data.prompt_token_ids is not None
assert new_req_data.prefill_token_ids is not None
@@ -476,6 +477,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_index, prompt_len, new_req_data.sampling_params
)
if scheduler_output.scheduled_new_reqs:
self.req_states.apply_staged_writes()
self.sampler.apply_staged_writes(
self.req_states.prefill_token_ids.gpu,
self.req_states.prefill_len.np,
self.req_states.prompt_len,
)
if self.uses_mrope:
self.mrope_states.apply_staged_writes()
def update_requests(self, scheduler_output: SchedulerOutput) -> None:
# Add new blocks for the existing requests.
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
@@ -486,16 +498,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_index, req_new_block_ids, overwrite=False
)
self.req_states.apply_staged_writes()
self.block_tables.apply_staged_writes()
self.sampler.apply_staged_writes(
self.req_states.prefill_token_ids.gpu,
self.req_states.prefill_len.np,
self.req_states.prompt_len,
)
if self.uses_mrope:
self.mrope_states.apply_staged_writes()
def prepare_inputs(
self,
scheduler_output: SchedulerOutput,
@@ -951,15 +953,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dummy_run: bool = False,
) -> ModelRunnerOutput | None:
assert intermediate_tensors is None
if scheduler_output.total_num_scheduled_tokens == 0 and not dummy_run:
# No need to run the model.
self.update_states(scheduler_output)
return EMPTY_MODEL_RUNNER_OUTPUT
if not dummy_run:
# Update the request states.
self.finish_requests(scheduler_output)
self.free_states(scheduler_output)
self.add_requests(scheduler_output)
self.update_requests(scheduler_output)
self.block_tables.apply_staged_writes()
if scheduler_output.total_num_scheduled_tokens == 0:
# No need to run the model.
return EMPTY_MODEL_RUNNER_OUTPUT
cudagraph_mode, num_tokens_after_padding, num_tokens_across_dp = (
self.get_cudagraph_and_dp_padding(scheduler_output)
)
self.update_states(scheduler_output)
if num_tokens_after_padding == 0:
# All DP ranks have zero tokens to run.
return EMPTY_MODEL_RUNNER_OUTPUT