[Core] Move ray-specific WorkerWrapperBase methods to RayWorkerWrapper (#35328)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -16,6 +16,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.network_utils import get_ip
|
||||
from vllm.v1.outputs import AsyncModelRunnerOutput
|
||||
from vllm.v1.serial_utils import run_method
|
||||
from vllm.v1.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -50,6 +51,29 @@ try:
|
||||
# that thread.
|
||||
self.compiled_dag_cuda_device_set = False
|
||||
|
||||
def adjust_rank(self, rank_mapping: dict[int, int]) -> None:
|
||||
"""
|
||||
Adjust the rpc_rank based on the given mapping.
|
||||
It is only used during the initialization of the executor,
|
||||
to adjust the rpc_rank of workers after we create all workers.
|
||||
"""
|
||||
if self.rpc_rank in rank_mapping:
|
||||
self.rpc_rank = rank_mapping[self.rpc_rank]
|
||||
|
||||
def execute_method(self, method: str | bytes, *args, **kwargs):
|
||||
try:
|
||||
return run_method(self, method, args, kwargs)
|
||||
except Exception as e:
|
||||
# if the driver worker also execute methods,
|
||||
# exceptions in the rest worker may cause deadlock in rpc
|
||||
# see https://github.com/vllm-project/vllm/issues/3455
|
||||
msg = (
|
||||
f"Error executing method {method!r}. "
|
||||
"This might cause deadlock in distributed execution."
|
||||
)
|
||||
logger.exception(msg)
|
||||
raise e
|
||||
|
||||
def get_node_ip(self) -> str:
|
||||
return get_ip()
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ from vllm.tracing import instrument
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
from vllm.v1.serial_utils import run_method
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
@@ -211,15 +210,6 @@ class WorkerWrapperBase:
|
||||
if self.worker is not None:
|
||||
self.worker.shutdown()
|
||||
|
||||
def adjust_rank(self, rank_mapping: dict[int, int]) -> None:
|
||||
"""
|
||||
Adjust the rpc_rank based on the given mapping.
|
||||
It is only used during the initialization of the executor,
|
||||
to adjust the rpc_rank of workers after we create all workers.
|
||||
"""
|
||||
if self.rpc_rank in rank_mapping:
|
||||
self.rpc_rank = rank_mapping[self.rpc_rank]
|
||||
|
||||
def update_environment_variables(
|
||||
self,
|
||||
envs_list: list[dict[str, str]],
|
||||
@@ -325,25 +315,6 @@ class WorkerWrapperBase:
|
||||
# To make vLLM config available during device initialization
|
||||
self.worker.init_device() # type: ignore
|
||||
|
||||
def execute_method(self, method: str | bytes, *args, **kwargs):
|
||||
try:
|
||||
# method resolution order:
|
||||
# if a method is defined in this class, it will be called directly.
|
||||
# otherwise, since we define `__getattr__` and redirect attribute
|
||||
# query to `self.worker`, the method will be called on the worker.
|
||||
return run_method(self, method, args, kwargs)
|
||||
except Exception as e:
|
||||
# if the driver worker also execute methods,
|
||||
# exceptions in the rest worker may cause deadlock in rpc like ray
|
||||
# see https://github.com/vllm-project/vllm/issues/3455
|
||||
# print the error and inform the user to solve the error
|
||||
msg = (
|
||||
f"Error executing method {method!r}. "
|
||||
"This might cause deadlock in distributed execution."
|
||||
)
|
||||
logger.exception(msg)
|
||||
raise e
|
||||
|
||||
def __getattr__(self, attr: str):
|
||||
return getattr(self.worker, attr)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user