[V1][Core] Add worker_base for v1 worker (#12816)
Signed-off-by: Aoyu <aoyuzhan@amazon.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Aoyu <aoyuzhan@amazon.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -21,6 +21,7 @@ from vllm.utils import GiB_bytes
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -28,7 +29,7 @@ if TYPE_CHECKING:
|
||||
from vllm.v1.core.scheduler_output import SchedulerOutput
|
||||
|
||||
|
||||
class Worker:
|
||||
class Worker(WorkerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -39,23 +40,11 @@ class Worker:
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
|
||||
# TODO: use WorkerBase.__init__(self, vllm_config=vllm_config)
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device_config = vllm_config.device_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
self.parallel_config.rank = rank
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=is_driver_worker)
|
||||
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
@@ -126,7 +115,8 @@ class Worker:
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
# Construct the model runner
|
||||
self.model_runner = GPUModelRunner(self.vllm_config, self.device)
|
||||
self.model_runner: GPUModelRunner = GPUModelRunner(
|
||||
self.vllm_config, self.device)
|
||||
|
||||
def load_model(self) -> None:
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
|
||||
63
vllm/v1/worker/worker_base.py
Normal file
63
vllm/v1/worker/worker_base.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class WorkerBase(WorkerBaseV0):
|
||||
"""
|
||||
Abstract class for v1 worker, mainly define some methods for v1.
|
||||
For methods shared by v0 and v1, define them in v0 WorkerBase
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize common worker components.
|
||||
|
||||
Args:
|
||||
vllm_config: Complete vLLM configuration
|
||||
local_rank: Local device index
|
||||
rank: Global rank in distributed setup
|
||||
distributed_init_method: Distributed initialization method
|
||||
is_driver_worker: Whether this worker handles driver
|
||||
responsibilities
|
||||
"""
|
||||
# Configuration storage
|
||||
super().__init__(vllm_config=vllm_config)
|
||||
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.is_driver_worker = is_driver_worker
|
||||
|
||||
# Device and model state
|
||||
self.device: Optional[torch.device] = None
|
||||
self.model_runner: Optional[nn.Module] = None
|
||||
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
"""Get specifications for KV cache implementation."""
|
||||
raise NotImplementedError
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
"""Prepare model for execution through compilation/warmup."""
|
||||
raise NotImplementedError
|
||||
|
||||
def check_health(self) -> None:
|
||||
"""Basic health check (override for device-specific checks)."""
|
||||
return
|
||||
Reference in New Issue
Block a user