[TPU][Core] Enable Pipeline Parallelism on TPU backend (#28506)

Signed-off-by: Chenyaaang <chenyangli@google.com>
This commit is contained in:
Chenyaaang
2026-01-16 15:29:20 -08:00
committed by GitHub
parent ca21288080
commit 484e22bc18
2 changed files with 34 additions and 14 deletions

View File

@@ -104,16 +104,7 @@ class MultiprocExecutor(Executor):
self.shutdown_event = threading.Event() self.shutdown_event = threading.Event()
self.failure_callback: FailureCallback | None = None self.failure_callback: FailureCallback | None = None
self.world_size = self.parallel_config.world_size tp_size, pp_size, pcp_size = self._get_parallel_sizes()
assert self.world_size % self.parallel_config.nnodes_within_dp == 0, (
f"global world_size ({self.parallel_config.world_size}) must be "
f"divisible by nnodes_within_dp "
f"({self.parallel_config.nnodes_within_dp}). "
)
self.local_world_size = self.parallel_config.local_world_size
tp_size = self.parallel_config.tensor_parallel_size
pp_size = self.parallel_config.pipeline_parallel_size
pcp_size = self.parallel_config.prefill_context_parallel_size
assert self.world_size == tp_size * pp_size * pcp_size, ( assert self.world_size == tp_size * pp_size * pcp_size, (
f"world_size ({self.world_size}) must be equal to the " f"world_size ({self.world_size}) must be equal to the "
f"tensor_parallel_size ({tp_size}) x pipeline" f"tensor_parallel_size ({tp_size}) x pipeline"
@@ -154,6 +145,7 @@ class MultiprocExecutor(Executor):
) )
for local_rank in range(self.local_world_size): for local_rank in range(self.local_world_size):
global_rank = global_start_rank + local_rank global_rank = global_start_rank + local_rank
is_driver_worker = self._is_driver_worker(global_rank)
unready_workers.append( unready_workers.append(
WorkerProc.make_worker_process( WorkerProc.make_worker_process(
vllm_config=self.vllm_config, vllm_config=self.vllm_config,
@@ -162,6 +154,7 @@ class MultiprocExecutor(Executor):
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
input_shm_handle=scheduler_output_handle, input_shm_handle=scheduler_output_handle,
shared_worker_lock=shared_worker_lock, shared_worker_lock=shared_worker_lock,
is_driver_worker=is_driver_worker,
) )
) )
@@ -199,6 +192,11 @@ class MultiprocExecutor(Executor):
# Wait for all remote response mqs to be ready. # Wait for all remote response mqs to be ready.
for response_mq in self.response_mqs: for response_mq in self.response_mqs:
response_mq.wait_until_ready() response_mq.wait_until_ready()
self.futures_queue = deque[tuple[FutureWrapper, Callable]]()
self._post_init_executor()
success = True success = True
finally: finally:
if not success: if not success:
@@ -209,10 +207,27 @@ class MultiprocExecutor(Executor):
uw.death_writer.close() uw.death_writer.close()
self._ensure_worker_termination([uw.proc for uw in unready_workers]) self._ensure_worker_termination([uw.proc for uw in unready_workers])
self.futures_queue = deque[tuple[FutureWrapper, Callable]]()
self.output_rank = self._get_output_rank() self.output_rank = self._get_output_rank()
def _get_parallel_sizes(self) -> tuple[int, int, int]:
self.world_size = self.parallel_config.world_size
assert self.world_size % self.parallel_config.nnodes_within_dp == 0, (
f"global world_size ({self.parallel_config.world_size}) must be "
f"divisible by nnodes_within_dp "
f"({self.parallel_config.nnodes_within_dp}). "
)
self.local_world_size = self.parallel_config.local_world_size
tp_size = self.parallel_config.tensor_parallel_size
pp_size = self.parallel_config.pipeline_parallel_size
pcp_size = self.parallel_config.prefill_context_parallel_size
return tp_size, pp_size, pcp_size
def _post_init_executor(self) -> None:
pass
def _is_driver_worker(self, rank: int) -> bool:
return rank % self.parallel_config.tensor_parallel_size == 0
def start_worker_monitor(self, inline=False) -> None: def start_worker_monitor(self, inline=False) -> None:
workers = self.workers workers = self.workers
self_ref = weakref.ref(self) self_ref = weakref.ref(self)
@@ -517,6 +532,7 @@ class WorkerProc:
distributed_init_method: str, distributed_init_method: str,
input_shm_handle: Handle, input_shm_handle: Handle,
shared_worker_lock: LockType, shared_worker_lock: LockType,
is_driver_worker: bool,
): ):
self.rank = rank self.rank = rank
wrapper = WorkerWrapperBase(rpc_rank=local_rank, global_rank=rank) wrapper = WorkerWrapperBase(rpc_rank=local_rank, global_rank=rank)
@@ -524,7 +540,6 @@ class WorkerProc:
all_kwargs: list[dict] = [ all_kwargs: list[dict] = [
{} for _ in range(vllm_config.parallel_config.world_size) {} for _ in range(vllm_config.parallel_config.world_size)
] ]
is_driver_worker = rank % vllm_config.parallel_config.tensor_parallel_size == 0
all_kwargs[local_rank] = { all_kwargs[local_rank] = {
"vllm_config": vllm_config, "vllm_config": vllm_config,
"local_rank": local_rank, "local_rank": local_rank,
@@ -571,6 +586,7 @@ class WorkerProc:
distributed_init_method: str, distributed_init_method: str,
input_shm_handle, # Receive SchedulerOutput input_shm_handle, # Receive SchedulerOutput
shared_worker_lock: LockType, shared_worker_lock: LockType,
is_driver_worker: bool,
) -> UnreadyWorkerProcHandle: ) -> UnreadyWorkerProcHandle:
context = get_mp_context() context = get_mp_context()
# (reader, writer) # (reader, writer)
@@ -588,6 +604,7 @@ class WorkerProc:
"ready_pipe": (reader, writer), "ready_pipe": (reader, writer),
"death_pipe": death_reader, "death_pipe": death_reader,
"shared_worker_lock": shared_worker_lock, "shared_worker_lock": shared_worker_lock,
"is_driver_worker": is_driver_worker,
} }
# Run EngineCore busy loop in background process. # Run EngineCore busy loop in background process.
proc = context.Process( proc = context.Process(

View File

@@ -103,7 +103,7 @@ try:
output = self.worker.model_runner.execute_model( output = self.worker.model_runner.execute_model(
scheduler_output, intermediate_tensors scheduler_output, intermediate_tensors
) )
if isinstance(output, IntermediateTensors): if self._is_intermediate_tensors(output):
return scheduler_output, grammar_output, output return scheduler_output, grammar_output, output
if isinstance(output, AsyncModelRunnerOutput): if isinstance(output, AsyncModelRunnerOutput):
@@ -125,6 +125,9 @@ try:
def override_env_vars(self, vars: dict[str, str]): def override_env_vars(self, vars: dict[str, str]):
os.environ.update(vars) os.environ.update(vars)
def _is_intermediate_tensors(self, output) -> bool:
return isinstance(output, IntermediateTensors)
ray_import_err = None ray_import_err = None
except ImportError as e: except ImportError as e: