[Core] RayWorkerVllm --> WorkerWrapper to reduce duplication (#4024)
[Core] replace narrow-usage RayWorkerVllm to general WorkerWrapper to reduce code duplication (#4024)
This commit is contained in:
@@ -1,8 +1,14 @@
|
||||
import importlib
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class WorkerBase(ABC):
|
||||
@@ -82,3 +88,53 @@ class LoraNotSupportedWorkerBase(WorkerBase):
|
||||
|
||||
def list_loras(self) -> List[int]:
|
||||
raise ValueError(f"{type(self)} does not support LoRA")
|
||||
|
||||
|
||||
class WorkerWrapperBase:
|
||||
"""
|
||||
The whole point of this class is to lazily initialize the worker.
|
||||
We first instantiate the WorkerWrapper, which remembers the worker module
|
||||
and class name. Then, when we call `update_environment_variables`, and the
|
||||
real initialization happens in `init_worker`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
worker_module_name=None,
|
||||
worker_class_name=None) -> None:
|
||||
self.worker_module_name = worker_module_name
|
||||
self.worker_class_name = worker_class_name
|
||||
self.worker = None
|
||||
|
||||
def update_environment_variables(self, envs: Dict[str, str]) -> None:
|
||||
key = 'CUDA_VISIBLE_DEVICES'
|
||||
if key in envs and key in os.environ:
|
||||
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
|
||||
# suppress the warning in `update_environment_variables`
|
||||
del os.environ[key]
|
||||
update_environment_variables(envs)
|
||||
|
||||
def init_worker(self, *args, **kwargs):
|
||||
"""
|
||||
Actual initialization of the worker class.
|
||||
Arguments are passed to the worker class constructor.
|
||||
"""
|
||||
mod = importlib.import_module(self.worker_module_name)
|
||||
worker_class = getattr(mod, self.worker_class_name)
|
||||
self.worker = worker_class(*args, **kwargs)
|
||||
|
||||
def execute_method(self, method, *args, **kwargs):
|
||||
try:
|
||||
if hasattr(self, method):
|
||||
executor = getattr(self, method)
|
||||
else:
|
||||
executor = getattr(self.worker, method)
|
||||
return executor(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# if the driver worker also execute methods,
|
||||
# exceptions in the rest worker may cause deadlock in rpc like ray
|
||||
# see https://github.com/vllm-project/vllm/issues/3455
|
||||
# print the error and inform the user to solve the error
|
||||
msg = (f"Error executing method {method}. "
|
||||
"This might cause deadlock in distributed execution.")
|
||||
logger.exception(msg)
|
||||
raise e
|
||||
|
||||
Reference in New Issue
Block a user