diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 4735035d7..6e1e41ed6 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -104,16 +104,7 @@ class MultiprocExecutor(Executor): self.shutdown_event = threading.Event() self.failure_callback: FailureCallback | None = None - self.world_size = self.parallel_config.world_size - assert self.world_size % self.parallel_config.nnodes_within_dp == 0, ( - f"global world_size ({self.parallel_config.world_size}) must be " - f"divisible by nnodes_within_dp " - f"({self.parallel_config.nnodes_within_dp}). " - ) - self.local_world_size = self.parallel_config.local_world_size - tp_size = self.parallel_config.tensor_parallel_size - pp_size = self.parallel_config.pipeline_parallel_size - pcp_size = self.parallel_config.prefill_context_parallel_size + tp_size, pp_size, pcp_size = self._get_parallel_sizes() assert self.world_size == tp_size * pp_size * pcp_size, ( f"world_size ({self.world_size}) must be equal to the " f"tensor_parallel_size ({tp_size}) x pipeline" @@ -154,6 +145,7 @@ class MultiprocExecutor(Executor): ) for local_rank in range(self.local_world_size): global_rank = global_start_rank + local_rank + is_driver_worker = self._is_driver_worker(global_rank) unready_workers.append( WorkerProc.make_worker_process( vllm_config=self.vllm_config, @@ -162,6 +154,7 @@ class MultiprocExecutor(Executor): distributed_init_method=distributed_init_method, input_shm_handle=scheduler_output_handle, shared_worker_lock=shared_worker_lock, + is_driver_worker=is_driver_worker, ) ) @@ -199,6 +192,11 @@ class MultiprocExecutor(Executor): # Wait for all remote response mqs to be ready. for response_mq in self.response_mqs: response_mq.wait_until_ready() + + self.futures_queue = deque[tuple[FutureWrapper, Callable]]() + + self._post_init_executor() + success = True finally: if not success: @@ -209,10 +207,27 @@ class MultiprocExecutor(Executor): uw.death_writer.close() self._ensure_worker_termination([uw.proc for uw in unready_workers]) - self.futures_queue = deque[tuple[FutureWrapper, Callable]]() - self.output_rank = self._get_output_rank() + def _get_parallel_sizes(self) -> tuple[int, int, int]: + self.world_size = self.parallel_config.world_size + assert self.world_size % self.parallel_config.nnodes_within_dp == 0, ( + f"global world_size ({self.parallel_config.world_size}) must be " + f"divisible by nnodes_within_dp " + f"({self.parallel_config.nnodes_within_dp}). " + ) + self.local_world_size = self.parallel_config.local_world_size + tp_size = self.parallel_config.tensor_parallel_size + pp_size = self.parallel_config.pipeline_parallel_size + pcp_size = self.parallel_config.prefill_context_parallel_size + return tp_size, pp_size, pcp_size + + def _post_init_executor(self) -> None: + pass + + def _is_driver_worker(self, rank: int) -> bool: + return rank % self.parallel_config.tensor_parallel_size == 0 + def start_worker_monitor(self, inline=False) -> None: workers = self.workers self_ref = weakref.ref(self) @@ -517,6 +532,7 @@ class WorkerProc: distributed_init_method: str, input_shm_handle: Handle, shared_worker_lock: LockType, + is_driver_worker: bool, ): self.rank = rank wrapper = WorkerWrapperBase(rpc_rank=local_rank, global_rank=rank) @@ -524,7 +540,6 @@ class WorkerProc: all_kwargs: list[dict] = [ {} for _ in range(vllm_config.parallel_config.world_size) ] - is_driver_worker = rank % vllm_config.parallel_config.tensor_parallel_size == 0 all_kwargs[local_rank] = { "vllm_config": vllm_config, "local_rank": local_rank, @@ -571,6 +586,7 @@ class WorkerProc: distributed_init_method: str, input_shm_handle, # Receive SchedulerOutput shared_worker_lock: LockType, + is_driver_worker: bool, ) -> UnreadyWorkerProcHandle: context = get_mp_context() # (reader, writer) @@ -588,6 +604,7 @@ class WorkerProc: "ready_pipe": (reader, writer), "death_pipe": death_reader, "shared_worker_lock": shared_worker_lock, + "is_driver_worker": is_driver_worker, } # Run EngineCore busy loop in background process. proc = context.Process( diff --git a/vllm/v1/executor/ray_utils.py b/vllm/v1/executor/ray_utils.py index dadf55006..21403e1c0 100644 --- a/vllm/v1/executor/ray_utils.py +++ b/vllm/v1/executor/ray_utils.py @@ -103,7 +103,7 @@ try: output = self.worker.model_runner.execute_model( scheduler_output, intermediate_tensors ) - if isinstance(output, IntermediateTensors): + if self._is_intermediate_tensors(output): return scheduler_output, grammar_output, output if isinstance(output, AsyncModelRunnerOutput): @@ -125,6 +125,9 @@ try: def override_env_vars(self, vars: dict[str, str]): os.environ.update(vars) + def _is_intermediate_tensors(self, output) -> bool: + return isinstance(output, IntermediateTensors) + ray_import_err = None except ImportError as e: