[Misc] Better RayExecutor and multiprocessing compatibility (#14705)
Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
@@ -2147,20 +2147,48 @@ def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]:
|
||||
ctx.destroy(linger=0)
|
||||
|
||||
|
||||
def _check_multiproc_method():
|
||||
if (cuda_is_initialized()
|
||||
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
|
||||
logger.warning("CUDA was previously initialized. We must use "
|
||||
"the `spawn` multiprocessing start method. Setting "
|
||||
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
|
||||
"See https://docs.vllm.ai/en/latest/getting_started/"
|
||||
"troubleshooting.html#python-multiprocessing "
|
||||
"for more information.")
|
||||
def is_in_ray_actor():
|
||||
"""Check if we are in a Ray actor."""
|
||||
|
||||
try:
|
||||
import ray
|
||||
return (ray.is_initialized()
|
||||
and ray.get_runtime_context().get_actor_id() is not None)
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def _maybe_force_spawn():
|
||||
"""Check if we need to force the use of the `spawn` multiprocessing start
|
||||
method.
|
||||
"""
|
||||
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn":
|
||||
return
|
||||
|
||||
reason = None
|
||||
if cuda_is_initialized():
|
||||
reason = "CUDA is initialized"
|
||||
elif is_in_ray_actor():
|
||||
reason = "In a Ray actor and can only be spawned"
|
||||
|
||||
if reason is not None:
|
||||
logger.warning(
|
||||
"We must use the `spawn` multiprocessing start method. "
|
||||
"Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
|
||||
"See https://docs.vllm.ai/en/latest/getting_started/"
|
||||
"troubleshooting.html#python-multiprocessing "
|
||||
"for more information. Reason: %s", reason)
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
|
||||
def get_mp_context():
|
||||
_check_multiproc_method()
|
||||
"""Get a multiprocessing context with a particular method (spawn or fork).
|
||||
By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to
|
||||
determine the multiprocessing method (default is fork). However, under
|
||||
certain conditions, we may enforce spawn and override the value of
|
||||
VLLM_WORKER_MULTIPROC_METHOD.
|
||||
"""
|
||||
_maybe_force_spawn()
|
||||
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
|
||||
return multiprocessing.get_context(mp_method)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user