[TPU][Core] Enable Pipeline Parallelism on TPU backend (#28506)
Signed-off-by: Chenyaaang <chenyangli@google.com>
This commit is contained in:
@@ -104,16 +104,7 @@ class MultiprocExecutor(Executor):
|
|||||||
self.shutdown_event = threading.Event()
|
self.shutdown_event = threading.Event()
|
||||||
self.failure_callback: FailureCallback | None = None
|
self.failure_callback: FailureCallback | None = None
|
||||||
|
|
||||||
self.world_size = self.parallel_config.world_size
|
tp_size, pp_size, pcp_size = self._get_parallel_sizes()
|
||||||
assert self.world_size % self.parallel_config.nnodes_within_dp == 0, (
|
|
||||||
f"global world_size ({self.parallel_config.world_size}) must be "
|
|
||||||
f"divisible by nnodes_within_dp "
|
|
||||||
f"({self.parallel_config.nnodes_within_dp}). "
|
|
||||||
)
|
|
||||||
self.local_world_size = self.parallel_config.local_world_size
|
|
||||||
tp_size = self.parallel_config.tensor_parallel_size
|
|
||||||
pp_size = self.parallel_config.pipeline_parallel_size
|
|
||||||
pcp_size = self.parallel_config.prefill_context_parallel_size
|
|
||||||
assert self.world_size == tp_size * pp_size * pcp_size, (
|
assert self.world_size == tp_size * pp_size * pcp_size, (
|
||||||
f"world_size ({self.world_size}) must be equal to the "
|
f"world_size ({self.world_size}) must be equal to the "
|
||||||
f"tensor_parallel_size ({tp_size}) x pipeline"
|
f"tensor_parallel_size ({tp_size}) x pipeline"
|
||||||
@@ -154,6 +145,7 @@ class MultiprocExecutor(Executor):
|
|||||||
)
|
)
|
||||||
for local_rank in range(self.local_world_size):
|
for local_rank in range(self.local_world_size):
|
||||||
global_rank = global_start_rank + local_rank
|
global_rank = global_start_rank + local_rank
|
||||||
|
is_driver_worker = self._is_driver_worker(global_rank)
|
||||||
unready_workers.append(
|
unready_workers.append(
|
||||||
WorkerProc.make_worker_process(
|
WorkerProc.make_worker_process(
|
||||||
vllm_config=self.vllm_config,
|
vllm_config=self.vllm_config,
|
||||||
@@ -162,6 +154,7 @@ class MultiprocExecutor(Executor):
|
|||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
input_shm_handle=scheduler_output_handle,
|
input_shm_handle=scheduler_output_handle,
|
||||||
shared_worker_lock=shared_worker_lock,
|
shared_worker_lock=shared_worker_lock,
|
||||||
|
is_driver_worker=is_driver_worker,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -199,6 +192,11 @@ class MultiprocExecutor(Executor):
|
|||||||
# Wait for all remote response mqs to be ready.
|
# Wait for all remote response mqs to be ready.
|
||||||
for response_mq in self.response_mqs:
|
for response_mq in self.response_mqs:
|
||||||
response_mq.wait_until_ready()
|
response_mq.wait_until_ready()
|
||||||
|
|
||||||
|
self.futures_queue = deque[tuple[FutureWrapper, Callable]]()
|
||||||
|
|
||||||
|
self._post_init_executor()
|
||||||
|
|
||||||
success = True
|
success = True
|
||||||
finally:
|
finally:
|
||||||
if not success:
|
if not success:
|
||||||
@@ -209,10 +207,27 @@ class MultiprocExecutor(Executor):
|
|||||||
uw.death_writer.close()
|
uw.death_writer.close()
|
||||||
self._ensure_worker_termination([uw.proc for uw in unready_workers])
|
self._ensure_worker_termination([uw.proc for uw in unready_workers])
|
||||||
|
|
||||||
self.futures_queue = deque[tuple[FutureWrapper, Callable]]()
|
|
||||||
|
|
||||||
self.output_rank = self._get_output_rank()
|
self.output_rank = self._get_output_rank()
|
||||||
|
|
||||||
|
def _get_parallel_sizes(self) -> tuple[int, int, int]:
|
||||||
|
self.world_size = self.parallel_config.world_size
|
||||||
|
assert self.world_size % self.parallel_config.nnodes_within_dp == 0, (
|
||||||
|
f"global world_size ({self.parallel_config.world_size}) must be "
|
||||||
|
f"divisible by nnodes_within_dp "
|
||||||
|
f"({self.parallel_config.nnodes_within_dp}). "
|
||||||
|
)
|
||||||
|
self.local_world_size = self.parallel_config.local_world_size
|
||||||
|
tp_size = self.parallel_config.tensor_parallel_size
|
||||||
|
pp_size = self.parallel_config.pipeline_parallel_size
|
||||||
|
pcp_size = self.parallel_config.prefill_context_parallel_size
|
||||||
|
return tp_size, pp_size, pcp_size
|
||||||
|
|
||||||
|
def _post_init_executor(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _is_driver_worker(self, rank: int) -> bool:
|
||||||
|
return rank % self.parallel_config.tensor_parallel_size == 0
|
||||||
|
|
||||||
def start_worker_monitor(self, inline=False) -> None:
|
def start_worker_monitor(self, inline=False) -> None:
|
||||||
workers = self.workers
|
workers = self.workers
|
||||||
self_ref = weakref.ref(self)
|
self_ref = weakref.ref(self)
|
||||||
@@ -517,6 +532,7 @@ class WorkerProc:
|
|||||||
distributed_init_method: str,
|
distributed_init_method: str,
|
||||||
input_shm_handle: Handle,
|
input_shm_handle: Handle,
|
||||||
shared_worker_lock: LockType,
|
shared_worker_lock: LockType,
|
||||||
|
is_driver_worker: bool,
|
||||||
):
|
):
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
wrapper = WorkerWrapperBase(rpc_rank=local_rank, global_rank=rank)
|
wrapper = WorkerWrapperBase(rpc_rank=local_rank, global_rank=rank)
|
||||||
@@ -524,7 +540,6 @@ class WorkerProc:
|
|||||||
all_kwargs: list[dict] = [
|
all_kwargs: list[dict] = [
|
||||||
{} for _ in range(vllm_config.parallel_config.world_size)
|
{} for _ in range(vllm_config.parallel_config.world_size)
|
||||||
]
|
]
|
||||||
is_driver_worker = rank % vllm_config.parallel_config.tensor_parallel_size == 0
|
|
||||||
all_kwargs[local_rank] = {
|
all_kwargs[local_rank] = {
|
||||||
"vllm_config": vllm_config,
|
"vllm_config": vllm_config,
|
||||||
"local_rank": local_rank,
|
"local_rank": local_rank,
|
||||||
@@ -571,6 +586,7 @@ class WorkerProc:
|
|||||||
distributed_init_method: str,
|
distributed_init_method: str,
|
||||||
input_shm_handle, # Receive SchedulerOutput
|
input_shm_handle, # Receive SchedulerOutput
|
||||||
shared_worker_lock: LockType,
|
shared_worker_lock: LockType,
|
||||||
|
is_driver_worker: bool,
|
||||||
) -> UnreadyWorkerProcHandle:
|
) -> UnreadyWorkerProcHandle:
|
||||||
context = get_mp_context()
|
context = get_mp_context()
|
||||||
# (reader, writer)
|
# (reader, writer)
|
||||||
@@ -588,6 +604,7 @@ class WorkerProc:
|
|||||||
"ready_pipe": (reader, writer),
|
"ready_pipe": (reader, writer),
|
||||||
"death_pipe": death_reader,
|
"death_pipe": death_reader,
|
||||||
"shared_worker_lock": shared_worker_lock,
|
"shared_worker_lock": shared_worker_lock,
|
||||||
|
"is_driver_worker": is_driver_worker,
|
||||||
}
|
}
|
||||||
# Run EngineCore busy loop in background process.
|
# Run EngineCore busy loop in background process.
|
||||||
proc = context.Process(
|
proc = context.Process(
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ try:
|
|||||||
output = self.worker.model_runner.execute_model(
|
output = self.worker.model_runner.execute_model(
|
||||||
scheduler_output, intermediate_tensors
|
scheduler_output, intermediate_tensors
|
||||||
)
|
)
|
||||||
if isinstance(output, IntermediateTensors):
|
if self._is_intermediate_tensors(output):
|
||||||
return scheduler_output, grammar_output, output
|
return scheduler_output, grammar_output, output
|
||||||
|
|
||||||
if isinstance(output, AsyncModelRunnerOutput):
|
if isinstance(output, AsyncModelRunnerOutput):
|
||||||
@@ -125,6 +125,9 @@ try:
|
|||||||
def override_env_vars(self, vars: dict[str, str]):
|
def override_env_vars(self, vars: dict[str, str]):
|
||||||
os.environ.update(vars)
|
os.environ.update(vars)
|
||||||
|
|
||||||
|
def _is_intermediate_tensors(self, output) -> bool:
|
||||||
|
return isinstance(output, IntermediateTensors)
|
||||||
|
|
||||||
ray_import_err = None
|
ray_import_err = None
|
||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
|
|||||||
Reference in New Issue
Block a user