[Core] RayWorkerVllm --> WorkerWrapper to reduce duplication (#4024)

[Core] replace narrow-usage RayWorkerVllm to general WorkerWrapper to reduce code duplication (#4024)
This commit is contained in:
youkaichao
2024-04-17 01:34:33 -07:00
committed by GitHub
parent 11d652bd4f
commit 8438e0569e
8 changed files with 174 additions and 114 deletions

View File

@@ -1,8 +1,14 @@
import importlib
import os
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import update_environment_variables
logger = init_logger(__name__)
class WorkerBase(ABC):
@@ -82,3 +88,53 @@ class LoraNotSupportedWorkerBase(WorkerBase):
def list_loras(self) -> List[int]:
raise ValueError(f"{type(self)} does not support LoRA")
class WorkerWrapperBase:
"""
The whole point of this class is to lazily initialize the worker.
We first instantiate the WorkerWrapper, which remembers the worker module
and class name. Then, when we call `update_environment_variables`, and the
real initialization happens in `init_worker`.
"""
def __init__(self,
worker_module_name=None,
worker_class_name=None) -> None:
self.worker_module_name = worker_module_name
self.worker_class_name = worker_class_name
self.worker = None
def update_environment_variables(self, envs: Dict[str, str]) -> None:
key = 'CUDA_VISIBLE_DEVICES'
if key in envs and key in os.environ:
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
# suppress the warning in `update_environment_variables`
del os.environ[key]
update_environment_variables(envs)
def init_worker(self, *args, **kwargs):
"""
Actual initialization of the worker class.
Arguments are passed to the worker class constructor.
"""
mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name)
self.worker = worker_class(*args, **kwargs)
def execute_method(self, method, *args, **kwargs):
try:
if hasattr(self, method):
executor = getattr(self, method)
else:
executor = getattr(self.worker, method)
return executor(*args, **kwargs)
except Exception as e:
# if the driver worker also execute methods,
# exceptions in the rest worker may cause deadlock in rpc like ray
# see https://github.com/vllm-project/vllm/issues/3455
# print the error and inform the user to solve the error
msg = (f"Error executing method {method}. "
"This might cause deadlock in distributed execution.")
logger.exception(msg)
raise e