[Model Runner V2] Add ModelStateInterface [4/N] (#35621)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -22,7 +22,7 @@ from vllm.v1.worker.gpu.attn_utils import (
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
|
||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
from vllm.v1.worker.gpu.model_states import ModelState
|
||||
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ from vllm.v1.worker.gpu.kv_connector import (
|
||||
)
|
||||
from vllm.v1.worker.gpu.lora_utils import LoraState
|
||||
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
|
||||
from vllm.v1.worker.gpu.model_states import ModelState
|
||||
from vllm.v1.worker.gpu.model_states import init_model_state
|
||||
from vllm.v1.worker.gpu.pool.pooling_runner import PoolingRunner
|
||||
from vllm.v1.worker.gpu.pp_utils import pp_broadcast, pp_receive
|
||||
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
||||
@@ -267,7 +267,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
prepare_communication_buffer_for_model(self.speculator)
|
||||
|
||||
# Initialize the components that require the model.
|
||||
self.model_state = ModelState(
|
||||
self.model_state = init_model_state(
|
||||
self.vllm_config, self.model, self.encoder_cache, self.device
|
||||
)
|
||||
if self.is_pooling_model:
|
||||
|
||||
18
vllm/v1/worker/gpu/model_states/__init__.py
Normal file
18
vllm/v1/worker/gpu/model_states/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
|
||||
|
||||
|
||||
def init_model_state(
|
||||
vllm_config: VllmConfig,
|
||||
model: nn.Module,
|
||||
encoder_cache: EncoderCache | None,
|
||||
device: torch.device,
|
||||
):
|
||||
from vllm.v1.worker.gpu.model_states.default import DefaultModelState
|
||||
|
||||
return DefaultModelState(vllm_config, model, encoder_cache, device)
|
||||
@@ -13,11 +13,12 @@ from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
|
||||
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
|
||||
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
|
||||
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
|
||||
class ModelState:
|
||||
class DefaultModelState(ModelState):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
67
vllm/v1/worker/gpu/model_states/interface.py
Normal file
67
vllm/v1/worker/gpu/model_states/interface.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
|
||||
class ModelState(ABC):
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
model: nn.Module,
|
||||
encoder_cache: EncoderCache | None,
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_staged_writes(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_mm_embeddings(
|
||||
self,
|
||||
scheduled_encoder_inputs: dict[str, list[int]],
|
||||
input_batch: InputBatch,
|
||||
req_states: RequestState,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def prepare_inputs(
|
||||
self, input_batch: InputBatch, req_states: RequestState
|
||||
) -> dict[str, torch.Tensor | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def prepare_dummy_inputs(
|
||||
self, num_reqs: int, num_tokens: int
|
||||
) -> dict[str, torch.Tensor | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def prepare_attn(
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
block_tables: tuple[torch.Tensor, ...],
|
||||
slot_mappings: torch.Tensor,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
Reference in New Issue
Block a user