[TPU][Core] Enable Pipeline Parallelism on TPU backend (#28506)

Signed-off-by: Chenyaaang <chenyangli@google.com>
This commit is contained in:
Chenyaaang
2026-01-16 15:29:20 -08:00
committed by GitHub
parent ca21288080
commit 484e22bc18
2 changed files with 34 additions and 14 deletions

View File

@@ -104,16 +104,7 @@ class MultiprocExecutor(Executor):
self.shutdown_event = threading.Event()
self.failure_callback: FailureCallback | None = None
self.world_size = self.parallel_config.world_size
assert self.world_size % self.parallel_config.nnodes_within_dp == 0, (
f"global world_size ({self.parallel_config.world_size}) must be "
f"divisible by nnodes_within_dp "
f"({self.parallel_config.nnodes_within_dp}). "
)
self.local_world_size = self.parallel_config.local_world_size
tp_size = self.parallel_config.tensor_parallel_size
pp_size = self.parallel_config.pipeline_parallel_size
pcp_size = self.parallel_config.prefill_context_parallel_size
tp_size, pp_size, pcp_size = self._get_parallel_sizes()
assert self.world_size == tp_size * pp_size * pcp_size, (
f"world_size ({self.world_size}) must be equal to the "
f"tensor_parallel_size ({tp_size}) x pipeline"
@@ -154,6 +145,7 @@ class MultiprocExecutor(Executor):
)
for local_rank in range(self.local_world_size):
global_rank = global_start_rank + local_rank
is_driver_worker = self._is_driver_worker(global_rank)
unready_workers.append(
WorkerProc.make_worker_process(
vllm_config=self.vllm_config,
@@ -162,6 +154,7 @@ class MultiprocExecutor(Executor):
distributed_init_method=distributed_init_method,
input_shm_handle=scheduler_output_handle,
shared_worker_lock=shared_worker_lock,
is_driver_worker=is_driver_worker,
)
)
@@ -199,6 +192,11 @@ class MultiprocExecutor(Executor):
# Wait for all remote response mqs to be ready.
for response_mq in self.response_mqs:
response_mq.wait_until_ready()
self.futures_queue = deque[tuple[FutureWrapper, Callable]]()
self._post_init_executor()
success = True
finally:
if not success:
@@ -209,10 +207,27 @@ class MultiprocExecutor(Executor):
uw.death_writer.close()
self._ensure_worker_termination([uw.proc for uw in unready_workers])
self.futures_queue = deque[tuple[FutureWrapper, Callable]]()
self.output_rank = self._get_output_rank()
def _get_parallel_sizes(self) -> tuple[int, int, int]:
self.world_size = self.parallel_config.world_size
assert self.world_size % self.parallel_config.nnodes_within_dp == 0, (
f"global world_size ({self.parallel_config.world_size}) must be "
f"divisible by nnodes_within_dp "
f"({self.parallel_config.nnodes_within_dp}). "
)
self.local_world_size = self.parallel_config.local_world_size
tp_size = self.parallel_config.tensor_parallel_size
pp_size = self.parallel_config.pipeline_parallel_size
pcp_size = self.parallel_config.prefill_context_parallel_size
return tp_size, pp_size, pcp_size
def _post_init_executor(self) -> None:
pass
def _is_driver_worker(self, rank: int) -> bool:
return rank % self.parallel_config.tensor_parallel_size == 0
def start_worker_monitor(self, inline=False) -> None:
workers = self.workers
self_ref = weakref.ref(self)
@@ -517,6 +532,7 @@ class WorkerProc:
distributed_init_method: str,
input_shm_handle: Handle,
shared_worker_lock: LockType,
is_driver_worker: bool,
):
self.rank = rank
wrapper = WorkerWrapperBase(rpc_rank=local_rank, global_rank=rank)
@@ -524,7 +540,6 @@ class WorkerProc:
all_kwargs: list[dict] = [
{} for _ in range(vllm_config.parallel_config.world_size)
]
is_driver_worker = rank % vllm_config.parallel_config.tensor_parallel_size == 0
all_kwargs[local_rank] = {
"vllm_config": vllm_config,
"local_rank": local_rank,
@@ -571,6 +586,7 @@ class WorkerProc:
distributed_init_method: str,
input_shm_handle, # Receive SchedulerOutput
shared_worker_lock: LockType,
is_driver_worker: bool,
) -> UnreadyWorkerProcHandle:
context = get_mp_context()
# (reader, writer)
@@ -588,6 +604,7 @@ class WorkerProc:
"ready_pipe": (reader, writer),
"death_pipe": death_reader,
"shared_worker_lock": shared_worker_lock,
"is_driver_worker": is_driver_worker,
}
# Run EngineCore busy loop in background process.
proc = context.Process(

View File

@@ -103,7 +103,7 @@ try:
output = self.worker.model_runner.execute_model(
scheduler_output, intermediate_tensors
)
if isinstance(output, IntermediateTensors):
if self._is_intermediate_tensors(output):
return scheduler_output, grammar_output, output
if isinstance(output, AsyncModelRunnerOutput):
@@ -125,6 +125,9 @@ try:
def override_env_vars(self, vars: dict[str, str]):
os.environ.update(vars)
def _is_intermediate_tensors(self, output) -> bool:
return isinstance(output, IntermediateTensors)
ray_import_err = None
except ImportError as e: