[Model Runner V2] Use NamedTuple for execute_model_state (#35930)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Woosuk Kwon
2026-03-09 11:17:34 -07:00
committed by GitHub
parent fe0c085c28
commit 10a5f4d53d

View File

@@ -21,6 +21,7 @@ import functools
import gc
import time
from copy import deepcopy
from typing import Any, NamedTuple
import numpy as np
import torch
@@ -44,7 +45,7 @@ from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
from vllm.v1.worker.gpu.async_utils import AsyncOutput, AsyncPoolingOutput
from vllm.v1.worker.gpu.attn_utils import (
@@ -213,7 +214,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.pooling_runner: PoolingRunner | None = None
# For transferring state from execute_model to subsequent sample_tokens call.
self.execute_model_state: tuple | None = None
self.execute_model_state: ExecuteModelState | None = None
def update_max_model_len(self, max_model_len: int) -> None:
self.max_model_len = max_model_len
@@ -375,16 +376,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return None, None
assert self.execute_model_state is not None
(
input_batch,
model_inputs,
attn_metadata,
slot_mappings_by_layer,
hidden_states,
aux_hidden_states,
kv_connector_output,
num_tokens_across_dp,
) = self.execute_model_state
input_batch = self.execute_model_state.input_batch
attn_metadata = self.execute_model_state.attn_metadata
slot_mappings_by_layer = self.execute_model_state.slot_mappings_by_layer
hidden_states = self.execute_model_state.hidden_states
aux_hidden_states = self.execute_model_state.aux_hidden_states
num_tokens_across_dp = self.execute_model_state.num_tokens_across_dp
self.execute_model_state = None
# dummy run the eagle speculator's propose to ensure DP/EP sync.
@@ -989,15 +986,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states = None
kv_connector_output = self.kv_connector.post_forward(scheduler_output)
self.execute_model_state = (
input_batch,
model_inputs,
attn_metadata,
slot_mappings_by_layer,
hidden_states,
aux_hidden_states,
kv_connector_output,
num_tokens_across_dp,
self.execute_model_state = ExecuteModelState(
input_batch=input_batch,
attn_metadata=attn_metadata,
slot_mappings_by_layer=slot_mappings_by_layer,
hidden_states=hidden_states,
aux_hidden_states=aux_hidden_states,
kv_connector_output=kv_connector_output,
num_tokens_across_dp=num_tokens_across_dp,
)
if not self.is_last_pp_rank:
@@ -1016,16 +1012,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.execute_model_state is None:
# The prior execute_model call must have failed.
return None
(
input_batch,
model_inputs,
attn_metadata,
slot_mappings_by_layer,
hidden_states,
aux_hidden_states,
kv_connector_output,
num_tokens_across_dp,
) = self.execute_model_state
input_batch = self.execute_model_state.input_batch
attn_metadata = self.execute_model_state.attn_metadata
slot_mappings_by_layer = self.execute_model_state.slot_mappings_by_layer
hidden_states = self.execute_model_state.hidden_states
aux_hidden_states = self.execute_model_state.aux_hidden_states
kv_connector_output = self.execute_model_state.kv_connector_output
num_tokens_across_dp = self.execute_model_state.num_tokens_across_dp
self.execute_model_state = None
if not self.is_last_pp_rank:
@@ -1116,9 +1110,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# The prior execute_model call must have failed.
return None
input_batch, _, _, _, hidden_states, _, kv_connector_output, _ = (
self.execute_model_state
)
input_batch = self.execute_model_state.input_batch
hidden_states = self.execute_model_state.hidden_states
kv_connector_output = self.execute_model_state.kv_connector_output
self.execute_model_state = None
if not self.is_last_pp_rank:
@@ -1164,3 +1158,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
np.minimum(
computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
)
class ExecuteModelState(NamedTuple):
input_batch: InputBatch
attn_metadata: dict[str, Any] | None
slot_mappings_by_layer: dict[str, torch.Tensor] | None
hidden_states: torch.Tensor | IntermediateTensors
aux_hidden_states: list[torch.Tensor] | None
kv_connector_output: KVConnectorOutput | None
num_tokens_across_dp: torch.Tensor | None