Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -24,30 +24,36 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (destroy_distributed_environment,
|
||||
destroy_model_parallel)
|
||||
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
|
||||
MessageQueue)
|
||||
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
||||
get_pp_group, get_tp_group)
|
||||
from vllm.distributed import destroy_distributed_environment, destroy_model_parallel
|
||||
from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_dp_group,
|
||||
get_ep_group,
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import worker_receiver_cache_from_config
|
||||
from vllm.utils import (_maybe_force_spawn, decorate_logs,
|
||||
get_distributed_init_method, get_loopback_ip,
|
||||
get_mp_context, get_open_port, set_process_title)
|
||||
from vllm.utils import (
|
||||
_maybe_force_spawn,
|
||||
decorate_logs,
|
||||
get_distributed_init_method,
|
||||
get_loopback_ip,
|
||||
get_mp_context,
|
||||
get_open_port,
|
||||
set_process_title,
|
||||
)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.executor.abstract import Executor, FailureCallback
|
||||
from vllm.v1.executor.utils import get_and_update_mm_cache
|
||||
from vllm.v1.outputs import (AsyncModelRunnerOutput, DraftTokenIds,
|
||||
ModelRunnerOutput)
|
||||
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
|
||||
from vllm.v1.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiprocExecutor(Executor):
|
||||
|
||||
supports_pp: bool = True
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
@@ -65,7 +71,8 @@ class MultiprocExecutor(Executor):
|
||||
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
|
||||
f"world_size ({self.world_size}) must be equal to the "
|
||||
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
|
||||
f"_parallel_size ({pp_parallel_size}). ")
|
||||
f"_parallel_size ({pp_parallel_size}). "
|
||||
)
|
||||
|
||||
# Set multiprocessing envs
|
||||
set_multiprocessing_worker_envs()
|
||||
@@ -74,14 +81,15 @@ class MultiprocExecutor(Executor):
|
||||
# Since it only works for single node, we can use the loopback address
|
||||
# get_loopback_ip() for communication.
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_loopback_ip(), get_open_port())
|
||||
get_loopback_ip(), get_open_port()
|
||||
)
|
||||
|
||||
# Initialize worker and set up message queues for SchedulerOutputs
|
||||
# and ModelRunnerOutputs
|
||||
max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
|
||||
self.rpc_broadcast_mq = MessageQueue(self.world_size,
|
||||
self.world_size,
|
||||
max_chunk_bytes=max_chunk_bytes)
|
||||
self.rpc_broadcast_mq = MessageQueue(
|
||||
self.world_size, self.world_size, max_chunk_bytes=max_chunk_bytes
|
||||
)
|
||||
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
|
||||
|
||||
# Create workers
|
||||
@@ -99,7 +107,8 @@ class MultiprocExecutor(Executor):
|
||||
distributed_init_method=distributed_init_method,
|
||||
input_shm_handle=scheduler_output_handle,
|
||||
shared_worker_lock=shared_worker_lock,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
# Workers must be created before wait_for_ready to avoid
|
||||
# deadlock, since worker.init_device() does a device sync.
|
||||
@@ -120,8 +129,7 @@ class MultiprocExecutor(Executor):
|
||||
for uw in unready_workers:
|
||||
if uw.death_writer is not None:
|
||||
uw.death_writer.close()
|
||||
self._ensure_worker_termination(
|
||||
[uw.proc for uw in unready_workers])
|
||||
self._ensure_worker_termination([uw.proc for uw in unready_workers])
|
||||
|
||||
# For pipeline parallel, we use a thread pool for asynchronous
|
||||
# execute_model.
|
||||
@@ -130,7 +138,8 @@ class MultiprocExecutor(Executor):
|
||||
# from the response queue
|
||||
# _async_aggregate_workers_output also assumes a single IO thread
|
||||
self.io_thread_pool = ThreadPoolExecutor(
|
||||
max_workers=1, thread_name_prefix="mp_exec_io")
|
||||
max_workers=1, thread_name_prefix="mp_exec_io"
|
||||
)
|
||||
|
||||
self.output_rank = self._get_output_rank()
|
||||
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
||||
@@ -146,23 +155,22 @@ class MultiprocExecutor(Executor):
|
||||
sentinels = [h.proc.sentinel for h in workers]
|
||||
died = multiprocessing.connection.wait(sentinels)
|
||||
_self = self_ref()
|
||||
if not _self or getattr(_self, 'shutting_down', False):
|
||||
if not _self or getattr(_self, "shutting_down", False):
|
||||
return
|
||||
_self.is_failed = True
|
||||
proc_name = next(h.proc.name for h in workers
|
||||
if h.proc.sentinel == died[0])
|
||||
proc_name = next(h.proc.name for h in workers if h.proc.sentinel == died[0])
|
||||
logger.error(
|
||||
"Worker proc %s died unexpectedly, "
|
||||
"shutting down executor.", proc_name)
|
||||
"Worker proc %s died unexpectedly, shutting down executor.", proc_name
|
||||
)
|
||||
_self.shutdown()
|
||||
callback = _self.failure_callback
|
||||
if callback is not None:
|
||||
_self.failure_callback = None
|
||||
callback()
|
||||
|
||||
Thread(target=monitor_workers,
|
||||
daemon=True,
|
||||
name="MultiprocWorkerMonitor").start()
|
||||
Thread(
|
||||
target=monitor_workers, daemon=True, name="MultiprocWorkerMonitor"
|
||||
).start()
|
||||
|
||||
def register_failure_callback(self, callback: FailureCallback):
|
||||
if self.is_failed:
|
||||
@@ -175,47 +183,49 @@ class MultiprocExecutor(Executor):
|
||||
scheduler_output: SchedulerOutput,
|
||||
non_block: bool = False,
|
||||
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
|
||||
|
||||
if not self.has_connector:
|
||||
# get output only from a single worker (output_rank)
|
||||
(output, ) = self.collective_rpc(
|
||||
(output,) = self.collective_rpc(
|
||||
"execute_model",
|
||||
args=(scheduler_output, ),
|
||||
args=(scheduler_output,),
|
||||
unique_reply_rank=self.output_rank,
|
||||
non_block=non_block,
|
||||
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS)
|
||||
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
|
||||
)
|
||||
return output
|
||||
|
||||
# get output from all workers
|
||||
outputs = self.collective_rpc(
|
||||
"execute_model",
|
||||
args=(scheduler_output, ),
|
||||
args=(scheduler_output,),
|
||||
non_block=non_block,
|
||||
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS)
|
||||
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
# aggregate all workers output to a single output
|
||||
if non_block:
|
||||
return self.kv_output_aggregator.async_aggregate(
|
||||
outputs, self.output_rank)
|
||||
return self.kv_output_aggregator.async_aggregate(outputs, self.output_rank)
|
||||
return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
|
||||
|
||||
def execute_dummy_batch(self) -> None:
|
||||
self.collective_rpc("execute_dummy_batch",
|
||||
unique_reply_rank=self.output_rank)
|
||||
self.collective_rpc("execute_dummy_batch", unique_reply_rank=self.output_rank)
|
||||
|
||||
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||
# OPTIMIZATION: Get output only from a single worker (output_rank)
|
||||
outputs = self.collective_rpc("take_draft_token_ids",
|
||||
unique_reply_rank=self.output_rank)
|
||||
outputs = self.collective_rpc(
|
||||
"take_draft_token_ids", unique_reply_rank=self.output_rank
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict] = None,
|
||||
non_block: bool = False,
|
||||
unique_reply_rank: Optional[int] = None) -> list[Any]:
|
||||
def collective_rpc(
|
||||
self,
|
||||
method: Union[str, Callable],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict] = None,
|
||||
non_block: bool = False,
|
||||
unique_reply_rank: Optional[int] = None,
|
||||
) -> list[Any]:
|
||||
if self.is_failed:
|
||||
raise RuntimeError("Executor failed.")
|
||||
|
||||
@@ -230,42 +240,53 @@ class MultiprocExecutor(Executor):
|
||||
send_method = method
|
||||
else:
|
||||
send_method = cloudpickle.dumps(
|
||||
method, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
method, protocol=pickle.HIGHEST_PROTOCOL
|
||||
)
|
||||
self.rpc_broadcast_mq.enqueue(
|
||||
(send_method, args, kwargs, unique_reply_rank))
|
||||
(send_method, args, kwargs, unique_reply_rank)
|
||||
)
|
||||
|
||||
workers = (self.workers[unique_reply_rank],
|
||||
) if unique_reply_rank is not None else self.workers
|
||||
workers = (
|
||||
(self.workers[unique_reply_rank],)
|
||||
if unique_reply_rank is not None
|
||||
else self.workers
|
||||
)
|
||||
responses = []
|
||||
|
||||
def get_response(w: WorkerProcHandle,
|
||||
dequeue_timeout: Optional[float] = None,
|
||||
cancel_event: Optional[threading.Event] = None):
|
||||
def get_response(
|
||||
w: WorkerProcHandle,
|
||||
dequeue_timeout: Optional[float] = None,
|
||||
cancel_event: Optional[threading.Event] = None,
|
||||
):
|
||||
status, result = w.worker_response_mq.dequeue(
|
||||
timeout=dequeue_timeout, cancel=cancel_event)
|
||||
timeout=dequeue_timeout, cancel=cancel_event
|
||||
)
|
||||
|
||||
if status != WorkerProc.ResponseStatus.SUCCESS:
|
||||
raise RuntimeError(
|
||||
f"Worker failed with error '{result}', please check the"
|
||||
" stack trace above for the root cause")
|
||||
" stack trace above for the root cause"
|
||||
)
|
||||
return result
|
||||
|
||||
for w in workers:
|
||||
dequeue_timeout = None if deadline is None else (
|
||||
deadline - time.monotonic())
|
||||
dequeue_timeout = (
|
||||
None if deadline is None else (deadline - time.monotonic())
|
||||
)
|
||||
|
||||
if self.io_thread_pool is not None:
|
||||
# We must consume worker_response_mq from a single thread.
|
||||
result = self.io_thread_pool.submit( # type: ignore
|
||||
get_response, w, dequeue_timeout, self.shutdown_event)
|
||||
get_response, w, dequeue_timeout, self.shutdown_event
|
||||
)
|
||||
if not non_block:
|
||||
result = result.result()
|
||||
elif not non_block:
|
||||
result = get_response(w, dequeue_timeout,
|
||||
self.shutdown_event)
|
||||
result = get_response(w, dequeue_timeout, self.shutdown_event)
|
||||
else:
|
||||
raise RuntimeError("non_block can only be used when"
|
||||
" max_concurrent_batches > 1")
|
||||
raise RuntimeError(
|
||||
"non_block can only be used when max_concurrent_batches > 1"
|
||||
)
|
||||
responses.append(result)
|
||||
|
||||
return responses
|
||||
@@ -302,11 +323,11 @@ class MultiprocExecutor(Executor):
|
||||
|
||||
def shutdown(self):
|
||||
"""Properly shut down the executor and its workers"""
|
||||
if not getattr(self, 'shutting_down', False):
|
||||
if not getattr(self, "shutting_down", False):
|
||||
self.shutting_down = True
|
||||
|
||||
# Make sure all the worker processes are terminated first.
|
||||
if workers := getattr(self, 'workers', None):
|
||||
if workers := getattr(self, "workers", None):
|
||||
for w in workers:
|
||||
# Close death_writer to signal child processes to exit
|
||||
if w.death_writer is not None:
|
||||
@@ -348,6 +369,7 @@ class MultiprocExecutor(Executor):
|
||||
@dataclass
|
||||
class UnreadyWorkerProcHandle:
|
||||
"""WorkerProcess handle before READY."""
|
||||
|
||||
proc: BaseProcess
|
||||
rank: int
|
||||
ready_pipe: Connection
|
||||
@@ -363,8 +385,8 @@ class WorkerProcHandle:
|
||||
|
||||
@classmethod
|
||||
def from_unready_handle(
|
||||
cls, unready_handle: UnreadyWorkerProcHandle,
|
||||
worker_response_mq: MessageQueue) -> "WorkerProcHandle":
|
||||
cls, unready_handle: UnreadyWorkerProcHandle, worker_response_mq: MessageQueue
|
||||
) -> "WorkerProcHandle":
|
||||
return cls(
|
||||
proc=unready_handle.proc,
|
||||
rank=unready_handle.rank,
|
||||
@@ -393,8 +415,7 @@ class WorkerProc:
|
||||
all_kwargs: list[dict] = [
|
||||
{} for _ in range(vllm_config.parallel_config.world_size)
|
||||
]
|
||||
is_driver_worker = (
|
||||
rank % vllm_config.parallel_config.tensor_parallel_size == 0)
|
||||
is_driver_worker = rank % vllm_config.parallel_config.tensor_parallel_size == 0
|
||||
all_kwargs[rank] = {
|
||||
"vllm_config": vllm_config,
|
||||
"local_rank": local_rank,
|
||||
@@ -407,7 +428,8 @@ class WorkerProc:
|
||||
|
||||
# Initialize MessageQueue for receiving SchedulerOutput
|
||||
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
|
||||
input_shm_handle, self.worker.rank)
|
||||
input_shm_handle, self.worker.rank
|
||||
)
|
||||
|
||||
# Initializes a message queue for sending the model output
|
||||
self.worker_response_mq = MessageQueue(1, 1)
|
||||
@@ -419,19 +441,22 @@ class WorkerProc:
|
||||
self.async_output_copy_thread = Thread(
|
||||
target=self.async_output_busy_loop,
|
||||
daemon=True,
|
||||
name="WorkerAsyncOutputCopy")
|
||||
name="WorkerAsyncOutputCopy",
|
||||
)
|
||||
self.async_output_copy_thread.start()
|
||||
|
||||
# Initialize multimodal receiver cache if needed
|
||||
self.mm_receiver_cache = worker_receiver_cache_from_config(
|
||||
vllm_config, MULTIMODAL_REGISTRY, shared_worker_lock)
|
||||
vllm_config, MULTIMODAL_REGISTRY, shared_worker_lock
|
||||
)
|
||||
|
||||
# Initialize device
|
||||
self.worker.init_device()
|
||||
|
||||
# Set process title and log prefix
|
||||
self.setup_proc_title_and_log_prefix(
|
||||
enable_ep=vllm_config.parallel_config.enable_expert_parallel)
|
||||
enable_ep=vllm_config.parallel_config.enable_expert_parallel
|
||||
)
|
||||
|
||||
# Load model
|
||||
self.worker.load_model()
|
||||
@@ -463,10 +488,12 @@ class WorkerProc:
|
||||
"shared_worker_lock": shared_worker_lock,
|
||||
}
|
||||
# Run EngineCore busy loop in background process.
|
||||
proc = context.Process(target=WorkerProc.worker_main,
|
||||
kwargs=process_kwargs,
|
||||
name=f"VllmWorker-{rank}",
|
||||
daemon=True)
|
||||
proc = context.Process(
|
||||
target=WorkerProc.worker_main,
|
||||
kwargs=process_kwargs,
|
||||
name=f"VllmWorker-{rank}",
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
proc.start()
|
||||
writer.close()
|
||||
@@ -476,16 +503,18 @@ class WorkerProc:
|
||||
|
||||
@staticmethod
|
||||
def wait_for_ready(
|
||||
unready_proc_handles: list[UnreadyWorkerProcHandle]
|
||||
unready_proc_handles: list[UnreadyWorkerProcHandle],
|
||||
) -> list[WorkerProcHandle]:
|
||||
|
||||
e = Exception("WorkerProc initialization failed due to "
|
||||
"an exception in a background process. "
|
||||
"See stack trace for root cause.")
|
||||
e = Exception(
|
||||
"WorkerProc initialization failed due to "
|
||||
"an exception in a background process. "
|
||||
"See stack trace for root cause."
|
||||
)
|
||||
|
||||
pipes = {handle.ready_pipe: handle for handle in unready_proc_handles}
|
||||
ready_proc_handles: list[Optional[WorkerProcHandle]] = (
|
||||
[None] * len(unready_proc_handles))
|
||||
ready_proc_handles: list[Optional[WorkerProcHandle]] = [None] * len(
|
||||
unready_proc_handles
|
||||
)
|
||||
while pipes:
|
||||
ready = multiprocessing.connection.wait(pipes.keys())
|
||||
for pipe in ready:
|
||||
@@ -499,10 +528,13 @@ class WorkerProc:
|
||||
|
||||
# Extract the message queue handle.
|
||||
worker_response_mq = MessageQueue.create_from_handle(
|
||||
response["handle"], 0)
|
||||
response["handle"], 0
|
||||
)
|
||||
ready_proc_handles[unready_proc_handle.rank] = (
|
||||
WorkerProcHandle.from_unready_handle(
|
||||
unready_proc_handle, worker_response_mq))
|
||||
unready_proc_handle, worker_response_mq
|
||||
)
|
||||
)
|
||||
|
||||
except EOFError:
|
||||
e.__suppress_context__ = True
|
||||
@@ -523,8 +555,8 @@ class WorkerProc:
|
||||
|
||||
@staticmethod
|
||||
def worker_main(*args, **kwargs):
|
||||
""" Worker initialization and execution loops.
|
||||
This runs a background process """
|
||||
"""Worker initialization and execution loops.
|
||||
This runs a background process"""
|
||||
|
||||
# Signal handler used for graceful termination.
|
||||
# SystemExit exception is only raised once to allow this and worker
|
||||
@@ -561,9 +593,9 @@ class WorkerProc:
|
||||
except Exception as e:
|
||||
logger.warning("Death monitoring error: %s", e)
|
||||
|
||||
death_monitor = Thread(target=monitor_parent_death,
|
||||
daemon=True,
|
||||
name="WorkerDeathMonitor")
|
||||
death_monitor = Thread(
|
||||
target=monitor_parent_death, daemon=True, name="WorkerDeathMonitor"
|
||||
)
|
||||
death_monitor.start()
|
||||
|
||||
try:
|
||||
@@ -571,12 +603,12 @@ class WorkerProc:
|
||||
worker = WorkerProc(*args, **kwargs)
|
||||
|
||||
# Send READY once we know everything is loaded
|
||||
ready_writer.send({
|
||||
"status":
|
||||
WorkerProc.READY_STR,
|
||||
"handle":
|
||||
worker.worker_response_mq.export_handle(),
|
||||
})
|
||||
ready_writer.send(
|
||||
{
|
||||
"status": WorkerProc.READY_STR,
|
||||
"handle": worker.worker_response_mq.export_handle(),
|
||||
}
|
||||
)
|
||||
|
||||
# Ensure message queues are ready. Will deadlock if re-ordered.
|
||||
# Must be kept consistent with the Executor
|
||||
@@ -653,15 +685,18 @@ class WorkerProc:
|
||||
"""Main busy loop for Multiprocessing Workers"""
|
||||
while True:
|
||||
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue(
|
||||
cancel=cancel, indefinite=True)
|
||||
cancel=cancel, indefinite=True
|
||||
)
|
||||
try:
|
||||
if isinstance(method, str):
|
||||
func = getattr(self.worker, method)
|
||||
elif isinstance(method, bytes):
|
||||
func = partial(cloudpickle.loads(method), self.worker)
|
||||
# retrieve from shm cache if available
|
||||
if self.mm_receiver_cache is not None \
|
||||
and func.__name__ == "execute_model":
|
||||
if (
|
||||
self.mm_receiver_cache is not None
|
||||
and func.__name__ == "execute_model"
|
||||
):
|
||||
get_and_update_mm_cache(self.mm_receiver_cache, args)
|
||||
output = func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
@@ -701,7 +736,7 @@ class WorkerProc:
|
||||
|
||||
|
||||
def set_multiprocessing_worker_envs():
|
||||
""" Set up environment variables that should be used when there are workers
|
||||
"""Set up environment variables that should be used when there are workers
|
||||
in a multiprocessing environment. This should be called by the parent
|
||||
process before worker processes are created"""
|
||||
|
||||
@@ -714,13 +749,16 @@ def set_multiprocessing_worker_envs():
|
||||
# impact on performance. The contention is amplified when running in a
|
||||
# container where CPU limits can cause throttling.
|
||||
default_omp_num_threads = 1
|
||||
if "OMP_NUM_THREADS" not in os.environ and (
|
||||
current_parallelism :=
|
||||
torch.get_num_threads()) > default_omp_num_threads:
|
||||
if (
|
||||
"OMP_NUM_THREADS" not in os.environ
|
||||
and (current_parallelism := torch.get_num_threads()) > default_omp_num_threads
|
||||
):
|
||||
logger.warning(
|
||||
"Reducing Torch parallelism from %d threads to %d to avoid "
|
||||
"unnecessary CPU contention. Set OMP_NUM_THREADS in the "
|
||||
"external environment to tune this value as needed.",
|
||||
current_parallelism, default_omp_num_threads)
|
||||
current_parallelism,
|
||||
default_omp_num_threads,
|
||||
)
|
||||
os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads)
|
||||
torch.set_num_threads(default_omp_num_threads)
|
||||
|
||||
Reference in New Issue
Block a user