[TPU][Core] Enable Pipeline Parallelism on TPU backend (#28506)
Signed-off-by: Chenyaaang <chenyangli@google.com>
This commit is contained in:
@@ -104,16 +104,7 @@ class MultiprocExecutor(Executor):
|
||||
self.shutdown_event = threading.Event()
|
||||
self.failure_callback: FailureCallback | None = None
|
||||
|
||||
self.world_size = self.parallel_config.world_size
|
||||
assert self.world_size % self.parallel_config.nnodes_within_dp == 0, (
|
||||
f"global world_size ({self.parallel_config.world_size}) must be "
|
||||
f"divisible by nnodes_within_dp "
|
||||
f"({self.parallel_config.nnodes_within_dp}). "
|
||||
)
|
||||
self.local_world_size = self.parallel_config.local_world_size
|
||||
tp_size = self.parallel_config.tensor_parallel_size
|
||||
pp_size = self.parallel_config.pipeline_parallel_size
|
||||
pcp_size = self.parallel_config.prefill_context_parallel_size
|
||||
tp_size, pp_size, pcp_size = self._get_parallel_sizes()
|
||||
assert self.world_size == tp_size * pp_size * pcp_size, (
|
||||
f"world_size ({self.world_size}) must be equal to the "
|
||||
f"tensor_parallel_size ({tp_size}) x pipeline"
|
||||
@@ -154,6 +145,7 @@ class MultiprocExecutor(Executor):
|
||||
)
|
||||
for local_rank in range(self.local_world_size):
|
||||
global_rank = global_start_rank + local_rank
|
||||
is_driver_worker = self._is_driver_worker(global_rank)
|
||||
unready_workers.append(
|
||||
WorkerProc.make_worker_process(
|
||||
vllm_config=self.vllm_config,
|
||||
@@ -162,6 +154,7 @@ class MultiprocExecutor(Executor):
|
||||
distributed_init_method=distributed_init_method,
|
||||
input_shm_handle=scheduler_output_handle,
|
||||
shared_worker_lock=shared_worker_lock,
|
||||
is_driver_worker=is_driver_worker,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -199,6 +192,11 @@ class MultiprocExecutor(Executor):
|
||||
# Wait for all remote response mqs to be ready.
|
||||
for response_mq in self.response_mqs:
|
||||
response_mq.wait_until_ready()
|
||||
|
||||
self.futures_queue = deque[tuple[FutureWrapper, Callable]]()
|
||||
|
||||
self._post_init_executor()
|
||||
|
||||
success = True
|
||||
finally:
|
||||
if not success:
|
||||
@@ -209,10 +207,27 @@ class MultiprocExecutor(Executor):
|
||||
uw.death_writer.close()
|
||||
self._ensure_worker_termination([uw.proc for uw in unready_workers])
|
||||
|
||||
self.futures_queue = deque[tuple[FutureWrapper, Callable]]()
|
||||
|
||||
self.output_rank = self._get_output_rank()
|
||||
|
||||
def _get_parallel_sizes(self) -> tuple[int, int, int]:
|
||||
self.world_size = self.parallel_config.world_size
|
||||
assert self.world_size % self.parallel_config.nnodes_within_dp == 0, (
|
||||
f"global world_size ({self.parallel_config.world_size}) must be "
|
||||
f"divisible by nnodes_within_dp "
|
||||
f"({self.parallel_config.nnodes_within_dp}). "
|
||||
)
|
||||
self.local_world_size = self.parallel_config.local_world_size
|
||||
tp_size = self.parallel_config.tensor_parallel_size
|
||||
pp_size = self.parallel_config.pipeline_parallel_size
|
||||
pcp_size = self.parallel_config.prefill_context_parallel_size
|
||||
return tp_size, pp_size, pcp_size
|
||||
|
||||
def _post_init_executor(self) -> None:
|
||||
pass
|
||||
|
||||
def _is_driver_worker(self, rank: int) -> bool:
|
||||
return rank % self.parallel_config.tensor_parallel_size == 0
|
||||
|
||||
def start_worker_monitor(self, inline=False) -> None:
|
||||
workers = self.workers
|
||||
self_ref = weakref.ref(self)
|
||||
@@ -517,6 +532,7 @@ class WorkerProc:
|
||||
distributed_init_method: str,
|
||||
input_shm_handle: Handle,
|
||||
shared_worker_lock: LockType,
|
||||
is_driver_worker: bool,
|
||||
):
|
||||
self.rank = rank
|
||||
wrapper = WorkerWrapperBase(rpc_rank=local_rank, global_rank=rank)
|
||||
@@ -524,7 +540,6 @@ class WorkerProc:
|
||||
all_kwargs: list[dict] = [
|
||||
{} for _ in range(vllm_config.parallel_config.world_size)
|
||||
]
|
||||
is_driver_worker = rank % vllm_config.parallel_config.tensor_parallel_size == 0
|
||||
all_kwargs[local_rank] = {
|
||||
"vllm_config": vllm_config,
|
||||
"local_rank": local_rank,
|
||||
@@ -571,6 +586,7 @@ class WorkerProc:
|
||||
distributed_init_method: str,
|
||||
input_shm_handle, # Receive SchedulerOutput
|
||||
shared_worker_lock: LockType,
|
||||
is_driver_worker: bool,
|
||||
) -> UnreadyWorkerProcHandle:
|
||||
context = get_mp_context()
|
||||
# (reader, writer)
|
||||
@@ -588,6 +604,7 @@ class WorkerProc:
|
||||
"ready_pipe": (reader, writer),
|
||||
"death_pipe": death_reader,
|
||||
"shared_worker_lock": shared_worker_lock,
|
||||
"is_driver_worker": is_driver_worker,
|
||||
}
|
||||
# Run EngineCore busy loop in background process.
|
||||
proc = context.Process(
|
||||
|
||||
@@ -103,7 +103,7 @@ try:
|
||||
output = self.worker.model_runner.execute_model(
|
||||
scheduler_output, intermediate_tensors
|
||||
)
|
||||
if isinstance(output, IntermediateTensors):
|
||||
if self._is_intermediate_tensors(output):
|
||||
return scheduler_output, grammar_output, output
|
||||
|
||||
if isinstance(output, AsyncModelRunnerOutput):
|
||||
@@ -125,6 +125,9 @@ try:
|
||||
def override_env_vars(self, vars: dict[str, str]):
|
||||
os.environ.update(vars)
|
||||
|
||||
def _is_intermediate_tensors(self, output) -> bool:
|
||||
return isinstance(output, IntermediateTensors)
|
||||
|
||||
ray_import_err = None
|
||||
|
||||
except ImportError as e:
|
||||
|
||||
Reference in New Issue
Block a user