[V1][Core] Add worker_base for v1 worker (#12816)

Signed-off-by: Aoyu <aoyuzhan@amazon.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Aoyu <aoyuzhan@amazon.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
Aoyu
2025-02-13 20:35:18 +08:00
committed by GitHub
parent c9d3ecf016
commit 2092a6fa7d
4 changed files with 154 additions and 53 deletions

View File

@@ -21,6 +21,7 @@ from vllm.utils import GiB_bytes
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase
logger = init_logger(__name__)
@@ -28,7 +29,7 @@ if TYPE_CHECKING:
from vllm.v1.core.scheduler_output import SchedulerOutput
class Worker:
class Worker(WorkerBase):
def __init__(
self,
@@ -39,23 +40,11 @@ class Worker:
is_driver_worker: bool = False,
):
# TODO: use WorkerBase.__init__(self, vllm_config=vllm_config)
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
super().__init__(vllm_config=vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker)
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
@@ -126,7 +115,8 @@ class Worker:
set_random_seed(self.model_config.seed)
# Construct the model runner
self.model_runner = GPUModelRunner(self.vllm_config, self.device)
self.model_runner: GPUModelRunner = GPUModelRunner(
self.vllm_config, self.device)
def load_model(self) -> None:
if self.vllm_config.model_config.enable_sleep_mode:

View File

@@ -0,0 +1,63 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0
logger = init_logger(__name__)
class WorkerBase(WorkerBaseV0):
"""
Abstract class for v1 worker, mainly define some methods for v1.
For methods shared by v0 and v1, define them in v0 WorkerBase
"""
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
):
"""
Initialize common worker components.
Args:
vllm_config: Complete vLLM configuration
local_rank: Local device index
rank: Global rank in distributed setup
distributed_init_method: Distributed initialization method
is_driver_worker: Whether this worker handles driver
responsibilities
"""
# Configuration storage
super().__init__(vllm_config=vllm_config)
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
# Device and model state
self.device: Optional[torch.device] = None
self.model_runner: Optional[nn.Module] = None
def get_kv_cache_spec(self) -> KVCacheSpec:
"""Get specifications for KV cache implementation."""
raise NotImplementedError
def compile_or_warm_up_model(self) -> None:
"""Prepare model for execution through compilation/warmup."""
raise NotImplementedError
def check_health(self) -> None:
"""Basic health check (override for device-specific checks)."""
return