[Model Runner V2] Use NamedTuple for execute_model_state (#35930)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -21,6 +21,7 @@ import functools
|
|||||||
import gc
|
import gc
|
||||||
import time
|
import time
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from typing import Any, NamedTuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -44,7 +45,7 @@ from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
|
|||||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
|
||||||
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
|
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
|
||||||
from vllm.v1.worker.gpu.async_utils import AsyncOutput, AsyncPoolingOutput
|
from vllm.v1.worker.gpu.async_utils import AsyncOutput, AsyncPoolingOutput
|
||||||
from vllm.v1.worker.gpu.attn_utils import (
|
from vllm.v1.worker.gpu.attn_utils import (
|
||||||
@@ -213,7 +214,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.pooling_runner: PoolingRunner | None = None
|
self.pooling_runner: PoolingRunner | None = None
|
||||||
|
|
||||||
# For transferring state from execute_model to subsequent sample_tokens call.
|
# For transferring state from execute_model to subsequent sample_tokens call.
|
||||||
self.execute_model_state: tuple | None = None
|
self.execute_model_state: ExecuteModelState | None = None
|
||||||
|
|
||||||
def update_max_model_len(self, max_model_len: int) -> None:
|
def update_max_model_len(self, max_model_len: int) -> None:
|
||||||
self.max_model_len = max_model_len
|
self.max_model_len = max_model_len
|
||||||
@@ -375,16 +376,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
assert self.execute_model_state is not None
|
assert self.execute_model_state is not None
|
||||||
(
|
input_batch = self.execute_model_state.input_batch
|
||||||
input_batch,
|
attn_metadata = self.execute_model_state.attn_metadata
|
||||||
model_inputs,
|
slot_mappings_by_layer = self.execute_model_state.slot_mappings_by_layer
|
||||||
attn_metadata,
|
hidden_states = self.execute_model_state.hidden_states
|
||||||
slot_mappings_by_layer,
|
aux_hidden_states = self.execute_model_state.aux_hidden_states
|
||||||
hidden_states,
|
num_tokens_across_dp = self.execute_model_state.num_tokens_across_dp
|
||||||
aux_hidden_states,
|
|
||||||
kv_connector_output,
|
|
||||||
num_tokens_across_dp,
|
|
||||||
) = self.execute_model_state
|
|
||||||
self.execute_model_state = None
|
self.execute_model_state = None
|
||||||
|
|
||||||
# dummy run the eagle speculator's propose to ensure DP/EP sync.
|
# dummy run the eagle speculator's propose to ensure DP/EP sync.
|
||||||
@@ -989,15 +986,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
aux_hidden_states = None
|
aux_hidden_states = None
|
||||||
|
|
||||||
kv_connector_output = self.kv_connector.post_forward(scheduler_output)
|
kv_connector_output = self.kv_connector.post_forward(scheduler_output)
|
||||||
self.execute_model_state = (
|
self.execute_model_state = ExecuteModelState(
|
||||||
input_batch,
|
input_batch=input_batch,
|
||||||
model_inputs,
|
attn_metadata=attn_metadata,
|
||||||
attn_metadata,
|
slot_mappings_by_layer=slot_mappings_by_layer,
|
||||||
slot_mappings_by_layer,
|
hidden_states=hidden_states,
|
||||||
hidden_states,
|
aux_hidden_states=aux_hidden_states,
|
||||||
aux_hidden_states,
|
kv_connector_output=kv_connector_output,
|
||||||
kv_connector_output,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
num_tokens_across_dp,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.is_last_pp_rank:
|
if not self.is_last_pp_rank:
|
||||||
@@ -1016,16 +1012,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if self.execute_model_state is None:
|
if self.execute_model_state is None:
|
||||||
# The prior execute_model call must have failed.
|
# The prior execute_model call must have failed.
|
||||||
return None
|
return None
|
||||||
(
|
|
||||||
input_batch,
|
input_batch = self.execute_model_state.input_batch
|
||||||
model_inputs,
|
attn_metadata = self.execute_model_state.attn_metadata
|
||||||
attn_metadata,
|
slot_mappings_by_layer = self.execute_model_state.slot_mappings_by_layer
|
||||||
slot_mappings_by_layer,
|
hidden_states = self.execute_model_state.hidden_states
|
||||||
hidden_states,
|
aux_hidden_states = self.execute_model_state.aux_hidden_states
|
||||||
aux_hidden_states,
|
kv_connector_output = self.execute_model_state.kv_connector_output
|
||||||
kv_connector_output,
|
num_tokens_across_dp = self.execute_model_state.num_tokens_across_dp
|
||||||
num_tokens_across_dp,
|
|
||||||
) = self.execute_model_state
|
|
||||||
self.execute_model_state = None
|
self.execute_model_state = None
|
||||||
|
|
||||||
if not self.is_last_pp_rank:
|
if not self.is_last_pp_rank:
|
||||||
@@ -1116,9 +1110,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# The prior execute_model call must have failed.
|
# The prior execute_model call must have failed.
|
||||||
return None
|
return None
|
||||||
|
|
||||||
input_batch, _, _, _, hidden_states, _, kv_connector_output, _ = (
|
input_batch = self.execute_model_state.input_batch
|
||||||
self.execute_model_state
|
hidden_states = self.execute_model_state.hidden_states
|
||||||
)
|
kv_connector_output = self.execute_model_state.kv_connector_output
|
||||||
self.execute_model_state = None
|
self.execute_model_state = None
|
||||||
|
|
||||||
if not self.is_last_pp_rank:
|
if not self.is_last_pp_rank:
|
||||||
@@ -1164,3 +1158,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
np.minimum(
|
np.minimum(
|
||||||
computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
|
computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExecuteModelState(NamedTuple):
|
||||||
|
input_batch: InputBatch
|
||||||
|
attn_metadata: dict[str, Any] | None
|
||||||
|
slot_mappings_by_layer: dict[str, torch.Tensor] | None
|
||||||
|
hidden_states: torch.Tensor | IntermediateTensors
|
||||||
|
aux_hidden_states: list[torch.Tensor] | None
|
||||||
|
kv_connector_output: KVConnectorOutput | None
|
||||||
|
num_tokens_across_dp: torch.Tensor | None
|
||||||
|
|||||||
Reference in New Issue
Block a user