[Core] Pipeline Parallel Support (#4412)
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
This commit is contained in:
committed by
GitHub
parent
15aba081f3
commit
c5832d2ae9
@@ -69,7 +69,7 @@ class DistributedGPUExecutor(GPUExecutor):
|
||||
if self.parallel_worker_tasks is None:
|
||||
self.parallel_worker_tasks = self._run_workers(
|
||||
"start_worker_execution_loop",
|
||||
async_run_remote_workers_only=True,
|
||||
async_run_tensor_parallel_workers_only=True,
|
||||
**self.extra_execute_model_run_workers_kwargs)
|
||||
|
||||
# Only the driver worker returns the sampling results.
|
||||
@@ -138,17 +138,17 @@ class DistributedGPUExecutor(GPUExecutor):
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
async_run_remote_workers_only: bool = False,
|
||||
async_run_tensor_parallel_workers_only: bool = False,
|
||||
max_concurrent_workers: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers.
|
||||
|
||||
Args:
|
||||
async_run_remote_workers_only: If True the method will be run only
|
||||
in the remote workers, not the driver worker. It will also be
|
||||
run asynchronously and return a list of futures rather than
|
||||
blocking on the results.
|
||||
async_run_tensor_parallel_workers_only: If True the method will be
|
||||
run only in the remote TP workers, not the driver worker.
|
||||
It will also be run asynchronously and return a list of futures
|
||||
rather than blocking on the results.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
@@ -110,6 +111,30 @@ class ExecutorBase(ABC):
|
||||
|
||||
class ExecutorAsyncBase(ExecutorBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
cache_config: CacheConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
) -> None:
|
||||
# This locks each pipeline parallel stage so multiple virtual engines
|
||||
# can't execute on the same stage at the same time
|
||||
self.pp_locks = [
|
||||
asyncio.Lock()
|
||||
for _ in range(parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
super().__init__(model_config, cache_config, parallel_config,
|
||||
scheduler_config, device_config, load_config,
|
||||
lora_config, vision_language_config,
|
||||
speculative_config)
|
||||
|
||||
@abstractmethod
|
||||
async def execute_model_async(
|
||||
self,
|
||||
|
||||
@@ -45,7 +45,8 @@ class GPUExecutor(ExecutorBase):
|
||||
lora_config=self.lora_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
speculative_config=self.speculative_config,
|
||||
is_driver_worker=rank == 0,
|
||||
is_driver_worker=(not self.parallel_config)
|
||||
or (rank % self.parallel_config.tensor_parallel_size == 0),
|
||||
)
|
||||
|
||||
def _create_worker(self,
|
||||
|
||||
@@ -91,17 +91,17 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
async_run_remote_workers_only: bool = False,
|
||||
async_run_tensor_parallel_workers_only: bool = False,
|
||||
max_concurrent_workers: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers.
|
||||
|
||||
Args:
|
||||
async_run_remote_workers_only: If True the method will be run only
|
||||
in the remote workers, not the driver worker. It will also be
|
||||
run asynchronously and return a list of futures rather than
|
||||
blocking on the results.
|
||||
async_run_tensor_parallel_workers_only: If True the method will be
|
||||
run only in the remote TP workers, not the driver worker.
|
||||
It will also be run asynchronously and return a list of futures
|
||||
rather than blocking on the results.
|
||||
"""
|
||||
|
||||
if max_concurrent_workers:
|
||||
@@ -114,7 +114,7 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
||||
for worker in self.workers
|
||||
]
|
||||
|
||||
if async_run_remote_workers_only:
|
||||
if async_run_tensor_parallel_workers_only:
|
||||
# Just return futures
|
||||
return worker_outputs
|
||||
|
||||
|
||||
@@ -62,7 +62,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
|
||||
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||
**ray_remote_kwargs):
|
||||
if self.parallel_config.tensor_parallel_size == 1:
|
||||
if (self.parallel_config.tensor_parallel_size == 1
|
||||
and self.parallel_config.pipeline_parallel_size == 1):
|
||||
# For single GPU case, we use a ray worker with constrained memory.
|
||||
num_gpus = self.cache_config.gpu_memory_utilization
|
||||
else:
|
||||
@@ -189,6 +190,26 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
max_concurrent_workers=self.parallel_config.
|
||||
max_parallel_loading_workers)
|
||||
|
||||
# This is the list of workers that are rank 0 of each TP group EXCEPT
|
||||
# global rank 0. These are the workers that will broadcast to the
|
||||
# rest of the workers.
|
||||
self.tp_driver_workers: List[RayWorkerWrapper] = []
|
||||
# This is the list of workers that are not drivers and not the first
|
||||
# worker in a TP group. These are the workers that will be
|
||||
# broadcasted to.
|
||||
self.non_driver_workers: List[RayWorkerWrapper] = []
|
||||
|
||||
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
|
||||
for tp_rank in range(self.parallel_config.tensor_parallel_size):
|
||||
rank = (pp_rank *
|
||||
self.parallel_config.tensor_parallel_size) + tp_rank
|
||||
if rank == 0:
|
||||
pass
|
||||
elif rank % self.parallel_config.tensor_parallel_size == 0:
|
||||
self.tp_driver_workers.append(self.workers[rank - 1])
|
||||
else:
|
||||
self.non_driver_workers.append(self.workers[rank - 1])
|
||||
|
||||
def _driver_execute_model(
|
||||
self, execute_model_req: Optional[ExecuteModelRequest]
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
@@ -204,7 +225,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
async_run_remote_workers_only: bool = False,
|
||||
async_run_tensor_parallel_workers_only: bool = False,
|
||||
all_args: Optional[List[Tuple[Any, ...]]] = None,
|
||||
all_kwargs: Optional[List[Dict[str, Any]]] = None,
|
||||
use_dummy_driver: bool = False,
|
||||
@@ -215,10 +236,11 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
"""Runs the given method on all workers. Can be used in the following
|
||||
ways:
|
||||
|
||||
- async_run_remote_workers_only: If True the method will be run only
|
||||
in the remote workers, not the driver worker. It will also be
|
||||
run asynchronously and return a list of futures rather than blocking
|
||||
on the results.
|
||||
Args:
|
||||
- async_run_tensor_parallel_workers_only: If True the method will be
|
||||
run only in the remote TP workers, not the driver worker.
|
||||
It will also be run asynchronously and return a list of futures
|
||||
rather than blocking on the results.
|
||||
- args/kwargs: All workers share the same args/kwargs
|
||||
- all_args/all_kwargs: args/kwargs for each worker are specified
|
||||
individually
|
||||
@@ -228,7 +250,9 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
raise NotImplementedError(
|
||||
"max_concurrent_workers is not supported yet.")
|
||||
|
||||
count = len(self.workers)
|
||||
count = len(self.workers) if not \
|
||||
async_run_tensor_parallel_workers_only \
|
||||
else len(self.non_driver_workers)
|
||||
all_worker_args = repeat(args, count) if all_args is None \
|
||||
else islice(all_args, 1, None)
|
||||
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
|
||||
@@ -242,14 +266,17 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
ray_worker_outputs = []
|
||||
else:
|
||||
# Start the ray workers first.
|
||||
ray_workers = self.workers
|
||||
if async_run_tensor_parallel_workers_only:
|
||||
ray_workers = self.non_driver_workers
|
||||
ray_worker_outputs = [
|
||||
worker.execute_method.remote(method, *worker_args,
|
||||
**worker_kwargs)
|
||||
for (worker, worker_args, worker_kwargs
|
||||
) in zip(self.workers, all_worker_args, all_worker_kwargs)
|
||||
) in zip(ray_workers, all_worker_args, all_worker_kwargs)
|
||||
]
|
||||
|
||||
if async_run_remote_workers_only:
|
||||
if async_run_tensor_parallel_workers_only:
|
||||
# Just return futures
|
||||
return ray_worker_outputs
|
||||
|
||||
@@ -319,12 +346,32 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[SamplerOutput]:
|
||||
return await self.driver_exec_method("execute_model",
|
||||
execute_model_req)
|
||||
|
||||
async def _run_task_with_lock(task, lock, *args, **kwargs):
|
||||
async with lock:
|
||||
return await task(*args, **kwargs)
|
||||
|
||||
tasks = []
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
_run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
|
||||
"execute_model", execute_model_req)))
|
||||
for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
|
||||
start=1):
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
_run_task_with_lock(driver_worker.execute_method.remote,
|
||||
self.pp_locks[pp_rank],
|
||||
"execute_model", execute_model_req)))
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Only the last PP stage has the final results.
|
||||
return results[-1]
|
||||
|
||||
async def _start_worker_execution_loop(self):
|
||||
coros = [
|
||||
worker.execute_method.remote("start_worker_execution_loop")
|
||||
for worker in self.workers
|
||||
for worker in self.non_driver_workers
|
||||
]
|
||||
return await asyncio.gather(*coros)
|
||||
|
||||
Reference in New Issue
Block a user