From e3eb146f7ad4bc920e11e98cf88cee3839cf5f89 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 28 Feb 2026 13:19:45 -0800 Subject: [PATCH] [Model Runner V2] Add ModelStateInterface [4/N] (#35621) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/cudagraph_utils.py | 2 +- vllm/v1/worker/gpu/model_runner.py | 4 +- vllm/v1/worker/gpu/model_states/__init__.py | 18 +++++ .../default.py} | 3 +- vllm/v1/worker/gpu/model_states/interface.py | 67 +++++++++++++++++++ 5 files changed, 90 insertions(+), 4 deletions(-) create mode 100644 vllm/v1/worker/gpu/model_states/__init__.py rename vllm/v1/worker/gpu/{model_states.py => model_states/default.py} (98%) create mode 100644 vllm/v1/worker/gpu/model_states/interface.py diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index 6e43043bc..783715cfe 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -22,7 +22,7 @@ from vllm.v1.worker.gpu.attn_utils import ( from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp from vllm.v1.worker.gpu.input_batch import InputBuffers -from vllm.v1.worker.gpu.model_states import ModelState +from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.utils import AttentionGroup diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 188a2694e..ca44ad164 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -78,7 +78,7 @@ from vllm.v1.worker.gpu.kv_connector import ( ) from vllm.v1.worker.gpu.lora_utils import LoraState from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache -from vllm.v1.worker.gpu.model_states import ModelState +from vllm.v1.worker.gpu.model_states import init_model_state from vllm.v1.worker.gpu.pool.pooling_runner import PoolingRunner from vllm.v1.worker.gpu.pp_utils import pp_broadcast, pp_receive from vllm.v1.worker.gpu.sample.output import SamplerOutput @@ -267,7 +267,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): prepare_communication_buffer_for_model(self.speculator) # Initialize the components that require the model. - self.model_state = ModelState( + self.model_state = init_model_state( self.vllm_config, self.model, self.encoder_cache, self.device ) if self.is_pooling_model: diff --git a/vllm/v1/worker/gpu/model_states/__init__.py b/vllm/v1/worker/gpu/model_states/__init__.py new file mode 100644 index 000000000..3ddce0fdc --- /dev/null +++ b/vllm/v1/worker/gpu/model_states/__init__.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache + + +def init_model_state( + vllm_config: VllmConfig, + model: nn.Module, + encoder_cache: EncoderCache | None, + device: torch.device, +): + from vllm.v1.worker.gpu.model_states.default import DefaultModelState + + return DefaultModelState(vllm_config, model, encoder_cache, device) diff --git a/vllm/v1/worker/gpu/model_states.py b/vllm/v1/worker/gpu/model_states/default.py similarity index 98% rename from vllm/v1/worker/gpu/model_states.py rename to vllm/v1/worker/gpu/model_states/default.py index ca4d63e6b..d52f7d0ec 100644 --- a/vllm/v1/worker/gpu/model_states.py +++ b/vllm/v1/worker/gpu/model_states/default.py @@ -13,11 +13,12 @@ from vllm.v1.worker.gpu.input_batch import InputBatch from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState +from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.gpu.states import RequestState from vllm.v1.worker.utils import AttentionGroup -class ModelState: +class DefaultModelState(ModelState): def __init__( self, vllm_config: VllmConfig, diff --git a/vllm/v1/worker/gpu/model_states/interface.py b/vllm/v1/worker/gpu/model_states/interface.py new file mode 100644 index 000000000..d5a25710c --- /dev/null +++ b/vllm/v1/worker/gpu/model_states/interface.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from typing import Any + +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.v1.core.sched.output import NewRequestData +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.worker.gpu.input_batch import InputBatch +from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache +from vllm.v1.worker.gpu.states import RequestState +from vllm.v1.worker.utils import AttentionGroup + + +class ModelState(ABC): + @abstractmethod + def __init__( + self, + vllm_config: VllmConfig, + model: nn.Module, + encoder_cache: EncoderCache | None, + device: torch.device, + ) -> None: + raise NotImplementedError + + @abstractmethod + def add_request(self, req_index: int, new_req_data: NewRequestData) -> None: + raise NotImplementedError + + @abstractmethod + def apply_staged_writes(self) -> None: + raise NotImplementedError + + @abstractmethod + def get_mm_embeddings( + self, + scheduled_encoder_inputs: dict[str, list[int]], + input_batch: InputBatch, + req_states: RequestState, + ) -> torch.Tensor: + raise NotImplementedError + + @abstractmethod + def prepare_inputs( + self, input_batch: InputBatch, req_states: RequestState + ) -> dict[str, torch.Tensor | None]: + raise NotImplementedError + + @abstractmethod + def prepare_dummy_inputs( + self, num_reqs: int, num_tokens: int + ) -> dict[str, torch.Tensor | None]: + raise NotImplementedError + + @abstractmethod + def prepare_attn( + self, + input_batch: InputBatch, + block_tables: tuple[torch.Tensor, ...], + slot_mappings: torch.Tensor, + attn_groups: list[list[AttentionGroup]], + kv_cache_config: KVCacheConfig, + ) -> dict[str, Any]: + raise NotImplementedError