[Model Runner V2] Use NamedTuple for execute_model_state (#35930)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Woosuk Kwon
2026-03-09 11:17:34 -07:00
committed by GitHub
parent fe0c085c28
commit 10a5f4d53d

View File

@@ -21,6 +21,7 @@ import functools
import gc import gc
import time import time
from copy import deepcopy from copy import deepcopy
from typing import Any, NamedTuple
import numpy as np import numpy as np
import torch import torch
@@ -44,7 +45,7 @@ from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
from vllm.v1.worker.gpu.async_utils import AsyncOutput, AsyncPoolingOutput from vllm.v1.worker.gpu.async_utils import AsyncOutput, AsyncPoolingOutput
from vllm.v1.worker.gpu.attn_utils import ( from vllm.v1.worker.gpu.attn_utils import (
@@ -213,7 +214,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.pooling_runner: PoolingRunner | None = None self.pooling_runner: PoolingRunner | None = None
# For transferring state from execute_model to subsequent sample_tokens call. # For transferring state from execute_model to subsequent sample_tokens call.
self.execute_model_state: tuple | None = None self.execute_model_state: ExecuteModelState | None = None
def update_max_model_len(self, max_model_len: int) -> None: def update_max_model_len(self, max_model_len: int) -> None:
self.max_model_len = max_model_len self.max_model_len = max_model_len
@@ -375,16 +376,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return None, None return None, None
assert self.execute_model_state is not None assert self.execute_model_state is not None
( input_batch = self.execute_model_state.input_batch
input_batch, attn_metadata = self.execute_model_state.attn_metadata
model_inputs, slot_mappings_by_layer = self.execute_model_state.slot_mappings_by_layer
attn_metadata, hidden_states = self.execute_model_state.hidden_states
slot_mappings_by_layer, aux_hidden_states = self.execute_model_state.aux_hidden_states
hidden_states, num_tokens_across_dp = self.execute_model_state.num_tokens_across_dp
aux_hidden_states,
kv_connector_output,
num_tokens_across_dp,
) = self.execute_model_state
self.execute_model_state = None self.execute_model_state = None
# dummy run the eagle speculator's propose to ensure DP/EP sync. # dummy run the eagle speculator's propose to ensure DP/EP sync.
@@ -989,15 +986,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states = None aux_hidden_states = None
kv_connector_output = self.kv_connector.post_forward(scheduler_output) kv_connector_output = self.kv_connector.post_forward(scheduler_output)
self.execute_model_state = ( self.execute_model_state = ExecuteModelState(
input_batch, input_batch=input_batch,
model_inputs, attn_metadata=attn_metadata,
attn_metadata, slot_mappings_by_layer=slot_mappings_by_layer,
slot_mappings_by_layer, hidden_states=hidden_states,
hidden_states, aux_hidden_states=aux_hidden_states,
aux_hidden_states, kv_connector_output=kv_connector_output,
kv_connector_output, num_tokens_across_dp=num_tokens_across_dp,
num_tokens_across_dp,
) )
if not self.is_last_pp_rank: if not self.is_last_pp_rank:
@@ -1016,16 +1012,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.execute_model_state is None: if self.execute_model_state is None:
# The prior execute_model call must have failed. # The prior execute_model call must have failed.
return None return None
(
input_batch, input_batch = self.execute_model_state.input_batch
model_inputs, attn_metadata = self.execute_model_state.attn_metadata
attn_metadata, slot_mappings_by_layer = self.execute_model_state.slot_mappings_by_layer
slot_mappings_by_layer, hidden_states = self.execute_model_state.hidden_states
hidden_states, aux_hidden_states = self.execute_model_state.aux_hidden_states
aux_hidden_states, kv_connector_output = self.execute_model_state.kv_connector_output
kv_connector_output, num_tokens_across_dp = self.execute_model_state.num_tokens_across_dp
num_tokens_across_dp,
) = self.execute_model_state
self.execute_model_state = None self.execute_model_state = None
if not self.is_last_pp_rank: if not self.is_last_pp_rank:
@@ -1116,9 +1110,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# The prior execute_model call must have failed. # The prior execute_model call must have failed.
return None return None
input_batch, _, _, _, hidden_states, _, kv_connector_output, _ = ( input_batch = self.execute_model_state.input_batch
self.execute_model_state hidden_states = self.execute_model_state.hidden_states
) kv_connector_output = self.execute_model_state.kv_connector_output
self.execute_model_state = None self.execute_model_state = None
if not self.is_last_pp_rank: if not self.is_last_pp_rank:
@@ -1164,3 +1158,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
np.minimum( np.minimum(
computed_prefill, self.req_states.prefill_len.np, out=computed_prefill computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
) )
class ExecuteModelState(NamedTuple):
input_batch: InputBatch
attn_metadata: dict[str, Any] | None
slot_mappings_by_layer: dict[str, torch.Tensor] | None
hidden_states: torch.Tensor | IntermediateTensors
aux_hidden_states: list[torch.Tensor] | None
kv_connector_output: KVConnectorOutput | None
num_tokens_across_dp: torch.Tensor | None