[Model Runner V2] Use NamedTuple for execute_model_state (#35930)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -21,6 +21,7 @@ import functools
|
||||
import gc
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -44,7 +45,7 @@ from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
|
||||
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
|
||||
from vllm.v1.worker.gpu.async_utils import AsyncOutput, AsyncPoolingOutput
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
@@ -213,7 +214,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.pooling_runner: PoolingRunner | None = None
|
||||
|
||||
# For transferring state from execute_model to subsequent sample_tokens call.
|
||||
self.execute_model_state: tuple | None = None
|
||||
self.execute_model_state: ExecuteModelState | None = None
|
||||
|
||||
def update_max_model_len(self, max_model_len: int) -> None:
|
||||
self.max_model_len = max_model_len
|
||||
@@ -375,16 +376,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
return None, None
|
||||
|
||||
assert self.execute_model_state is not None
|
||||
(
|
||||
input_batch,
|
||||
model_inputs,
|
||||
attn_metadata,
|
||||
slot_mappings_by_layer,
|
||||
hidden_states,
|
||||
aux_hidden_states,
|
||||
kv_connector_output,
|
||||
num_tokens_across_dp,
|
||||
) = self.execute_model_state
|
||||
input_batch = self.execute_model_state.input_batch
|
||||
attn_metadata = self.execute_model_state.attn_metadata
|
||||
slot_mappings_by_layer = self.execute_model_state.slot_mappings_by_layer
|
||||
hidden_states = self.execute_model_state.hidden_states
|
||||
aux_hidden_states = self.execute_model_state.aux_hidden_states
|
||||
num_tokens_across_dp = self.execute_model_state.num_tokens_across_dp
|
||||
self.execute_model_state = None
|
||||
|
||||
# dummy run the eagle speculator's propose to ensure DP/EP sync.
|
||||
@@ -989,15 +986,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
aux_hidden_states = None
|
||||
|
||||
kv_connector_output = self.kv_connector.post_forward(scheduler_output)
|
||||
self.execute_model_state = (
|
||||
input_batch,
|
||||
model_inputs,
|
||||
attn_metadata,
|
||||
slot_mappings_by_layer,
|
||||
hidden_states,
|
||||
aux_hidden_states,
|
||||
kv_connector_output,
|
||||
num_tokens_across_dp,
|
||||
self.execute_model_state = ExecuteModelState(
|
||||
input_batch=input_batch,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mappings_by_layer=slot_mappings_by_layer,
|
||||
hidden_states=hidden_states,
|
||||
aux_hidden_states=aux_hidden_states,
|
||||
kv_connector_output=kv_connector_output,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
)
|
||||
|
||||
if not self.is_last_pp_rank:
|
||||
@@ -1016,16 +1012,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.execute_model_state is None:
|
||||
# The prior execute_model call must have failed.
|
||||
return None
|
||||
(
|
||||
input_batch,
|
||||
model_inputs,
|
||||
attn_metadata,
|
||||
slot_mappings_by_layer,
|
||||
hidden_states,
|
||||
aux_hidden_states,
|
||||
kv_connector_output,
|
||||
num_tokens_across_dp,
|
||||
) = self.execute_model_state
|
||||
|
||||
input_batch = self.execute_model_state.input_batch
|
||||
attn_metadata = self.execute_model_state.attn_metadata
|
||||
slot_mappings_by_layer = self.execute_model_state.slot_mappings_by_layer
|
||||
hidden_states = self.execute_model_state.hidden_states
|
||||
aux_hidden_states = self.execute_model_state.aux_hidden_states
|
||||
kv_connector_output = self.execute_model_state.kv_connector_output
|
||||
num_tokens_across_dp = self.execute_model_state.num_tokens_across_dp
|
||||
self.execute_model_state = None
|
||||
|
||||
if not self.is_last_pp_rank:
|
||||
@@ -1116,9 +1110,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# The prior execute_model call must have failed.
|
||||
return None
|
||||
|
||||
input_batch, _, _, _, hidden_states, _, kv_connector_output, _ = (
|
||||
self.execute_model_state
|
||||
)
|
||||
input_batch = self.execute_model_state.input_batch
|
||||
hidden_states = self.execute_model_state.hidden_states
|
||||
kv_connector_output = self.execute_model_state.kv_connector_output
|
||||
self.execute_model_state = None
|
||||
|
||||
if not self.is_last_pp_rank:
|
||||
@@ -1164,3 +1158,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
np.minimum(
|
||||
computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
|
||||
)
|
||||
|
||||
|
||||
class ExecuteModelState(NamedTuple):
|
||||
input_batch: InputBatch
|
||||
attn_metadata: dict[str, Any] | None
|
||||
slot_mappings_by_layer: dict[str, torch.Tensor] | None
|
||||
hidden_states: torch.Tensor | IntermediateTensors
|
||||
aux_hidden_states: list[torch.Tensor] | None
|
||||
kv_connector_output: KVConnectorOutput | None
|
||||
num_tokens_across_dp: torch.Tensor | None
|
||||
|
||||
Reference in New Issue
Block a user