[Core] Move ray-specific WorkerWrapperBase methods to RayWorkerWrapper (#35328)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-03-04 20:11:59 -08:00
committed by GitHub
parent 3b23d57c96
commit 16c472abe7
2 changed files with 24 additions and 29 deletions

View File

@@ -16,6 +16,7 @@ from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.network_utils import get_ip
from vllm.v1.outputs import AsyncModelRunnerOutput
from vllm.v1.serial_utils import run_method
from vllm.v1.worker.worker_base import WorkerWrapperBase
if TYPE_CHECKING:
@@ -50,6 +51,29 @@ try:
# that thread.
self.compiled_dag_cuda_device_set = False
def adjust_rank(self, rank_mapping: dict[int, int]) -> None:
"""
Adjust the rpc_rank based on the given mapping.
It is only used during the initialization of the executor,
to adjust the rpc_rank of workers after we create all workers.
"""
if self.rpc_rank in rank_mapping:
self.rpc_rank = rank_mapping[self.rpc_rank]
def execute_method(self, method: str | bytes, *args, **kwargs):
try:
return run_method(self, method, args, kwargs)
except Exception as e:
# if the driver worker also execute methods,
# exceptions in the rest worker may cause deadlock in rpc
# see https://github.com/vllm-project/vllm/issues/3455
msg = (
f"Error executing method {method!r}. "
"This might cause deadlock in distributed execution."
)
logger.exception(msg)
raise e
def get_node_ip(self) -> str:
return get_ip()

View File

@@ -15,7 +15,6 @@ from vllm.tracing import instrument
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.system_utils import update_environment_variables
from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.v1.serial_utils import run_method
if TYPE_CHECKING:
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
@@ -211,15 +210,6 @@ class WorkerWrapperBase:
if self.worker is not None:
self.worker.shutdown()
def adjust_rank(self, rank_mapping: dict[int, int]) -> None:
"""
Adjust the rpc_rank based on the given mapping.
It is only used during the initialization of the executor,
to adjust the rpc_rank of workers after we create all workers.
"""
if self.rpc_rank in rank_mapping:
self.rpc_rank = rank_mapping[self.rpc_rank]
def update_environment_variables(
self,
envs_list: list[dict[str, str]],
@@ -325,25 +315,6 @@ class WorkerWrapperBase:
# To make vLLM config available during device initialization
self.worker.init_device() # type: ignore
def execute_method(self, method: str | bytes, *args, **kwargs):
try:
# method resolution order:
# if a method is defined in this class, it will be called directly.
# otherwise, since we define `__getattr__` and redirect attribute
# query to `self.worker`, the method will be called on the worker.
return run_method(self, method, args, kwargs)
except Exception as e:
# if the driver worker also execute methods,
# exceptions in the rest worker may cause deadlock in rpc like ray
# see https://github.com/vllm-project/vllm/issues/3455
# print the error and inform the user to solve the error
msg = (
f"Error executing method {method!r}. "
"This might cause deadlock in distributed execution."
)
logger.exception(msg)
raise e
def __getattr__(self, attr: str):
return getattr(self.worker, attr)